Skip to content
Snippets Groups Projects
Commit ccb5f9dd authored by Jan Caron's avatar Jan Caron
Browse files

Overhaul of plotting capabilities. Plots now look nicer. Now all in nm.

fielddata: Added plot_quiver_field convenience function.
io_phasemap: Added functionality for images separately from HyperSpy.
phasemap: Automatically normalise confidence.
projector: Quaternion for tilt around x-axis corrected (was falsely around y).
reconstruction: Added verbose option (True by default).
reconstruction_2d_from_phasemap: Plotting optimised.
parent 4f1991fd
No related branches found
No related tags found
No related merge requests found
...@@ -12,15 +12,13 @@ from numbers import Number ...@@ -12,15 +12,13 @@ from numbers import Number
import numpy as np import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator from matplotlib.ticker import MaxNLocator, FuncFormatter
from matplotlib.colors import ListedColormap from matplotlib.colors import ListedColormap
from PIL import Image from PIL import Image
from scipy.ndimage.interpolation import zoom from scipy.ndimage.interpolation import zoom
from jutil import fft
from . import colors from . import colors
__all__ = ['VectorData', 'ScalarData'] __all__ = ['VectorData', 'ScalarData']
...@@ -727,7 +725,7 @@ class VectorData(FieldData): ...@@ -727,7 +725,7 @@ class VectorData(FieldData):
from .file_io.io_vectordata import save_vectordata from .file_io.io_vectordata import save_vectordata
save_vectordata(self, filename, **kwargs) save_vectordata(self, filename, **kwargs)
def plot_field(self, title='Vector Field', axis=None, proj_axis='z', figsize=(8.5, 7), def plot_field(self, title='Vector Field', axis=None, proj_axis='z', figsize=(9, 8),
ax_slice=None, show_mask=True, bgcolor='white', hue_mode='triadic'): ax_slice=None, show_mask=True, bgcolor='white', hue_mode='triadic'):
"""Plot a slice of the vector field as a quiver plot. """Plot a slice of the vector field as a quiver plot.
...@@ -765,16 +763,16 @@ class VectorData(FieldData): ...@@ -765,16 +763,16 @@ class VectorData(FieldData):
ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2
u_mag, v_mag = self.get_slice(ax_slice, proj_axis) u_mag, v_mag = self.get_slice(ax_slice, proj_axis)
if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice
u_label = 'x [px]' u_label = 'x-axis [nm]'
v_label = 'y [px]' v_label = 'y-axis [nm]'
submask = self.get_mask()[ax_slice, ...] submask = self.get_mask()[ax_slice, ...]
elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice
u_label = 'x [px]' u_label = 'x-axis [nm]'
v_label = 'z [px]' v_label = 'z-axis [nm]'
submask = self.get_mask()[:, ax_slice, :] submask = self.get_mask()[:, ax_slice, :]
elif proj_axis == 'x': # Slice of the yz-plane with x = ax_slice elif proj_axis == 'x': # Slice of the yz-plane with x = ax_slice
u_label = 'z [px]' u_label = 'z-axis [nm]'
v_label = 'y [px]' v_label = 'y-axis [nm]'
submask = self.get_mask()[..., ax_slice] submask = self.get_mask()[..., ax_slice]
else: else:
raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis))
...@@ -817,10 +815,12 @@ class VectorData(FieldData): ...@@ -817,10 +815,12 @@ class VectorData(FieldData):
u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1]))) u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1])))
axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True))
axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True))
axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
# Return plotting axis: # Return plotting axis:
return axis return axis
def plot_streamline(self, title='Vector Field', axis=None, proj_axis='z', figsize=(8.5, 7), def plot_streamline(self, title='Vector Field', axis=None, proj_axis='z', figsize=(9, 8),
coloring='angle', ax_slice=None, density=2, linewidth=2, coloring='angle', ax_slice=None, density=2, linewidth=2,
show_mask=True, bgcolor='white', hue_mode='triadic'): show_mask=True, bgcolor='white', hue_mode='triadic'):
"""Plot a slice of the vector field as a quiver plot. """Plot a slice of the vector field as a quiver plot.
...@@ -869,16 +869,16 @@ class VectorData(FieldData): ...@@ -869,16 +869,16 @@ class VectorData(FieldData):
ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2
u_mag, v_mag = self.get_slice(ax_slice, proj_axis) u_mag, v_mag = self.get_slice(ax_slice, proj_axis)
if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice
u_label = 'x [px]' u_label = 'x-axis [nm]'
v_label = 'y [px]' v_label = 'y-axis [nm]'
submask = self.get_mask()[ax_slice, ...] submask = self.get_mask()[ax_slice, ...]
elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice
u_label = 'x [px]' u_label = 'x-axis [nm]'
v_label = 'z [px]' v_label = 'z-axis [nm]'
submask = self.get_mask()[:, ax_slice, :] submask = self.get_mask()[:, ax_slice, :]
elif proj_axis == 'x': # Slice of the yz-plane with x = ax_slice elif proj_axis == 'x': # Slice of the yz-plane with x = ax_slice
u_label = 'z [px]' u_label = 'z-axis [nm]'
v_label = 'y [px]' v_label = 'y-axis [nm]'
submask = self.get_mask()[..., ax_slice] submask = self.get_mask()[..., ax_slice]
else: else:
raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis))
...@@ -948,10 +948,12 @@ class VectorData(FieldData): ...@@ -948,10 +948,12 @@ class VectorData(FieldData):
u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1]))) u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1])))
axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True))
axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True))
axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
# Return plotting axis: # Return plotting axis:
return axis return axis
def plot_quiver(self, title='Vector Field', axis=None, proj_axis='z', figsize=(8.5, 7), def plot_quiver(self, title='Vector Field', axis=None, proj_axis='z', figsize=(9, 8),
coloring='angle', ar_dens=1, ax_slice=None, log=False, scaled=True, coloring='angle', ar_dens=1, ax_slice=None, log=False, scaled=True,
scale=1., show_mask=True, bgcolor='white', hue_mode='triadic'): scale=1., show_mask=True, bgcolor='white', hue_mode='triadic'):
"""Plot a slice of the vector field as a quiver plot. """Plot a slice of the vector field as a quiver plot.
...@@ -1004,16 +1006,16 @@ class VectorData(FieldData): ...@@ -1004,16 +1006,16 @@ class VectorData(FieldData):
ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2 ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2
u_mag, v_mag = self.get_slice(ax_slice, proj_axis) u_mag, v_mag = self.get_slice(ax_slice, proj_axis)
if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice if proj_axis == 'z': # Slice of the xy-plane with z = ax_slice
u_label = 'x [px]' u_label = 'x-axis [nm]'
v_label = 'y [px]' v_label = 'y-axis [nm]'
submask = self.get_mask()[ax_slice, ...] submask = self.get_mask()[ax_slice, ...]
elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice elif proj_axis == 'y': # Slice of the xz-plane with y = ax_slice
u_label = 'x [px]' u_label = 'x-axis [nm]'
v_label = 'z [px]' v_label = 'z-axis [nm]'
submask = self.get_mask()[:, ax_slice, :] submask = self.get_mask()[:, ax_slice, :]
elif proj_axis == 'x': # Slice of the yz-plane with x = ax_slice elif proj_axis == 'x': # Slice of the yz-plane with x = ax_slice
u_label = 'z [px]' u_label = 'z-axis [nm]'
v_label = 'y [px]' v_label = 'y-axis [nm]'
submask = self.get_mask()[..., ax_slice] submask = self.get_mask()[..., ax_slice]
else: else:
raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis)) raise ValueError('{} is not a valid argument (use x, y or z)'.format(proj_axis))
...@@ -1102,9 +1104,53 @@ class VectorData(FieldData): ...@@ -1102,9 +1104,53 @@ class VectorData(FieldData):
u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1]))) u_bin, v_bin = 9, np.max((2, np.floor(9 * dim_uv[0] / dim_uv[1])))
axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True))
axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True))
axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
# Return plotting axis: # Return plotting axis:
return axis return axis
def plot_quiver_field(self, title='Vector Field', axis=None, proj_axis='z',
figsize=(9, 8), ar_dens=1, ax_slice=None, show_mask=True,
bgcolor='white', hue_mode='triadic'):
"""Plot the vector field as a field plot with uniformly colored arrows overlayed.
Parameters
----------
title : string, optional
The title for the plot.
axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis on which the graph is plotted. Creates a new figure if none is specified.
proj_axis : {'z', 'y', 'x'}, optional
The axis, from which a slice is plotted. The default is 'z'.
figsize : tuple of floats (N=2)
Size of the plot figure.
ar_dens: int, optional
Number defining the arrow density which is plotted. A higher ar_dens number skips more
arrows (a number of 2 plots every second arrow). Default is 1.
ax_slice : int, optional
The slice-index of the axis specified in `proj_axis`. Is set to the center of
`proj_axis` if not specified.
show_mask: boolean
Default is True. Shows the outlines of the mask slice if available.
bgcolor: {'black', 'white'}, optional
Determines the background color of the plot.
hue_mode : {'triadic', 'tetradic'}
Optional string for determining the hue scheme. Use either a triadic or tetradic
scheme (see the according colormaps for more information).
Returns
-------
axis: :class:`~matplotlib.axes.AxesSubplot`
The axis on which the graph is plotted.
"""
axis = self.plot_field(title=title, axis=axis, proj_axis=proj_axis, figsize=figsize,
ax_slice=ax_slice, show_mask=show_mask, bgcolor=bgcolor,
hue_mode=hue_mode)
self.plot_quiver(axis=axis, proj_axis=proj_axis, figsize=figsize, coloring='uniform',
ar_dens=ar_dens, ax_slice=ax_slice, show_mask=show_mask, bgcolor=bgcolor)
return axis
def plot_quiver3d(self, title='Vector Field', limit=None, cmap='jet', mode='2darrow', def plot_quiver3d(self, title='Vector Field', limit=None, cmap='jet', mode='2darrow',
coloring='full', ar_dens=1, opacity=1.0, hue_mode='triadic'): coloring='full', ar_dens=1, opacity=1.0, hue_mode='triadic'):
"""Plot the vector field as 3D-vectors in a quiverplot. """Plot the vector field as 3D-vectors in a quiverplot.
......
...@@ -10,6 +10,8 @@ import os ...@@ -10,6 +10,8 @@ import os
import numpy as np import numpy as np
from PIL import Image
from ..phasemap import PhaseMap from ..phasemap import PhaseMap
__all__ = ['load_phasemap'] __all__ = ['load_phasemap']
...@@ -73,6 +75,8 @@ def _load(filename, as_phasemap=False, a=1., **kwargs): ...@@ -73,6 +75,8 @@ def _load(filename, as_phasemap=False, a=1., **kwargs):
# Load from npy-files: # Load from npy-files:
elif extension in ['.npy', '.npz']: elif extension in ['.npy', '.npz']:
return _load_from_npy(filename, as_phasemap, a, **kwargs) return _load_from_npy(filename, as_phasemap, a, **kwargs)
elif extension in ['.jpeg', '.jpg', '.png', '.bmp']:
return _load_from_img(filename, as_phasemap, a, **kwargs)
# Load with HyperSpy: # Load with HyperSpy:
else: else:
if extension == '': if extension == '':
...@@ -124,6 +128,17 @@ def _load_from_npy(filename, as_phasemap, a, **kwargs): ...@@ -124,6 +128,17 @@ def _load_from_npy(filename, as_phasemap, a, **kwargs):
return result return result
def _load_from_img(filename, as_phasemap, a, **kwargs):
result = np.asarray(Image.open(filename, **kwargs).convert('L'))
if as_phasemap:
if a is None:
a = 1. # Use default!
return PhaseMap(a, result)
else:
return result
def _load_from_hs(filename, as_phasemap, a, **kwargs): def _load_from_hs(filename, as_phasemap, a, **kwargs):
try: try:
import hyperspy.api as hs import hyperspy.api as hs
......
...@@ -7,12 +7,18 @@ ...@@ -7,12 +7,18 @@
import logging import logging
from numbers import Number from numbers import Number
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MaxNLocator, FuncFormatter from matplotlib.ticker import MaxNLocator, FuncFormatter
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage.interpolation import zoom from scipy.ndimage.interpolation import zoom
from . import colors from . import colors
...@@ -51,8 +57,6 @@ class PhaseMap(object): ...@@ -51,8 +57,6 @@ class PhaseMap(object):
""" """
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
UNITDICT = {u'rad': 1E0, UNITDICT = {u'rad': 1E0,
...@@ -160,9 +164,11 @@ class PhaseMap(object): ...@@ -160,9 +164,11 @@ class PhaseMap(object):
if confidence is not None: if confidence is not None:
assert confidence.shape == self.phase.shape, \ assert confidence.shape == self.phase.shape, \
'Confidence and phase dimensions must match!' 'Confidence and phase dimensions must match!'
confidence = confidence.astype(dtype=np.float32)
confidence /= confidence.max() # Normalise!
else: else:
confidence = np.ones_like(self.phase) confidence = np.ones_like(self.phase, dtype=np.float32)
self._confidence = confidence.astype(dtype=np.float32) self._confidence = confidence
def __init__(self, a, phase, mask=None, confidence=None): def __init__(self, a, phase, mask=None, confidence=None):
self._log.debug('Calling __init__') self._log.debug('Calling __init__')
...@@ -525,9 +531,9 @@ class PhaseMap(object): ...@@ -525,9 +531,9 @@ class PhaseMap(object):
from .file_io.io_phasemap import save_phasemap from .file_io.io_phasemap import save_phasemap
save_phasemap(self, filename, save_mask, save_conf, pyramid_format, **kwargs) save_phasemap(self, filename, save_mask, save_conf, pyramid_format, **kwargs)
def plot_phase(self, title='Phase Map', cbar_title=None, unit='rad', cmap='RdBu', limit=None, def plot_phase(self, title='Phase Map', cbar_title=None, unit='rad', cmap='RdBu', vmin=None,
norm=None, axis=None, cbar=True, show_mask=True, show_conf=True, vmax=None, symmetric=True, norm=None, axis=None, cbar=True, figsize=(9, 8),
sigma_clip=None, interpolation='none'): show_mask=True, show_conf=True, sigma_clip=None, interpolation='none'):
"""Display the phasemap as a colormesh. """Display the phasemap as a colormesh.
Parameters Parameters
...@@ -542,9 +548,15 @@ class PhaseMap(object): ...@@ -542,9 +548,15 @@ class PhaseMap(object):
cmap : string, optional cmap : string, optional
The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string.
The default is 'RdBu'. The default is 'RdBu'.
limit : float, optional vmin : float, optional
Plotlimit for the phase in both negative and positive direction (symmetric around 0). Minimum value used for determining the plot limits. If not set, it will be
If not specified, the maximum amplitude of the phase is used. determined by the minimum of the phase directly.
vmax : float, optional
Maximum value used for determining the plot limits. If not set, it will be
determined by the minimum of the phase directly.
symmetric : boolean, optional
If True (default), a zero symmetric colormap is assumed and a zero value (which
will always be present) will be set to the central color color of the colormap.
norm : :class:`~matplotlib.colors.Normalize` or subclass, optional norm : :class:`~matplotlib.colors.Normalize` or subclass, optional
Norm, which is used to determine the colors to encode the phase information. Norm, which is used to determine the colors to encode the phase information.
If not specified, :class:`~matplotlib.colors.Normalize` is automatically used. If not specified, :class:`~matplotlib.colors.Normalize` is automatically used.
...@@ -552,6 +564,8 @@ class PhaseMap(object): ...@@ -552,6 +564,8 @@ class PhaseMap(object):
Axis on which the graph is plotted. Creates a new figure if none is specified. Axis on which the graph is plotted. Creates a new figure if none is specified.
cbar : bool, optional cbar : bool, optional
A switch determining if the colorbar should be plotted or not. Default is True. A switch determining if the colorbar should be plotted or not. Default is True.
figsize : tuple of floats (N=2)
Size of the plot figure.
show_mask : bool, optional show_mask : bool, optional
A switch determining if the mask should be plotted or not. Default is True. A switch determining if the mask should be plotted or not. Default is True.
show_conf : float, optional show_conf : float, optional
...@@ -571,8 +585,8 @@ class PhaseMap(object): ...@@ -571,8 +585,8 @@ class PhaseMap(object):
self._log.debug('Calling plot_phase') self._log.debug('Calling plot_phase')
# Take units into consideration: # Take units into consideration:
phase = self.phase * self.UNITDICT[unit] phase = self.phase * self.UNITDICT[unit]
# Calculate limit if necessary: # Calculate limits if necessary (not necessary if both limits are already set):
if limit is None: if vmin is None and vmax is None:
phase_l = phase phase_l = phase
# Clip non-trustworthy regions for the limit calculation: # Clip non-trustworthy regions for the limit calculation:
if show_conf: if show_conf:
...@@ -585,15 +599,28 @@ class PhaseMap(object): ...@@ -585,15 +599,28 @@ class PhaseMap(object):
phase_sigma = np.where(outlier, phase_l, np.nan) phase_sigma = np.where(outlier, phase_l, np.nan)
phase_min, phase_max = np.nanmin(phase_sigma), np.nanmax(phase_sigma) phase_min, phase_max = np.nanmin(phase_sigma), np.nanmax(phase_sigma)
phase_l = np.clip(phase_l, phase_min, phase_max) phase_l = np.clip(phase_l, phase_min, phase_max)
# Calculate the limit: # Calculate the limits if necessary (zero has to be present!):
limit = np.max(np.abs(phase_l)) if vmin is None:
vmin = np.min(phase_l)
if vmax is None:
vmax = np.max(phase_l)
# Configure colormap, to fix white to zero if colormap is symmetric:
if symmetric:
if isinstance(cmap, str): # Get colormap if given as string:
cmap = plt.get_cmap(cmap)
vmin, vmax = np.min([vmin, 0]), np.max([0, vmax]) # Make sure zero is present!
limit = np.max(np.abs([vmin, vmax]))
start = (vmin + limit) / (2 * limit)
end = (vmax + limit) / (2 * limit)
cmap_colors = cmap(np.linspace(start, end, 256))
cmap = LinearSegmentedColormap.from_list('Symmetric', cmap_colors)
# If no axis is specified, a new figure is created: # If no axis is specified, a new figure is created:
if axis is None: if axis is None:
fig = plt.figure(figsize=(7, 7)) fig = plt.figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1) axis = fig.add_subplot(1, 1, 1)
axis.set_aspect('equal') axis.set_aspect('equal')
# Plot the phasemap: # Plot the phasemap:
im = axis.imshow(phase, cmap=cmap, vmin=-limit, vmax=limit, interpolation=interpolation, im = axis.imshow(phase, cmap=cmap, vmin=vmin, vmax=vmax, interpolation=interpolation,
norm=norm, origin='lower', extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) norm=norm, origin='lower', extent=(0, self.dim_uv[1], 0, self.dim_uv[0]))
if show_mask or show_conf: if show_mask or show_conf:
vv, uu = np.indices(self.dim_uv) + 0.5 vv, uu = np.indices(self.dim_uv) + 0.5
...@@ -611,8 +638,8 @@ class PhaseMap(object): ...@@ -611,8 +638,8 @@ class PhaseMap(object):
u_bin, v_bin = 9, np.max((2, np.floor(9 * self.dim_uv[0] / self.dim_uv[1]))) u_bin, v_bin = 9, np.max((2, np.floor(9 * self.dim_uv[0] / self.dim_uv[1])))
axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True))
axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True))
axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:g}'.format(x * self.a))) axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:g}'.format(x * self.a))) axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
axis.tick_params(axis='both', which='major', labelsize=14) axis.tick_params(axis='both', which='major', labelsize=14)
axis.set_title(title, fontsize=18) axis.set_title(title, fontsize=18)
axis.set_xlim(0, self.dim_uv[1]) axis.set_xlim(0, self.dim_uv[1])
...@@ -621,10 +648,9 @@ class PhaseMap(object): ...@@ -621,10 +648,9 @@ class PhaseMap(object):
axis.set_ylabel('v-axis [nm]', fontsize=15) axis.set_ylabel('v-axis [nm]', fontsize=15)
# # Add colorbar: # # Add colorbar:
if cbar: if cbar:
fig = plt.gcf() divider = make_axes_locatable(axis)
fig.subplots_adjust(right=0.8) cbar_ax = divider.append_axes('right', size='5%', pad=0.1)
cbar_ax = fig.add_axes([0.82, 0.15, 0.02, 0.7]) cbar = plt.colorbar(im, cax=cbar_ax)
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.tick_params(labelsize=14) cbar.ax.tick_params(labelsize=14)
if cbar_title is None: if cbar_title is None:
cbar_title = u'phase shift [{}]'.format(unit) cbar_title = u'phase shift [{}]'.format(unit)
...@@ -667,11 +693,17 @@ class PhaseMap(object): ...@@ -667,11 +693,17 @@ class PhaseMap(object):
axis.set_xlabel('u-axis [px]') axis.set_xlabel('u-axis [px]')
axis.set_ylabel('v-axis [px]') axis.set_ylabel('v-axis [px]')
axis.set_zlabel('phase shift [{}]'.format(unit)) axis.set_zlabel('phase shift [{}]'.format(unit))
if self.dim_uv[0] >= self.dim_uv[1]:
u_bin, v_bin = np.max((2, np.floor(9 * self.dim_uv[1] / self.dim_uv[0]))), 9
else:
u_bin, v_bin = 9, np.max((2, np.floor(9 * self.dim_uv[0] / self.dim_uv[1])))
axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True))
axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True))
# Return plotting axis: # Return plotting axis:
return axis return axis
def plot_holo(self, title=None, gain='auto', axis=None, hue_mode='triadic', def plot_holo(self, title=None, gain='auto', axis=None, hue_mode='triadic',
interpolation='none'): interpolation='none', figsize=(8, 8)):
"""Display the color coded holography image. """Display the color coded holography image.
Parameters Parameters
...@@ -683,11 +715,13 @@ class PhaseMap(object): ...@@ -683,11 +715,13 @@ class PhaseMap(object):
which means that the gain will be determined automatically to look pretty. which means that the gain will be determined automatically to look pretty.
axis : :class:`~matplotlib.axes.AxesSubplot`, optional axis : :class:`~matplotlib.axes.AxesSubplot`, optional
Axis on which the graph is plotted. Creates a new figure if none is specified. Axis on which the graph is plotted. Creates a new figure if none is specified.
interpolation : {'none, 'bilinear', 'cubic', 'nearest'}, optional
Defines the interpolation method. No interpolation is used in the default case.
hue_mode : {'triadic', 'tetradic'} hue_mode : {'triadic', 'tetradic'}
Optional string for determining the hue scheme. Use either a triadic or tetradic Optional string for determining the hue scheme. Use either a triadic or tetradic
scheme (see the according colormaps for more information). scheme (see the according colormaps for more information).
interpolation : {'none, 'bilinear', 'cubic', 'nearest'}, optional
Defines the interpolation method. No interpolation is used in the default case.
figsize : tuple of floats (N=2)
Size of the plot figure.
Returns Returns
------- -------
...@@ -699,6 +733,7 @@ class PhaseMap(object): ...@@ -699,6 +733,7 @@ class PhaseMap(object):
# Calculate gain if 'auto' is selected: # Calculate gain if 'auto' is selected:
if gain == 'auto': if gain == 'auto':
gain = 4 * 2 * np.pi / (np.abs(self.phase).max() + 1E-30) gain = 4 * 2 * np.pi / (np.abs(self.phase).max() + 1E-30)
gain = round(gain, -int(np.floor(np.log10(abs(gain)))))
# Set title if not set: # Set title if not set:
if title is None: if title is None:
title = 'Holographic contour Map (gain: {:g})'.format(gain) title = 'Holographic contour Map (gain: {:g})'.format(gain)
...@@ -706,40 +741,56 @@ class PhaseMap(object): ...@@ -706,40 +741,56 @@ class PhaseMap(object):
holo = np.cos(gain * self.phase) holo = np.cos(gain * self.phase)
holo += 1 # Shift to positive values holo += 1 # Shift to positive values
holo /= 2 # Rescale to [0, 1] holo /= 2 # Rescale to [0, 1]
# Calculate the phase gradients and calculate colors: # Calculate the phase gradients:
# B = rot(A) --> B_x = grad_y(A_z), B_y = -grad_x(A_z); phi_m ~ -int(A_z) # B = rot(A) --> B_x = grad_y(A_z), B_y = -grad_x(A_z); phi_m ~ -int(A_z)
# sign switch --> B_x = -grad_y(phi_m), B_y = grad_x(phi_m) # sign switch --> B_x = -grad_y(phi_m), B_y = grad_x(phi_m)
grad_x, grad_y = np.gradient(self.phase, self.a, self.a) grad_x, grad_y = np.gradient(self.phase, self.a, self.a)
rgb = colors.rgb_from_vector(grad_x, -grad_y, np.zeros_like(grad_x), mode=hue_mode) # Clip outliers:
sigma_clip = 2
outlier_x = np.abs(grad_x - np.mean(grad_x)) < sigma_clip * np.std(grad_x)
grad_x_sigma = np.where(outlier_x, grad_x, np.nan)
grad_x_min, grad_x_max = np.nanmin(grad_x_sigma), np.nanmax(grad_x_sigma)
grad_x = np.clip(grad_x, grad_x_min, grad_x_max)
outlier_y = np.abs(grad_y - np.mean(grad_y)) < sigma_clip * np.std(grad_y)
grad_y_sigma = np.where(outlier_y, grad_y, np.nan)
grad_y_min, grad_y_max = np.nanmin(grad_y_sigma), np.nanmax(grad_y_sigma)
grad_y = np.clip(grad_y, grad_y_min, grad_y_max)
# Calculate colors:
rgb = colors.rgb_from_vector(grad_x, -grad_y,
np.zeros_like(grad_x), mode=hue_mode)
rgb = (holo.T * rgb.T).T.astype(np.uint8) rgb = (holo.T * rgb.T).T.astype(np.uint8)
holo_image = Image.fromarray(rgb) holo_image = Image.fromarray(rgb)
# If no axis is specified, a new figure is created: # If no axis is specified, a new figure is created:
if axis is None: if axis is None:
fig = plt.figure() fig = plt.figure(figsize=figsize)
axis = fig.add_subplot(1, 1, 1) axis = fig.add_subplot(1, 1, 1)
axis.set_aspect('equal') axis.set_aspect('equal')
# Plot the image and set axes: # Plot the image and set axes:
axis.imshow(holo_image, origin='lower', interpolation=interpolation, axis.imshow(holo_image, origin='lower', interpolation=interpolation,
extent=(0, self.dim_uv[1], 0, self.dim_uv[0])) extent=(0, self.dim_uv[1], 0, self.dim_uv[0]))
# Set the title and the axes labels: # Set the title and the axes labels:
axis.set_title(title) # Set the axes ticks and labels:
axis.tick_params(axis='both', which='major', labelsize=14)
axis.set_title(title, fontsize=18)
axis.set_xlabel('u-axis [px]', fontsize=15)
axis.set_ylabel('v-axis [px]', fontsize=15)
if self.dim_uv[0] >= self.dim_uv[1]: if self.dim_uv[0] >= self.dim_uv[1]:
u_bin, v_bin = np.max((2, np.floor(9 * self.dim_uv[1] / self.dim_uv[0]))), 9 u_bin, v_bin = np.max((2, np.floor(9 * self.dim_uv[1] / self.dim_uv[0]))), 9
else: else:
u_bin, v_bin = 9, np.max((2, np.floor(9 * self.dim_uv[0] / self.dim_uv[1]))) u_bin, v_bin = 9, np.max((2, np.floor(9 * self.dim_uv[0] / self.dim_uv[1])))
axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True)) axis.xaxis.set_major_locator(MaxNLocator(nbins=u_bin, integer=True))
axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True)) axis.yaxis.set_major_locator(MaxNLocator(nbins=v_bin, integer=True))
axis.xaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
axis.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '{:.3g}'.format(x * self.a)))
axis.tick_params(axis='both', which='major', labelsize=14)
axis.set_title(title, fontsize=18)
axis.set_xlim(0, self.dim_uv[1])
axis.set_ylim(0, self.dim_uv[0])
axis.set_xlabel('u-axis [nm]', fontsize=15)
axis.set_ylabel('v-axis [nm]', fontsize=15)
# Return plotting axis: # Return plotting axis:
return axis return axis
def plot_combined(self, sup_title='Combined Plot', phase_title='Phase Map', holo_title=None, def plot_combined(self, sup_title='Combined Plot', phase_title='Phase Map', holo_title=None,
cbar_title=None, unit='rad', cmap='RdBu', limit=None, norm=None, gain='auto', cbar_title=None, unit='rad', cmap='RdBu', vmin=None, vmax=None,
interpolation='none', cbar=True, show_mask=True, show_conf=True, symmetric=True, norm=None, gain='auto', interpolation='none', cbar=True,
hue_mode='triadic'): show_mask=True, show_conf=True, hue_mode='triadic'):
"""Display the phase map and the resulting color coded holography image in one plot. """Display the phase map and the resulting color coded holography image in one plot.
Parameters Parameters
...@@ -757,9 +808,15 @@ class PhaseMap(object): ...@@ -757,9 +808,15 @@ class PhaseMap(object):
cmap : string, optional cmap : string, optional
The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string. The :class:`~matplotlib.colors.Colormap` which is used for the plot as a string.
The default is 'RdBu'. The default is 'RdBu'.
limit : float, optional vmin : float, optional
Plotlimit for the phase in both negative and positive direction (symmetric around 0). Minimum value used for determining the plot limits. If not set, it will be
If not specified, the maximum amplitude of the phase is used. determined by the minimum of the phase directly.
vmax : float, optional
Maximum value used for determining the plot limits. If not set, it will be
determined by the minimum of the phase directly.
symmetric : boolean, optional
If True (default), a zero symmetric colormap is assumed and a zero value (which
will always be present) will be set to the central color color of the colormap.
norm : :class:`~matplotlib.colors.Normalize` or subclass, optional norm : :class:`~matplotlib.colors.Normalize` or subclass, optional
Norm, which is used to determine the colors to encode the phase information. Norm, which is used to determine the colors to encode the phase information.
If not specified, :class:`~matplotlib.colors.Normalize` is automatically used. If not specified, :class:`~matplotlib.colors.Normalize` is automatically used.
...@@ -787,17 +844,21 @@ class PhaseMap(object): ...@@ -787,17 +844,21 @@ class PhaseMap(object):
""" """
self._log.debug('Calling plot_combined') self._log.debug('Calling plot_combined')
# Create combined plot and set title: # Create combined plot and set title:
fig = plt.figure(figsize=(15, 7)) fig = plt.figure(figsize=(19, 9))
fig.suptitle(sup_title, fontsize=20) fig.suptitle(sup_title, fontsize=20)
# Plot holography image: # Plot holography image:
holo_axis = fig.add_subplot(1, 2, 1, aspect='equal') holo_axis = fig.add_subplot(1, 2, 1, aspect='equal')
self.plot_holo(title=holo_title, gain=gain, axis=holo_axis, interpolation=interpolation, self.plot_holo(title=holo_title, gain=gain, axis=holo_axis, interpolation=interpolation,
hue_mode=hue_mode) hue_mode=hue_mode)
if cbar: # Make space for colorbar without adding one, so that both plots have same size:
divider = make_axes_locatable(holo_axis)
cbar_ax = divider.append_axes('right', size='5%', pad=0.1)
cbar_ax.axis('off')
# Plot phase map: # Plot phase map:
phase_axis = fig.add_subplot(1, 2, 2, aspect='equal') phase_axis = fig.add_subplot(1, 2, 2, aspect='equal')
fig.subplots_adjust(right=0.85)
self.plot_phase(title=phase_title, cbar_title=cbar_title, unit=unit, cmap=cmap, self.plot_phase(title=phase_title, cbar_title=cbar_title, unit=unit, cmap=cmap,
limit=limit, norm=norm, axis=phase_axis, cbar=cbar, vmin=vmin, vmax=vmax, symmetric=symmetric, norm=norm, axis=phase_axis,
show_mask=show_mask, show_conf=show_conf) cbar=cbar, show_mask=show_mask, show_conf=show_conf)
# Return the plotting axes: # Return the plotting axes:
return phase_axis, holo_axis return phase_axis, holo_axis
...@@ -258,7 +258,7 @@ class RotTiltProjector(Projector): ...@@ -258,7 +258,7 @@ class RotTiltProjector(Projector):
# Calculate vectors to voxels relative to rotation center: # Calculate vectors to voxels relative to rotation center:
voxel_vecs = (np.asarray(voxels) + 0.5 - np.asarray(center)).T voxel_vecs = (np.asarray(voxels) + 0.5 - np.asarray(center)).T
# Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x): # Create tilt, rotation and combined quaternion, careful: Quaternion(w,x,y,z), not (z,y,x):
quat_x = Quaternion.from_axisangle((0, 1, 0), tilt) # Tilt around y-axis quat_x = Quaternion.from_axisangle((1, 0, 0), tilt) # Tilt around x-axis
quat_z = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis quat_z = Quaternion.from_axisangle((0, 0, 1), rotation) # Rotate around z-axis
quat = quat_x * quat_z # Combined quaternion (first rotate around z, then tilt around x) quat = quat_x * quat_z # Combined quaternion (first rotate around z, then tilt around x)
# Calculate impact positions on the projected pixel coordinate grid (flip because quat.): # Calculate impact positions on the projected pixel coordinate grid (flip because quat.):
......
...@@ -22,7 +22,7 @@ __all__ = ['optimize_linear', 'optimize_nonlin', 'optimize_splitbregman'] ...@@ -22,7 +22,7 @@ __all__ = ['optimize_linear', 'optimize_nonlin', 'optimize_splitbregman']
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def optimize_linear(costfunction, max_iter=None): def optimize_linear(costfunction, max_iter=None, verbose=False):
"""Reconstruct a three-dimensional magnetic distribution from given phase maps via the """Reconstruct a three-dimensional magnetic distribution from given phase maps via the
conjugate gradient optimizaion method :func:`~.scipy.sparse.linalg.cg`. conjugate gradient optimizaion method :func:`~.scipy.sparse.linalg.cg`.
Blazingly fast for l2-based cost functions. Blazingly fast for l2-based cost functions.
...@@ -34,6 +34,9 @@ def optimize_linear(costfunction, max_iter=None): ...@@ -34,6 +34,9 @@ def optimize_linear(costfunction, max_iter=None):
regularisator which is minimized in the optimization process. regularisator which is minimized in the optimization process.
max_iter : int, optional max_iter : int, optional
The maximum number of iterations for the opimization. The maximum number of iterations for the opimization.
verbose: bool, optional
If set to True, information like a progressbar is displayed during reconstruction.
The default is False.
Returns Returns
------- -------
...@@ -44,7 +47,7 @@ def optimize_linear(costfunction, max_iter=None): ...@@ -44,7 +47,7 @@ def optimize_linear(costfunction, max_iter=None):
import jutil.cg as jcg import jutil.cg as jcg
_log.debug('Calling optimize_linear') _log.debug('Calling optimize_linear')
_log.info('Cost before optimization: {:.3e}'.format(costfunction(np.zeros(costfunction.n)))) _log.info('Cost before optimization: {:.3e}'.format(costfunction(np.zeros(costfunction.n))))
x_opt = jcg.conj_grad_minimize(costfunction, max_iter=max_iter).x x_opt = jcg.conj_grad_minimize(costfunction, max_iter=max_iter, verbose=verbose).x
_log.info('Cost after optimization: {:.3e}'.format(costfunction(x_opt))) _log.info('Cost after optimization: {:.3e}'.format(costfunction(x_opt)))
# Cut ramp parameters if necessary (this also saves the final parameters in the ramp class!): # Cut ramp parameters if necessary (this also saves the final parameters in the ramp class!):
x_opt = costfunction.fwd_model.ramp.extract_ramp_params(x_opt) x_opt = costfunction.fwd_model.ramp.extract_ramp_params(x_opt)
......
...@@ -8,6 +8,9 @@ import logging ...@@ -8,6 +8,9 @@ import logging
import numpy as np import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from jutil.taketime import TakeTime from jutil.taketime import TakeTime
from .. import reconstruction from .. import reconstruction
...@@ -23,7 +26,7 @@ _log = logging.getLogger(__name__) ...@@ -23,7 +26,7 @@ _log = logging.getLogger(__name__)
def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ramp_order=1, def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ramp_order=1,
plot_results=False, ar_dens=None): plot_results=False, ar_dens=None, verbose=True):
"""Convenience function for reconstructing a projected distribution from a single phasemap. """Convenience function for reconstructing a projected distribution from a single phasemap.
Parameters Parameters
...@@ -46,6 +49,9 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram ...@@ -46,6 +49,9 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram
ar_dens: int, optional ar_dens: int, optional
Number defining the arrow density which is plotted. A higher ar_dens number skips more Number defining the arrow density which is plotted. A higher ar_dens number skips more
arrows (a number of 2 plots every second arrow). Default is 1. arrows (a number of 2 plots every second arrow). Default is 1.
verbose: bool, optional
If set to True, information like a progressbar is displayed during reconstruction.
The default is False.
Returns Returns
------- -------
...@@ -64,7 +70,7 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram ...@@ -64,7 +70,7 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram
cost = Costfunction(fwd_model, reg) cost = Costfunction(fwd_model, reg)
# Reconstruct: # Reconstruct:
with TakeTime('reconstruction time'): with TakeTime('reconstruction time'):
magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter) magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter, verbose=verbose)
param_cache = cost.fwd_model.ramp.param_cache param_cache = cost.fwd_model.ramp.param_cache
if ramp_order is None: if ramp_order is None:
offset, ramp = 0, (0, 0) offset, ramp = 0, (0, 0)
...@@ -78,12 +84,16 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram ...@@ -78,12 +84,16 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram
if plot_results: if plot_results:
if ar_dens is None: if ar_dens is None:
ar_dens = np.max([1, np.max(dim) // 64]) ar_dens = np.max([1, np.max(dim) // 64])
axis = magdata_rec.plot_field('Reconstructed Distribution', figsize=(15, 15)) magdata_rec.plot_quiver_field('Reconstructed Distribution',
magdata_rec.plot_quiver(axis=axis, ar_dens=ar_dens, coloring='uniform') ar_dens=ar_dens, figsize=(16, 16))
phasemap.plot_combined('Input Phase')
phasemap -= fwd_model.ramp(index=0)
phasemap.plot_combined('Input Phase (ramp corrected)')
phasemap_rec = pm(magdata_rec) phasemap_rec = pm(magdata_rec)
gain = 4 * 2 * np.pi / (np.abs(phasemap_rec.phase).max() + 1E-30)
gain = round(gain, -int(np.floor(np.log10(abs(gain)))))
vmin = phasemap_rec.phase.min()
vmax = phasemap_rec.phase.max()
phasemap.plot_combined('Input Phase', gain=gain)
phasemap -= fwd_model.ramp(index=0)
phasemap.plot_combined('Input Phase (ramp corrected)', gain=gain, vmin=vmin, vmax=vmax)
title = 'Reconstructed Phase' title = 'Reconstructed Phase'
if ramp_order is not None: if ramp_order is not None:
if ramp_order >= 0: if ramp_order >= 0:
...@@ -92,11 +102,12 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram ...@@ -92,11 +102,12 @@ def reconstruction_2d_from_phasemap(phasemap, b_0=1, lam=1E-3, max_iter=100, ram
if ramp_order >= 1: if ramp_order >= 1:
print('ramp:', ramp) print('ramp:', ramp)
title += ', (Fitted Ramp: (u:{:.2g}, v:{:.2g}) [rad/nm]'.format(*ramp) title += ', (Fitted Ramp: (u:{:.2g}, v:{:.2g}) [rad/nm]'.format(*ramp)
phasemap_rec.plot_combined(title) phasemap_rec.plot_combined(title, gain=gain, vmin=vmin, vmax=vmax)
diff = (phasemap_rec - phasemap).phase diff = (phasemap_rec - phasemap).phase
diff_name = 'Difference (mean: {:.2g})'.format(diff.mean()) diff_name = 'Difference (RMS: {:.2g})'.format(np.sqrt(np.mean(diff) ** 2))
(phasemap_rec - phasemap).plot_phase(diff_name, sigma_clip=3) (phasemap_rec - phasemap).plot_phase(diff_name, sigma_clip=3)
if ramp_order is not None: if ramp_order is not None:
fwd_model.ramp(0).plot_combined('Fitted Ramp') ramp = fwd_model.ramp(0)
ramp.plot_phase('Fitted Ramp')
# Return reconstructed magnetisation distribution and cost function: # Return reconstructed magnetisation distribution and cost function:
return magdata_rec, cost return magdata_rec, cost
...@@ -28,7 +28,7 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_ ...@@ -28,7 +28,7 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_
angles=np.linspace(-90, 90, num=19), dim_uv=None, angles=np.linspace(-90, 90, num=19), dim_uv=None,
axes=(True, True), noise=0, offset_max=0, ramp_max=0, axes=(True, True), noise=0, offset_max=0, ramp_max=0,
use_internal_mask=True, plot_results=False, plot_input=False, use_internal_mask=True, plot_results=False, plot_input=False,
ar_dens=None, multicore=True): ar_dens=None, multicore=True, verbose=True):
"""Convenience function for reconstructing a projected distribution from a single phasemap. """Convenience function for reconstructing a projected distribution from a single phasemap.
Parameters Parameters
...@@ -78,6 +78,9 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_ ...@@ -78,6 +78,9 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_
multicore: boolean, optional multicore: boolean, optional
Determines if multiprocessing should be used. Default is True. Phasemap calculations Determines if multiprocessing should be used. Default is True. Phasemap calculations
will be divided onto the separate cores. will be divided onto the separate cores.
verbose: bool, optional
If set to True, information like a progressbar is displayed during reconstruction.
The default is False.
Returns Returns
------- -------
...@@ -128,7 +131,7 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_ ...@@ -128,7 +131,7 @@ def reconstruction_3d_from_magdata(magdata, b_0=1, lam=1E-3, max_iter=100, ramp_
cost = Costfunction(fwd_model, reg) cost = Costfunction(fwd_model, reg)
# Reconstruct and save: # Reconstruct and save:
with TakeTime('reconstruction time'): with TakeTime('reconstruction time'):
magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter) magdata_rec = reconstruction.optimize_linear(cost, max_iter=max_iter, verbose=verbose)
# Finalize ForwardModel (returns workers if multicore): # Finalize ForwardModel (returns workers if multicore):
fwd_model.finalize() fwd_model.finalize()
# Plot input: # Plot input:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment