From 0cabb8c55824daabd2447d55252e7d0305214d02 Mon Sep 17 00:00:00 2001
From: Jan Caron <j.caron@fz-juelich.de>
Date: Fri, 20 Mar 2020 17:35:32 +0100
Subject: [PATCH] Added comp_pos argument to from_signal method Can be used to
 specify which axis holds the components. Also usable by load_field function!
 Useful for backwards compatibilty with Pyramid.

---
 src/empyre/fields/field.py | 21 ++++++++++++++-------
 src/empyre/io/io_field.py  |  7 ++++---
 2 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/src/empyre/fields/field.py b/src/empyre/fields/field.py
index afb7111..66f3f00 100644
--- a/src/empyre/fields/field.py
+++ b/src/empyre/fields/field.py
@@ -101,7 +101,7 @@ class Field(NDArrayOperatorsMixin):
             assert len(scale) == ndim, f'Each of the {ndim} dimensions needs a scale, but {scale} was given!'
             self.__scale = scale
         else:
-            raise AssertionError('Scaling has to be a number or a tuple of numbers!')
+            raise AssertionError(f'Scaling has to be a number or a tuple of numbers, was {scale} instead!')
 
     @property
     def shape(self):
@@ -141,7 +141,7 @@ class Field(NDArrayOperatorsMixin):
         self.vector = vector  # Set vector before scale, because scale needs to know if vector (via calling dim)!
         self.scale = scale
         if self.vector:
-            assert self.ncomp in (2, 3), 'Only 2- or 3-dimensional vector fields are supported!'
+            assert self.ncomp in (2, 3), 'Only 2- or 3-component vector fields are supported!'
         self._log.debug('Created ' + str(self))
 
     def __repr__(self):
@@ -292,7 +292,7 @@ class Field(NDArrayOperatorsMixin):
         return cls(np.stack(scalar_list, axis=-1), scalar_list[0].scale, vector=True)
 
     @classmethod
-    def from_signal(cls, signal, scale=None, vector=False):
+    def from_signal(cls, signal, scale=None, vector=None, comp_pos=-1):
         """Convert a :class:`~hyperspy.signals.Signal` object to a :class:`~.Field` object.
 
         Parameters
@@ -310,6 +310,11 @@ class Field(NDArrayOperatorsMixin):
             If set to True, forces the signal to be interpreted as a vector field. EMPyRe will check if the first axis
             is named 'vector components' (EMPyRe saves vector fields like this). If this is the case, vector will be
             automatically set to True and the signal will also be interpreted as a vector field.
+        comp_pos: int, optoinal
+            The index of the axis containing the vector components (if `vector=True`). EMPyRe needs this to be the last
+            axis (index `-1`, which is the default). In case another position is given, the vector component will be
+            moved to the last axis. Old Pyramid files will have this axis at index `0`, so use this for backwards
+            compatibilty.
 
         Notes
         -----
@@ -318,10 +323,12 @@ class Field(NDArrayOperatorsMixin):
         """
         cls._log.debug('Calling from_signal')
         data = signal.data
-        if signal.axes_manager[0].name == 'vector components':
-            vector = True  # Automatic detection!
-        if scale is None:  # If not provided, try to read from axes_manager:
-            scale = [signal.axes_manager[i].scale for i in range(len(data.shape) - vector)]  # One less axis if vector!
+        if vector and comp_pos != -1:  # component axis should be last, but is currently first -> roll to the end:
+            data = np.moveaxis(data, source=comp_pos, destination=-1)
+        if vector is None:  # Automatic detection:
+            vector = True if signal.axes_manager[0].name == 'vector components' else False
+        if scale is None:  # If not provided, try to read from axes_manager, one less axis if vector!:
+            scale = tuple([signal.axes_manager[i].scale for i in range(len(data.shape) - vector)])
         return cls(data, scale, vector)
 
     def to_signal(self):
diff --git a/src/empyre/io/io_field.py b/src/empyre/io/io_field.py
index 6dd3e32..20a011e 100644
--- a/src/empyre/io/io_field.py
+++ b/src/empyre/io/io_field.py
@@ -21,7 +21,7 @@ def load_field(filename, scale=None, vector=None, **kwargs):
 
     The function loads the file according to the extension:
         SCALAR???
-        - hdf5 for HDF5.
+        - hdf5 for HDF5.  # TODO: You can use comp_pos here!!!
         - EMD Electron Microscopy Dataset format (also HDF5).
         - npy or npz for numpy formats.
 
@@ -67,14 +67,15 @@ def load_field(filename, scale=None, vector=None, **kwargs):
     for plugin in plugin_list:  # Iterate over all plugins:
         if extension in plugin.file_extensions:  # Check if extension is recognised:
             return plugin.reader(filename, scale=scale, vector=vector, **kwargs)
-    # If nothing was found, try HyperSpy:
+    # If nothing was found, try HyperSpy
     _log.debug('Using HyperSpy')
     try:
         import hyperspy.api as hs
     except ImportError:
         _log.error('This extension recquires the hyperspy package!')
         return
-    return Field.from_signal(hs.load(filename, **kwargs), scale=scale, vector=vector)
+    comp_pos = kwargs.pop('comp_pos', -1)
+    return Field.from_signal(hs.load(filename, **kwargs), scale=scale, vector=vector, comp_pos=comp_pos)
 
 
 def save_field(filename, field, **kwargs):
-- 
GitLab