From fa66b0a16ae97261190b1489299ee69dad68a30f Mon Sep 17 00:00:00 2001
From: caron <j.caron@fz-juelich.de>
Date: Wed, 27 Nov 2019 12:26:52 +0100
Subject: [PATCH] plotting fixes and improvements rgb_from_vector fixed to
 properly generate luminance. get_avrg_kern_field now properly handles fits
 without ramp. plot_avrg_kern_field now also returns the plotting axis.
 plot_quiver ar_dens behaviour changed:     now has 'auto' functionality and
 also bins instead of showing     every nth arrow. Arrow color is now white
 with black stroke.

---
 pyramid/analytic.py    |  1 +
 pyramid/colors.py      | 29 +++++++++++--------------
 pyramid/diagnostics.py |  6 ++++--
 pyramid/fielddata.py   | 49 ++++++++++++++++++++++++++----------------
 4 files changed, 48 insertions(+), 37 deletions(-)

diff --git a/pyramid/analytic.py b/pyramid/analytic.py
index ae383a3..e44fde6 100644
--- a/pyramid/analytic.py
+++ b/pyramid/analytic.py
@@ -82,6 +82,7 @@ def phase_mag_slab(dim, a, phi, center, width, b_0=1):
 
 
 def phase_mag_disc(dim, a, phi, center, radius, height, b_0=1):
+    # TODO: Parameter order should match magcreator.examples!
     """Calculate the analytic magnetic phase for a homogeneously magnetized disc.
 
     Parameters
diff --git a/pyramid/colors.py b/pyramid/colors.py
index 01cfa1b..05aef95 100644
--- a/pyramid/colors.py
+++ b/pyramid/colors.py
@@ -80,29 +80,24 @@ class Colormap3D(colors.Colormap, metaclass=abc.ABCMeta):
         """
         self._log.debug('Calling rgb_from_vector')
         x, y, z = np.asarray(vector)
-        # Calculate spherical coordinates:
-        r = np.sqrt(x ** 2 + y ** 2 + z ** 2)
+        R = np.sqrt(x ** 2 + y ** 2 + z ** 2)
+        R_max = vmax if vmax is not None else R.max() + 1E-30
+        # FIRST color dimension: HUE (1D ring/angular direction)
         phi = np.asarray(np.arctan2(y, x))
         phi[phi < 0] += 2 * np.pi
-        theta = np.arccos(z / (r + 1E-30))
-        # Determine saturation normalisation:
-        if vmax is not None:
-            R = vmax
-        else:
-            R = r.max() + 1E-30
-        # Calculate color deterministics:
         hue = phi / (2 * np.pi)
-        lum = 1 - theta / np.pi
-        sat = r / R
-        # Calculate RGB from hue with colormap:
         rgba = np.asarray(self(hue))
         r, g, b = rgba[..., 0], rgba[..., 1], rgba[..., 2]
-        # Interpolate saturation:
+        # SECOND color dimension: SATURATION (2D, in-plane)
+        rho = np.sqrt(x ** 2 + y ** 2)
+        sat = rho / R_max
         r, g, b = interpolate_color(sat, (0.5, 0.5, 0.5), np.stack((r, g, b), axis=-1))
-        # Interpolate luminance:
-        lum_target = np.where(lum < 0.5, 0, 1)
-        lum_target = np.stack([lum_target] * 3, axis=-1)
-        fraction = np.where(lum < 0.5, 1 - 2 * lum, 2 * (lum - 0.5))
+        # THIRD color dimension: LUMINANCE (3D, color sphere)
+        theta = np.arccos(z / R_max)
+        lum = 1 - theta / np.pi  # goes from 0 (black) over 0.5 (grey) to 1 (white)!
+        lum_target = np.where(lum < 0.5, 0, 1)  # Separate upper(white)/lower(black) hemispheres!
+        lum_target = np.stack([lum_target] * 3, axis=-1)  # [0, 0, 0] -> black / [1, 1, 1] -> white!
+        fraction = 2 * np.abs(lum - 0.5)  # 0.5: difference from grey, 2: scale to range (0, 1)!
         r, g, b = interpolate_color(fraction, np.stack((r, g, b), axis=-1), lum_target)
         # Return RGB:
         return np.asarray(255 * np.stack((r, g, b), axis=-1), dtype=np.uint8)
