Skip to content
Snippets Groups Projects
Commit 0cabb8c5 authored by Jan Caron's avatar Jan Caron
Browse files

Added comp_pos argument to from_signal method

Can be used to specify which axis holds the components.
Also usable by load_field function!
Useful for backwards compatibilty with Pyramid.
parent 81ce0fe6
No related branches found
No related tags found
1 merge request!30Release 0.2.0
Pipeline #26881 passed
...@@ -101,7 +101,7 @@ class Field(NDArrayOperatorsMixin): ...@@ -101,7 +101,7 @@ class Field(NDArrayOperatorsMixin):
assert len(scale) == ndim, f'Each of the {ndim} dimensions needs a scale, but {scale} was given!' assert len(scale) == ndim, f'Each of the {ndim} dimensions needs a scale, but {scale} was given!'
self.__scale = scale self.__scale = scale
else: else:
raise AssertionError('Scaling has to be a number or a tuple of numbers!') raise AssertionError(f'Scaling has to be a number or a tuple of numbers, was {scale} instead!')
@property @property
def shape(self): def shape(self):
...@@ -141,7 +141,7 @@ class Field(NDArrayOperatorsMixin): ...@@ -141,7 +141,7 @@ class Field(NDArrayOperatorsMixin):
self.vector = vector # Set vector before scale, because scale needs to know if vector (via calling dim)! self.vector = vector # Set vector before scale, because scale needs to know if vector (via calling dim)!
self.scale = scale self.scale = scale
if self.vector: if self.vector:
assert self.ncomp in (2, 3), 'Only 2- or 3-dimensional vector fields are supported!' assert self.ncomp in (2, 3), 'Only 2- or 3-component vector fields are supported!'
self._log.debug('Created ' + str(self)) self._log.debug('Created ' + str(self))
def __repr__(self): def __repr__(self):
...@@ -292,7 +292,7 @@ class Field(NDArrayOperatorsMixin): ...@@ -292,7 +292,7 @@ class Field(NDArrayOperatorsMixin):
return cls(np.stack(scalar_list, axis=-1), scalar_list[0].scale, vector=True) return cls(np.stack(scalar_list, axis=-1), scalar_list[0].scale, vector=True)
@classmethod @classmethod
def from_signal(cls, signal, scale=None, vector=False): def from_signal(cls, signal, scale=None, vector=None, comp_pos=-1):
"""Convert a :class:`~hyperspy.signals.Signal` object to a :class:`~.Field` object. """Convert a :class:`~hyperspy.signals.Signal` object to a :class:`~.Field` object.
Parameters Parameters
...@@ -310,6 +310,11 @@ class Field(NDArrayOperatorsMixin): ...@@ -310,6 +310,11 @@ class Field(NDArrayOperatorsMixin):
If set to True, forces the signal to be interpreted as a vector field. EMPyRe will check if the first axis If set to True, forces the signal to be interpreted as a vector field. EMPyRe will check if the first axis
is named 'vector components' (EMPyRe saves vector fields like this). If this is the case, vector will be is named 'vector components' (EMPyRe saves vector fields like this). If this is the case, vector will be
automatically set to True and the signal will also be interpreted as a vector field. automatically set to True and the signal will also be interpreted as a vector field.
comp_pos: int, optoinal
The index of the axis containing the vector components (if `vector=True`). EMPyRe needs this to be the last
axis (index `-1`, which is the default). In case another position is given, the vector component will be
moved to the last axis. Old Pyramid files will have this axis at index `0`, so use this for backwards
compatibilty.
Notes Notes
----- -----
...@@ -318,10 +323,12 @@ class Field(NDArrayOperatorsMixin): ...@@ -318,10 +323,12 @@ class Field(NDArrayOperatorsMixin):
""" """
cls._log.debug('Calling from_signal') cls._log.debug('Calling from_signal')
data = signal.data data = signal.data
if signal.axes_manager[0].name == 'vector components': if vector and comp_pos != -1: # component axis should be last, but is currently first -> roll to the end:
vector = True # Automatic detection! data = np.moveaxis(data, source=comp_pos, destination=-1)
if scale is None: # If not provided, try to read from axes_manager: if vector is None: # Automatic detection:
scale = [signal.axes_manager[i].scale for i in range(len(data.shape) - vector)] # One less axis if vector! vector = True if signal.axes_manager[0].name == 'vector components' else False
if scale is None: # If not provided, try to read from axes_manager, one less axis if vector!:
scale = tuple([signal.axes_manager[i].scale for i in range(len(data.shape) - vector)])
return cls(data, scale, vector) return cls(data, scale, vector)
def to_signal(self): def to_signal(self):
......
...@@ -21,7 +21,7 @@ def load_field(filename, scale=None, vector=None, **kwargs): ...@@ -21,7 +21,7 @@ def load_field(filename, scale=None, vector=None, **kwargs):
The function loads the file according to the extension: The function loads the file according to the extension:
SCALAR??? SCALAR???
- hdf5 for HDF5. - hdf5 for HDF5. # TODO: You can use comp_pos here!!!
- EMD Electron Microscopy Dataset format (also HDF5). - EMD Electron Microscopy Dataset format (also HDF5).
- npy or npz for numpy formats. - npy or npz for numpy formats.
...@@ -67,14 +67,15 @@ def load_field(filename, scale=None, vector=None, **kwargs): ...@@ -67,14 +67,15 @@ def load_field(filename, scale=None, vector=None, **kwargs):
for plugin in plugin_list: # Iterate over all plugins: for plugin in plugin_list: # Iterate over all plugins:
if extension in plugin.file_extensions: # Check if extension is recognised: if extension in plugin.file_extensions: # Check if extension is recognised:
return plugin.reader(filename, scale=scale, vector=vector, **kwargs) return plugin.reader(filename, scale=scale, vector=vector, **kwargs)
# If nothing was found, try HyperSpy: # If nothing was found, try HyperSpy
_log.debug('Using HyperSpy') _log.debug('Using HyperSpy')
try: try:
import hyperspy.api as hs import hyperspy.api as hs
except ImportError: except ImportError:
_log.error('This extension recquires the hyperspy package!') _log.error('This extension recquires the hyperspy package!')
return return
return Field.from_signal(hs.load(filename, **kwargs), scale=scale, vector=vector) comp_pos = kwargs.pop('comp_pos', -1)
return Field.from_signal(hs.load(filename, **kwargs), scale=scale, vector=vector, comp_pos=comp_pos)
def save_field(filename, field, **kwargs): def save_field(filename, field, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment