From 08c92d06de51d4f02810bb3b925817d676b17386 Mon Sep 17 00:00:00 2001
From: Jan Caron <j.caron@fz-juelich.de>
Date: Mon, 29 Oct 2018 13:29:26 +0100
Subject: [PATCH] Last commit before mayavi update try!

---
 pyramid/colors.py                | 9 +++++++--
 pyramid/file_io/io_vectordata.py | 4 ++--
 pyramid/phasemap.py              | 9 +++++----
 3 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/pyramid/colors.py b/pyramid/colors.py
index 4051ef1..cdcaa56 100644
--- a/pyramid/colors.py
+++ b/pyramid/colors.py
@@ -65,7 +65,7 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
 
     _log = logging.getLogger(__name__ + '.Colormap3D')
 
-    def rgb_from_vector(self, vector):
+    def rgb_from_vector(self, vector, vmax=None):
         """Construct a hls tuple from three coordinates representing a 3D direction.
 
         Parameters
@@ -87,10 +87,15 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
         phi = np.asarray(np.arctan2(y, x))
         phi[phi < 0] += 2 * np.pi
         theta = np.arccos(z / (r + 1E-30))
+        # Determine saturation normalisation:
+        if vmax is not None:
+            R = vmax
+        else:
+            R = r.max() + 1E-30
         # Calculate color deterministics:
         hue = phi / (2 * np.pi)
         lum = 1 - theta / np.pi
-        sat = r / (r.max() + 1E-30)
+        sat = r / R
         # Calculate RGB from hue with colormap:
         rgba = np.asarray(self(hue))
         r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2]
diff --git a/pyramid/file_io/io_vectordata.py b/pyramid/file_io/io_vectordata.py
index 93da805..9373c4c 100644
--- a/pyramid/file_io/io_vectordata.py
+++ b/pyramid/file_io/io_vectordata.py
@@ -170,7 +170,7 @@ def _load_from_ovf(filename, a):
             if not np.allclose(xstep, ystep) and np.allclose(xstep, zstep):
                 _log.warning('Grid spacing is not equal in x, y and z (x will be used)!\n'
                              'Found step sizes are x:{}, y:{}, z:{} (all in {})!'.format(
-                                xstep, ystep, zstep, header.get('meshunit')))
+                              xstep, ystep, zstep, header.get('meshunit')))
             # Extract grid spacing from xstepsize and convert according to meshunit:
             unit = header.get('meshunit', 'nm')
             if unit == 'unspecified':
@@ -221,7 +221,7 @@ def _load_from_vtk(filename, a=None, **kwargs):
         _log.info('geometry: StructuredPoints')
         # Load relevant information from output (reverse to get typical Python order z,y,x):
         dim = output.origin[::-1]
-        origin = output.spacing[::-1]   
+        origin = output.spacing[::-1]
         spacing = output.dimensions[::-1]
         assert len(dim) == 3, 'Data has to be three-dimensional!'
         assert spacing[0] == spacing[1] == spacing[2], \
diff --git a/pyramid/phasemap.py b/pyramid/phasemap.py
index 93d3652..2555b22 100644
--- a/pyramid/phasemap.py
+++ b/pyramid/phasemap.py
@@ -730,7 +730,7 @@ class PhaseMap(object):
         return plottools.format_axis(axis, sampling=a, cbar_mappable=cbar_mappable,
                                      cbar_label=cbar_label, tight_layout=tight, **kwargs)
 
-    def plot_holo(self, gain='auto',  # specific to plot_holo!
+    def plot_holo(self, gain='auto', colorwheel=True,  # specific to plot_holo!
                   cmap=None, interpolation='none', axis=None, figsize=None, sigma_clip=2,
                   **kwargs):
         """Display the color coded holography image.
@@ -808,9 +808,10 @@ class PhaseMap(object):
             note = 'gain: {:g}'.format(gain)
         stroke = kwargs.pop('stroke', 'k')  # Default for holo is white with black outline!
         return plottools.format_axis(axis, sampling=a, note=note, tight_layout=tight,
-                                     stroke=stroke, **kwargs)
+                                     colorwheel=colorwheel, stroke=stroke, **kwargs)
 
-    def plot_combined(self, title='', phase_title='', holo_title='', figsize=None, **kwargs):
+    def plot_combined(self, title='', phase_title='', holo_title='', figsize=None,
+                      colorwheel=True, **kwargs):
         """Display the phase map and the resulting color coded holography image in one plot.
 
         Parameters
@@ -844,7 +845,7 @@ class PhaseMap(object):
         note = kwargs.pop('note', None)
         # Plot holography image:
         holo_axis = fig.add_subplot(1, 2, 1, aspect='equal')
-        self.plot_holo(axis=holo_axis, title=holo_title, note=None, **kwargs)
+        self.plot_holo(axis=holo_axis, title=holo_title, note=None, colorwheel=colorwheel, **kwargs)
         # Plot phase map:
         phase_axis = fig.add_subplot(1, 2, 2, aspect='equal')
         self.plot_phase(axis=phase_axis, title=phase_title, note=note, **kwargs)
-- 
GitLab