diff --git a/pyramid/diagnostics.py b/pyramid/diagnostics.py
index 8410577..bfeab03 100644
--- a/pyramid/diagnostics.py
+++ b/pyramid/diagnostics.py
@@ -195,7 +195,9 @@ class Diagnostics(object):
         if pos is not None:
             self.pos = pos
         magdata_avrg_kern = VectorData(self.cost.fwd_model.data_set.a, np.zeros((3,) + self.dim))
-        vector = self.avrg_kern_row[:-self.fwd_model.ramp.n]  # Only take vector field, not ramp!
+        # Only take vector field, not ramp [Special case of n=0 is caught by coalescing or:
+        # "x or y" returns x if x is True (here: -0 -> False), else y (None -> return whole array)]:
+        vector = self.avrg_kern_row[:(-self.fwd_model.ramp.n or None)]
         magdata_avrg_kern.set_vector(vector, mask=self.mask)
         return magdata_avrg_kern
 
@@ -401,7 +403,7 @@ class Diagnostics(object):
         artist = axis.add_patch(patches.Ellipse(xy, width, height, fill=False, edgecolor='w',
                                                 linewidth=2, alpha=0.5))
         artist.set_path_effects([patheffects.withStroke(linewidth=4, foreground='k', alpha=0.5)])
-        # TODO: Return axis on every plot?
+        return axis  # TODO: Return axis on every plot?
 
     def plot_avrg_kern_field3d(self, pos=None, mask=True, ellipsoid=True, **kwargs):
         avrg_kern_field = self.get_avrg_kern_field(pos)
diff --git a/pyramid/fielddata.py b/pyramid/fielddata.py
index 6eeee12..e11a9b7 100644
--- a/pyramid/fielddata.py
+++ b/pyramid/fielddata.py
@@ -884,7 +884,7 @@ class VectorData(FieldData):
     # # plt.savefig(directory + '/ch5-0-magnetic_distributions_v.png', bbox_inches='tight')
     # plt.savefig(directory + '/ch5-0-magnetic_distributions_v.pdf', bbox_inches='tight')
 
-    def plot_quiver(self, ar_dens=1, log=False, scaled=True, scale=1., b_0=None, qkey_unit='T',
+    def plot_quiver(self, ar_dens='auto', log=False, scaled=True, scale=1., b_0=None, qkey_unit='T',
                     coloring='angle', cmap=None,  # Used here and plot_streamlines!
                     proj_axis='z', ax_slice=None, show_mask=True, bgcolor=None, axis=None,
                     figsize=None, stroke=None, fontsize=None, **kwargs):
@@ -947,20 +947,30 @@ class VectorData(FieldData):
             stroke = plottools.STROKE_DEFAULT
         assert proj_axis == 'z' or proj_axis == 'y' or proj_axis == 'x', \
             "Axis has to be 'x', 'y' or 'z'."
+        # TODO: Deprecate ar_dens in favor of scale_it, call it "binsize" or something...
+        # TODO: None or 'auto' as default for ar_dens/bin_size, set something sensible!!!
+        if ar_dens == 'auto':
+            ar_dens = np.max((1,  np.max(self.dim) // 16))
+            # TODO: Just for now, delete later, when power of 2 no longer needed:
+            ar_dens = int(2**(np.log2(ar_dens)//1))
+        if ar_dens > 1:
+            scale_it = np.log2(ar_dens)
+            assert scale_it % 1 == 0, 'ar_dens has to be power of 2 (for now)!'  # TODO: Delete!
+            vecdata = self.scale_down(int(scale_it))
+        else:
+            vecdata = self
+        # Extract slice and mask:
         if ax_slice is None:
             ax_slice = self.dim[{'z': 0, 'y': 1, 'x': 2}[proj_axis]] // 2
-        # Extract slice and mask:
-        u_mag, v_mag = self.get_slice(ax_slice, proj_axis)[:2]
+        ax_slice //= ar_dens
+        # TODO: Currently, the slice is taken AFTER the scale down. In EMPyRe, this plotting
+        # TODO: function should be called on the slice (another 2D Vector Field class) instead!
+        # TODO: The slicing is therefore done before the scale down and more accurate!(??)
+        u_mag, v_mag = vecdata.get_slice(ax_slice, proj_axis)[:2]
         submask = np.where(np.hypot(u_mag, v_mag) > 0, True, False)
         # Prepare quiver (select only used arrows if ar_dens is specified):
-        # TODO: None or 'auto' as default for ar_dens, set something sensible!!!
-        # TODO: ALSO mean instead of every nth arrow (use extend to cover the same space!)
         dim_uv = u_mag.shape
-        vv, uu = np.indices(dim_uv) + 0.5  # shift to center of pixel
-        uu = uu[::ar_dens, ::ar_dens]
-        vv = vv[::ar_dens, ::ar_dens]
-        u_mag = u_mag[::ar_dens, ::ar_dens]
-        v_mag = v_mag[::ar_dens, ::ar_dens]
+        vv, uu = (np.indices(dim_uv) + 0.5) * ar_dens  # 0.5: shift to center of pixel!
         amplitudes = np.hypot(u_mag, v_mag)
         # TODO: Delete if only used in log:
         # angles = np.angle(u_mag + 1j * v_mag, deg=True).tolist()
@@ -983,14 +993,12 @@ class VectorData(FieldData):
         elif coloring == 'uniform':
             self._log.debug('Automatic uniform color encoding')
             hue = amplitudes / amplitudes.max()
-            if bgcolor == 'white':
-                cmap = colors.cmaps['transparent_black']
-            else:
-                cmap = colors.cmaps['transparent_white']
+            cmap = colors.cmaps['transparent_white']
         else:
             self._log.debug('Specified uniform color encoding')
             hue = np.zeros_like(u_mag)
             cmap = ListedColormap([coloring])
+        edgecolors = colors.cmaps['transparent_black'](amplitudes / amplitudes.max()).reshape(-1, 4)
         if cmap_overwrite is not None:
             cmap = cmap_overwrite
         # If no axis is specified, a new figure is created:
@@ -1019,9 +1027,9 @@ class VectorData(FieldData):
         quiv = axis.quiver(uu, vv, u_mag, v_mag, hue, cmap=cmap, clim=(0, 1),  # angles=angles,
                            pivot='middle', units='xy', scale_units='xy', scale=scale / ar_dens,
                            minlength=0.05, width=1*ar_dens, headlength=2, headaxislength=2,
-                           headwidth=2, minshaft=2)
-        axis.set_xlim(0, dim_uv[1])
-        axis.set_ylim(0, dim_uv[0])
+                           headwidth=2, minshaft=2, edgecolor=edgecolors, linewidths=1)
+        axis.set_xlim(0, dim_uv[1]*ar_dens)
+        axis.set_ylim(0, dim_uv[0]*ar_dens)
         # Determine colormap if necessary:
         if coloring == 'amplitude':
             cbar_mappable, cbar_label = quiv, 'amplitude'
@@ -1035,7 +1043,7 @@ class VectorData(FieldData):
             mask_color = 'white' if bgcolor == 'black' else 'black'
             axis.contour(uu, vv, submask, levels=[0.5], colors=mask_color,
                          linestyles='dotted', linewidths=2)
-        # Plot quiverkey if B_0 is specified):
+        # Plot quiverkey if B_0 is specified:
         if b_0 and not log:  # The angles needed for log would break the quiverkey!
             label = '{:.3g} {}'.format(amplitudes.max() * b_0, qkey_unit)
             quiv.angles = 'uv'  # With a list of angles, the quiverkey would break!
@@ -1506,6 +1514,11 @@ class ScalarData(FieldData):
         Only possible, if each axis length is a power of 2!
 
         """
+        # TODO: Florian: make more general, not just power of 2, n should denote bin size, not
+        # TODO: number of scale down operations! Same for scale up!
+        # TODO: def rebin(a, shape):
+        # TODO:     sh = (shape[0], a.shape[0] // shape[0], shape[1], a.shape[1] // shape[1])
+        # TODO:     return a.reshape(sh).mean(-1).mean(1)
         self._log.debug('Calling scale_down')
         assert n > 0 and isinstance(n, int), 'n must be a positive integer!'
         a_new = self.a * 2 ** n
-- 
GitLab