Skip to content
Snippets Groups Projects
Commit 6d2fa198 authored by Jan Caron's avatar Jan Caron
Browse files

Added io readers/writers and plot3d (untested)

parent 0db0c54b
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing EMPyRe IO functionality for several EMPyRe classes."""
from .io_field import *
__all__ = []
__all__.extend(io_field.__all__)
del io_field
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Subpackage containing EMPyRe IO functionality for the Field class."""
from . import llg, numpy, ovf, tec, text, vtk
plugin_list = [llg, numpy, ovf, tec, text, vtk]
__all__ = ['plugin_list']
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for LLG format."""
import logging
import numpy as np
from ..fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.llg',) # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = True
assert vector is True, 'Only vector fields can be read from the llg file format!'
SCALE = 1.0E-9 / 1.0E-2 # From cm to nm
data_columns = np.genfromtxt(filename, skip_header=2)
dim = tuple(np.genfromtxt(filename, dtype=int, skip_header=1, skip_footer=len(data_columns[:, 0])))
if scale is None: # Otherwise overwrite!
stride_x = 1 # x varies fastest
stride_y = dim[2] # one y step for dim[2] x steps
stride_z = dim[1] * dim[2] # one z step for one x-y layer (dim[1]*dim[2])
scale_x = (data_columns[stride_x, 0] - data_columns[0, 0]) / SCALE # first column varies in x
scale_y = (data_columns[stride_y, 1] - data_columns[0, 1]) / SCALE # second column varies in y
scale_z = (data_columns[stride_z, 2] - data_columns[0, 2]) / SCALE # third column varies in z
scale = (scale_z, scale_y, scale_x)
data = data_columns[:, 3:6].T.reshape(dim, (3,))
return Field(data, scale, vector=True)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
assert field.vector and len(field.dim) == 3, 'Only 3D vector fields can be saved to the llg file format!'
SCALE = 1.0E-9 / 1.0E-2 # from nm to cm
# Create 3D meshgrid and reshape it and the field into a list where x varies first:
zzz, yyy, xxx = (np.indices(field.dim) + 0.5) * np.reshape(field.scale, (3, 1, 1, 1)) * SCALE # broadcast shape!
z_coord, y_coord, x_coord = np.ravel(zzz), np.ravel(yyy), np.ravel(xxx) # Turn into vectors!
x_comp, y_comp, z_comp = field.comp # Extract scalar field components!
x_vec, y_vec, z_vec = np.ravel(x_comp.data), np.ravel(y_comp.data), np.ravel(z_comp.data) # Turn into vectors!
data = np.array([x_coord, y_coord, z_coord, x_vec, y_vec, z_vec]).T
# Save data to file:
with open(filename, 'w') as mag_file:
mag_file.write('LLGFileCreator: EMPyRe vector Field\n')
mag_file.write('\t{:d}\t{:d}\t{:d}\n'.format(field.dim))
mag_file.writelines('\n'.join(' '.join('{:7.6e}'.format(cell) for cell in row) for row in data))
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for the numpy format."""
import logging
import numpy as np
from ..fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.npy', '.npz') # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = False
if scale is None:
scale = 1
return Field(np.load(filename, **kwargs), scale, vector)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
np.save(filename, field.data, **kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
import os
import numpy as np
from ..fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.ovf', '.omf', '.ohf', 'obf') # Recognised file extensions
def reader(filename, scale=None, vector=True, segment=None, **kwargs):
"""More info at:
http://math.nist.gov/oommf/doc/userguide11b2/userguide/vectorfieldformat.html
http://math.nist.gov/oommf/doc/userguide12a5/userguide/OVF_2.0_format.html
"""
_log.debug('Calling reader')
assert vector is True, 'Only vector fields can be loaded from ovf-files!'
with open(filename, 'rb') as load_file:
line = load_file.readline()
assert line.startswith(b'# OOMMF') # File has OVF format!
read_header, read_data = False, False
header = {'version': line.split()[-1].decode('utf-8')}
x_mag, y_mag, z_mag = [], [], []
data_mode = None
while True:
# --- READ START OF FILE OR IN BETWEEN SEGMENTS LINE BY LINE ---------------------------
if not read_header and not read_data: # Start of file or between segments!
line = load_file.readline()
if line == b'':
break # End of file is reached!
if line.startswith(b'# Segment count'):
seg_count = int(line.split()[-1]) # Total number of segments (often just 1)!
seg_curr = 0 # Current segment (0: not in first segment, yet!)
if seg_count > 1: # If multiple segments, check if "segment" was set correctly:
assert segment is not None, (f'Multiple ({seg_count}) segments were found! '
'Chose one via the segment parameter!')
elif segment is None: # Only one segment AND parameter not set:
segment = 1 # Default to the first/only segment!
assert 0 < segment <= seg_count, (f'parameter segment={segment} out of bounds, '
f'Use value between 1 and {seg_count}!')
header['segment_count'] = seg_count
elif line.startswith(b'# Begin: Segment'): # Segment start!
seg_curr += 1
if seg_curr > segment:
break # Stop reading the file!
elif line.startswith(b'# Begin: Header'): # Header start!
read_header = True
elif line.startswith(b'# Begin: Data'): # Data start!
read_data = True
data_mode = ' '.join(line.decode('utf-8').split()[3:])
assert data_mode in ['text', 'Text', 'Binary 4', 'Binary 8'], \
f'Data mode {data_mode} is currently not supported by this reader!'
assert header.get('meshtype') == 'rectangular', \
'Only rectangular grids can be currently read!'
# --- READ HEADER LINE BY LINE ---------------------------------------------------------
elif read_header:
line = load_file.readline()
if line.startswith(b'# End: Header'): # Header is done:
read_header = False
continue
line = line.decode('utf-8') # Decode to use strings here!
line_list = line.split()
if '##' in line_list: # Strip trailing comments:
del line_list[line_list.index('##'):]
if len(line_list) <= 1: # Just '#' or empty line:
continue
key, value = line_list[1].strip(':'), ' '.join(line_list[2:])
if key not in header: # Add new key, value pair if not existant:
header[key] = value
elif key == 'Desc': # Description can go over several lines:
header['Desc'] = ' '.join([header['Desc'], value])
# --- READ DATA LINE BY LINE -----------------------------------------------------------
elif read_data: # Currently in a data block:
if data_mode in ['text', 'Text']: # Read data as text, line by line:
line = load_file.readline()
if line.startswith(b'# End: Data'):
read_data = False # Stop reading data and search for new segments!
continue
elif seg_curr < segment: # Do nothing with the line if wrong segment!
continue
else:
x, y, z = [float(i) for i in line.split()]
x_mag.append(x)
y_mag.append(y)
z_mag.append(z)
elif 'Binary' in data_mode: # Read data as binary, all bytes at the same time:
# Currently every segment is read until the wanted one is processed. Only that one is returned!
count = int(data_mode.split()[-1]) # Either 4 or 8!
if header['version'] == '1.0': # Big endian float:
dtype = f'>f{count}'
elif header['version'] == '2.0': # Little endian float:
dtype = f'<f{count}'
test = np.fromfile(load_file, dtype=dtype, count=1) # Read test byte!
if count == 4: # Binary 4:
assert test == 1234567.0, 'Wrong test bytes!'
elif count == 8: # Binary 8:
assert test == 123456789012345.0, 'Wrong test bytes!'
dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes']))
data_raw = np.fromfile(load_file, dtype=dtype, count=3*np.prod(dim))
x_mag, y_mag, z_mag = data_raw[0::3], data_raw[1::3], data_raw[2::3]
read_data = False # Stop reading data and search for new segments (if any).
# --- READING DONE -------------------------------------------------------------------------
# Format after reading:
dim = (int(header['znodes']), int(header['ynodes']), int(header['xnodes']))
x_mag = np.asarray(x_mag).reshape(dim)
y_mag = np.asarray(y_mag).reshape(dim)
z_mag = np.asarray(z_mag).reshape(dim)
data = np.stack((x_mag, y_mag, z_mag), axis=-1) * float(header.get('valuemultiplier', 1))
if scale is None:
unit = header.get('meshunit', 'nm')
if unit == 'unspecified':
unit = 'nm'
_log.info(f'unit: {unit}')
unit_scale = {'m': 1e9, 'mm': 1e6, 'µm': 1e3, 'nm': 1}[unit]
xstep = float(header.get('xstepsize')) * unit_scale
ystep = float(header.get('ystepsize')) * unit_scale
zstep = float(header.get('zstepsize')) * unit_scale
scale = (zstep, ystep, xstep)
return Field(data, scale=scale, vector=True)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
assert field.vector and len(field.dim) == 3, 'Only 3D vector fields can be saved to ovf files!'
with open(filename, 'w') as save_file:
save_file.write('# OOMMF OVF 2.0\n')
save_file.write('# Segment count: 1\n')
save_file.write('# Begin: Segment\n')
# Write Header:
save_file.write('# Begin: Header\n')
name = os.path.split(os.path.split(filename)[1])
save_file.write(f'# Title: PYRAMID-VECTORDATA {name}\n')
save_file.write('# meshtype: rectangular\n')
save_file.write('# meshunit: nm\n')
save_file.write('# valueunit: A/m\n')
save_file.write('# valuemultiplier: 1.\n')
save_file.write('# xmin: 0.\n')
save_file.write('# ymin: 0.\n')
save_file.write('# zmin: 0.\n')
save_file.write(f'# xmax: {field.scale[2] * field.dim[2]}\n')
save_file.write(f'# ymax: {field.scale[1] * field.dim[1]}\n')
save_file.write(f'# zmax: {field.scale[0] * field.dim[0]}\n')
save_file.write('# xbase: 0.\n')
save_file.write('# ybase: 0.\n')
save_file.write('# zbase: 0.\n')
save_file.write(f'# xstepsize: {field.scale[2]}\n')
save_file.write(f'# ystepsize: {field.scale[1]}\n')
save_file.write(f'# zstepsize: {field.scale[0]}\n')
save_file.write(f'# xnodes: {field.dim[2]}\n')
save_file.write(f'# ynodes: {field.dim[1]}\n')
save_file.write(f'# znodes: {field.dim[0]}\n')
save_file.write('# End: Header\n')
# Write data:
save_file.write('# Begin: Data Text\n')
x_mag, y_mag, z_mag = field.comp
x_mag = x_mag.data.ravel()
y_mag = y_mag.data.ravel()
z_mag = z_mag.data.ravel()
for i in range(np.prod(field.dim)):
save_file.write(f'{x_mag[i]:g} {y_mag[i]:g} {z_mag[i]:g}\n')
save_file.write('# End: Data Text\n')
save_file.write('# End: Segment\n')
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
import re
from numbers import Number
import numpy as np
from ..fields.field import Field
from ..utils.misc import interp_to_regular_grid
_log = logging.getLogger(__name__)
file_extensions = ('.tec',) # Recognised file extensions
def reader(filename, dim, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = False
assert vector is True, 'Only vector fields can be loaded from tec-files!'
with open(filename, 'r') as mag_file:
lines = mag_file.readlines() # Read in lines!
match = re.search(R'N=(\d+)', lines[2]) # Extract number of points from third line!
if match:
n_points = int(match.group(1))
else:
raise IOError('File does not seem to match .tec format!')
n_head, n_foot = 3, len(lines) - (3 + n_points)
# Read in data:
data_raw = np.genfromtxt(filename, skip_header=n_head, skip_footer=n_foot)
if scale is None:
raise ValueError('For the interpolation of unstructured grids, the `scale` parameter is required!')
elif isinstance(scale, Number): # Scale is the same for each dimension!
scale = (scale,) * len(dim)
elif isinstance(scale, tuple):
assert len(scale) == len(dim), f'Each of the {len(dim)} dimensions needs a scale, but {scale} was given!'
data = interp_to_regular_grid(data_raw[:, :3], data_raw[:, 3:], scale, **kwargs)
return Field(data, scale, vector=False)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
raise NotImplementedError('A writer for this extension is not yet implemented!')
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
import ast
import numpy as np
from ..fields.field import Field
_log = logging.getLogger(__name__)
file_extensions = ('.txt',) # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
_log.debug('Call reader')
if vector is None:
vector = False
assert vector is False, 'Only scalar 2D fields can currently be read with this file reader!'
with open(filename, 'r') as load_file: # Read data:
empyre_format = load_file.readline().startswith('EMPYRE-FORMAT')
if empyre_format: # File has EMPyRe structure:
scale = ast.literal_eval(load_file.readline()[8:-4]) # [8:-4] takes just the scale string!
data = np.loadtxt(filename, delimiter='\t', skiprows=2) # skips header!
else: # Try default with provided kwargs:
scale = 1 if scale is None else scale # Set default if not provided!
data = np.loadtxt(filename, **kwargs)
return Field(data, scale=scale, vector=False)
def writer(filename, field, with_header=True, **kwargs):
_log.debug('Call writer')
assert not field.vector, 'Vector fields can currently not be saved to text!'
assert len(field.dim) == 2, 'Only 2D fields can currenty be saved to text!'
if with_header: # write header:
with open(filename, 'w') as save_file:
save_file.write('EMPYRE-FORMAT\n')
save_file.write(f'scale = {field.scale} nm\n')
save_kwargs = {'fmt': '%7.6e', 'delimiter': '\t'}
else:
save_kwargs = kwargs
with open(filename, 'ba') as save_file: # the 'a' is for append!
np.savetxt(save_file, field.data, **save_kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO plugin for simple text format."""
import logging
from numbers import Number
import numpy as np
from ..fields.field import Field
from ..utils.misc import interp_to_regular_grid
from ..vis import colors
_log = logging.getLogger(__name__)
file_extensions = ('.vtk',) # Recognised file extensions
def reader(filename, scale=None, vector=None, **kwargs):
"""More infos at:
overview: https://docs.enthought.com/mayavi/mayavi/data.html
writing: https://vtk.org/Wiki/VTK/Writing_VTK_files_using_python
format: https://www.vtk.org/wp-content/uploads/2015/04/file-formats.pdf
"""
_log.debug('Calling reader')
try:
from tvtk.api import tvtk
except ImportError:
_log.error('This extension recquires the tvtk package!')
return
if vector is None:
vector = True
# Setting up reader:
reader = tvtk.DataSetReader(file_name=filename, read_all_scalars=True, read_all_vectors=vector)
reader.update()
# Getting output:
output = reader.output
assert output is not None, 'File reader could not find data or file "{}"!'.format(filename)
# Reading points and vectors:
tvtk.ImageData
if isinstance(output, tvtk.ImageData): # tvtk.StructuredPoints is a subclass of tvtk.ImageData!
# Connectivity: implicit; described by: 3D data array and spacing along each axis!
_log.info('geometry: ImageData')
# Load relevant information from output (reverse to get typical Python order z,y,x):
dim = output.dimensions[::-1]
origin = output.origin[::-1]
spacing = output.spacing[::-1]
_log.info(f'dim: {dim}, origin: {origin}, spacing: {spacing}')
assert len(dim) == 3, 'Data has to be three-dimensional!'
if scale is None:
scale = tuple(spacing)
if vector: # Extract vector compontents and create magnitude array:
vector_array = np.asarray(output.point_data.vectors, dtype=np.float)
x_mag, y_mag, z_mag = vector_array.T
data = np.stack((x_mag.reshape(dim), y_mag.reshape(dim), z_mag.reshape(dim)), axis=-1)
else: # Extract scalar data and create magnitude array:
scalar_array = np.asarray(output.point_data.scalars, dtype=np.float)
data = scalar_array.reshape(dim)
elif isinstance(output, tvtk.RectilinearGrid):
# Connectivity: implicit; described by: 3D data array and 1D array of spacing for each axis!
_log.info('geometry: RectilinearGrid')
raise NotImplementedError('RectilinearGrid is currently not supported!')
elif isinstance(output, tvtk.StructuredGrid):
# Connectivity: implicit; described by: 3D data array and 3D position arrays for each axis!
_log.info('geometry: StructuredGrid')
raise NotImplementedError('StructuredGrid is currently not supported!')
elif isinstance(output, tvtk.PolyData):
# Connectivity: explicit; described by: x, y, z positions of vertices and arrays of surface cells!
_log.info('geometry: PolyData')
raise NotImplementedError('PolyData is currently not supported!')
elif isinstance(output, tvtk.UnstructuredGrid):
# Connectivity: explicit; described by: x, y, z positions of vertices and arrays of volume cells!
_log.info('geometry: UnstructuredGrid')
# Load relevant information from output:
point_array = np.asarray(output.points, dtype=np.float)
if vector:
data_array = np.asarray(output.point_data.vectors, dtype=np.float)
else:
data_array = np.asarray(output.point_data.scalars, dtype=np.float)
if scale is None:
raise ValueError('For the interpolation of unstructured grids, the `scale` parameter is required!')
elif isinstance(scale, Number): # Scale is the same for each dimension!
scale = (scale,) * len(dim)
elif isinstance(scale, tuple):
assert len(scale) == len(dim), f'Each of the {len(dim)} dimensions needs a scale, but {scale} was given!'
data = interp_to_regular_grid(point_array, data_array, scale, **kwargs)
else:
raise TypeError('Data type of {} not understood!'.format(output))
return Field(data, scale, vector=True)
def writer(filename, field, **kwargs):
_log.debug('Call writer')
assert len(field.dim) == 3, 'Currently only 3D fields can be saved to vtk!'
try:
from tvtk.api import tvtk, write_data
except ImportError:
_log.error('This extension recquires the tvtk package!')
return
# Create dataset:
origin = (0, 0, 0)
spacing = (field.scale[2], field.dim[1], field.scale[0])
dimensions = (field.dim[2], field.dim[1], field.dim[0])
sp = tvtk.StructuredPoints(origin=origin, spacing=spacing, dimensions=dimensions)
# Fill with data from field:
if field.vector: # Handle vector fields:
# Put vector components in corresponding array:
vectors = field.data.reshape(-1, 3)
sp.point_data.vectors = vectors
sp.point_data.vectors.name = 'vectors'
# Calculate colors:
x_mag, y_mag, z_mag = field.comp
magvec = np.asarray((x_mag.data.ravel(), y_mag.data.ravel(), z_mag.data.ravel()))
rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(magvec)
point_colors = tvtk.UnsignedIntArray()
point_colors.number_of_components = 3
point_colors.name = 'colors'
point_colors.from_array(rgb)
sp.point_data.scalars = point_colors
sp.point_data.scalars.name = 'colors'
else: # Handle scalar fields:
scalars = field.data.ravel()
sp.point_data.scalars = scalars
sp.point_data.scalars.name = 'scalars'
# Write the data to file:
write_data(sp, filename)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""IO functionality for Field objects."""
import logging
import os
from ..fields.field import Field
from .field_plugins import plugin_list
__all__ = ['load_field', 'save_field']
_log = logging.getLogger(__name__)
def load_field(filename, scale=None, vector=None, **kwargs):
"""Load supported file into a :class:`~..fields.Field` instance.
The function loads the file according to the extension:
SCALAR???
- hdf5 for HDF5.
- EMD Electron Microscopy Dataset format (also HDF5).
- npy or npz for numpy formats.
PHASEMAP
- hdf5 for HDF5.
- rpl for Ripple (useful to export to Digital Micrograph).
- dm3 and dm4 for Digital Micrograph files.
- unf for SEMPER unf binary format.
- txt format.
- npy or npz for numpy formats.
- Many image formats such as png, tiff, jpeg...
VECTOR
- hdf5 for HDF5.
- EMD Electron Microscopy Dataset format (also HDF5).
- llg format.
- ovf format.
- npy or npz for numpy formats.
Any extra keyword is passed to the corresponsing reader. For available options see their individual documentation.
Parameters
----------
filename: str
The filename to be loaded.
scale: tuple of float, optional
Scaling along the dimensions of the underlying data. Default is 1.
vector: bool, optional
True if the field should be a vector field, False if it should be interpreted as a scalar field (default).
Returns
-------
field: `Field`
A `Field` object containing the loaded data.
Notes
-----
Falls back to HyperSpy routines for loading data, make sure it is installed if you need the full capabilities.
"""
_log.debug('Calling load_field')
extension = os.path.splitext(filename)[1]
for i, plugin in enumerate(plugin_list): # Iterate over all plugins:
if extension in plugin.extensions: # Check if extension is recognised:
return plugin.reader(filename, scale, vector, **kwargs)
# If nothing was found, try HyperSpy:
_log.debug('Using HyperSpy')
try:
import hyperspy.api as hs
except ImportError:
_log.error('This extension recquires the hyperspy package!')
return
return Field.from_signal(hs.load(filename, **kwargs), scale, vector)
def save_field(filename, field, **kwargs):
"""Saves the Field in the specified format.
The function gets the format from the extension:
- hdf5 for HDF5.
- EMD Electron Microscopy Dataset format (also HDF5).
- npy or npz for numpy formats.
If no extension is provided, 'hdf5' is used. Most formats are saved with the HyperSpy package (internally the field
is first converted to a HyperSpy Signal.
Each format accepts a different set of parameters. For details see the specific format documentation.
Parameters
----------
filename : str, optional
Name of the file which the Field is saved into. The extension determines the saving procedure.
"""
"""Saves the phasemap in the specified format.
The function gets the format from the extension:
- hdf5 for HDF5.
- rpl for Ripple (useful to export to Digital Micrograph).
- unf for SEMPER unf binary format.
- txt format.
- Many image formats such as png, tiff, jpeg...
If no extension is provided, 'hdf5' is used. Most formats are
saved with the HyperSpy package (internally the phasemap is first
converted to a HyperSpy Signal.
Each format accepts a different set of parameters. For details
see the specific format documentation.
Parameters
----------
filename: str, optional
Name of the file which the phasemap is saved into. The extension
determines the saving procedure.
save_mask: boolean, optional
If True, the `mask` is saved, too. For all formats, except HDF5, a separate file will
be created. HDF5 always saves the `mask` in the metadata, independent of this flag. The
default is False.
save_conf: boolean, optional
If True, the `confidence` is saved, too. For all formats, except HDF5, a separate file
will be created. HDF5 always saves the `confidence` in the metadata, independent of
this flag. The default is False
pyramid_format: boolean, optional
Only used for saving to '.txt' files. If this is True, the grid spacing is saved
in an appropriate header. Otherwise just the phase is written with the
corresponding `kwargs`.
"""
_log.debug('Calling save_field')
extension = os.path.splitext(filename)[1]
for i, plugin in enumerate(plugin_list): # Iterate over all plugins:
if extension in plugin.extensions: # Check if extension is recognised:
if not plugin.writes:
raise IOError(f'Files with extension {extension} can only be read, not written (yet)!')
return plugin.reader(filename**kwargs)
# If nothing was found, try HyperSpy:
_log.debug('Using HyperSpy')
field.to_signal().save(filename, **kwargs)
# -*- coding: utf-8 -*-
# Copyright 2020 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""This module provides functions for 3D plots based on the `mayavi` library."""
import logging
import numpy as np
from . import colors
__all__ = ['contour3d', 'mask3d', 'quiver3d']
_log = logging.getLogger(__name__)
# TODO: Docstrings and signature!
def contour3d(field, title='Field Distribution', contours=10, opacity=0.25, size=None, new_fig=True, **kwargs):
"""Plot a field as a 3D-contour plot.
Parameters
----------
title: string, optional
The title for the plot.
contours: int, optional
Number of contours which should be plotted.
opacity: float, optional
Defines the opacity of the contours. Default is 0.25.
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling contour3d')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if size is None:
size = (750, 700)
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
zzz, yyy, xxx = np.indices(field.dim) + np.reshape(field.scale, (3, 1, 1, 1)) / 2 # shifted by half of scale!
zzz, yyy, xxx = zzz.T, yyy.T, xxx.T # Transpose because of VTK order!
field_amp = field.amp.T # Transpose because of VTK order!
if not isinstance(contours, (list, tuple, np.ndarray)): # Calculate the contours:
contours = list(np.linspace(field_amp.min(), field_amp.max(), contours))
extent = np.ravel(list(zip((0, 0, 0), field_amp.shape)))
cont = mlab.contour3d(xxx, yyy, zzz, field_amp, contours=contours, opacity=opacity, **kwargs)
mlab.outline(cont, extent=extent)
mlab.axes(cont, extent=extent)
mlab.title(title, height=0.95, size=0.35)
mlab.orientation_axes()
cont.scene.isometric_view()
return cont
def mask3d(field, title='Mask', threshold=0, grid=True, labels=True,
orientation=True, size=None, new_fig=True, **kwargs):
"""Plot the mask as a 3D-contour plot.
Parameters
----------
title: string, optional
The title for the plot.
threshold : float, optional
A pixel only gets masked, if it lies above this threshold . The default is 0.
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling mask3d')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if size is None:
size = (750, 700)
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
zzz, yyy, xxx = np.indices(field.dim) + np.reshape(field.scale, (3, 1, 1, 1)) / 2 # shifted by half of scale!
zzz, yyy, xxx = zzz.T, yyy.T, xxx.T # Transpose because of VTK order!
mask = field.mask.astype(int).T # Transpose because of VTK order!
extent = np.ravel(list(zip((0, 0, 0), mask.shape)))
cont = mlab.contour3d(xxx, yyy, zzz, mask, contours=[1], **kwargs)
if grid:
mlab.outline(cont, extent=extent)
if labels:
mlab.axes(cont, extent=extent)
mlab.title(title, height=0.95, size=0.35)
if orientation:
oa = mlab.orientation_axes()
oa.marker.set_viewport(0, 0, 0.4, 0.4)
mlab.draw()
engine = mlab.get_engine()
scene = engine.scenes[0]
scene.scene.isometric_view()
return cont
def quiver3d(field, title='Vector Field', limit=None, cmap='jet', mode='2darrow',
coloring='angle', ar_dens=1, opacity=1.0, grid=True, labels=True,
orientation=True, size=(700, 750), new_fig=True, view='isometric',
position=None, bgcolor=(0.5, 0.5, 0.5)):
"""Plot the vector field as 3D-vectors in a quiverplot.
Parameters
----------
title : string, optional
The title for the plot.
limit : float, optional
Plotlimit for the vector field arrow length used to scale the colormap.
cmap : string, optional
String describing the colormap which is used for amplitude encoding (default is 'jet').
ar_dens: int, optional
Number defining the arrow density which is plotted. A higher ar_dens number skips more
arrows (a number of 2 plots every second arrow). Default is 1.
mode: string, optional
Mode, determining the glyphs used in the 3D plot. Default is '2darrow', which
corresponds to 2D arrows. For smaller amounts of arrows, 'arrow' (3D) is prettier.
coloring : {'angle', 'amplitude'}, optional
Color coding mode of the arrows. Use 'angle' (default) or 'amplitude'.
opacity: float, optional
Defines the opacity of the arrows. Default is 1.0 (completely opaque).
Returns
-------
plot : :class:`mayavi.modules.vectors.Vectors`
The plot object.
"""
_log.debug('Calling quiver_plot3D')
try:
from mayavi import mlab
except ImportError:
_log.error('This extension recquires the mayavi package!')
return
if limit is None:
limit = np.max(np.nan_to_num(field.amp))
ad = ar_dens
# Create points and vector components as lists:
zzz, yyy, xxx = (np.indices(field.dim) + 1 / 2)
zzz = zzz[::ad, ::ad, ::ad].ravel()
yyy = yyy[::ad, ::ad, ::ad].ravel()
xxx = xxx[::ad, ::ad, ::ad].ravel()
x_mag = field.data[::ad, ::ad, ::ad, 0].ravel()
y_mag = field.data[::ad, ::ad, ::ad, 1].ravel()
z_mag = field.data[::ad, ::ad, ::ad, 2].ravel()
# Plot them as vectors:
if new_fig:
mlab.figure(size=size, bgcolor=(0.5, 0.5, 0.5), fgcolor=(0., 0., 0.))
if coloring == 'angle': # Encodes the full angle via colorwheel and saturation:
_log.debug('Encoding full 3D angles')
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag, mode=mode, opacity=opacity,
scalars=np.arange(len(xxx)), line_width=2)
vector = np.asarray((x_mag.ravel(), y_mag.ravel(), z_mag.ravel()))
rgb = colors.CMAP_CIRCULAR_DEFAULT.rgb_from_vector(vector)
rgba = np.hstack((rgb, 255 * np.ones((len(xxx), 1), dtype=np.uint8)))
vecs.glyph.color_mode = 'color_by_scalar'
vecs.module_manager.scalar_lut_manager.lut.table = rgba
mlab.draw()
elif coloring == 'amplitude': # Encodes the amplitude of the arrows with the jet colormap:
_log.debug('Encoding amplitude')
vecs = mlab.quiver3d(xxx, yyy, zzz, x_mag, y_mag, z_mag,
mode=mode, colormap=cmap, opacity=opacity, line_width=2)
mlab.colorbar(label_fmt='%.2f')
mlab.colorbar(orientation='vertical')
else:
raise AttributeError('Coloring mode not supported!')
vecs.glyph.glyph_source.glyph_position = 'center'
vecs.module_manager.vector_lut_manager.data_range = np.array([0, limit])
extent = np.ravel(list(zip((0, 0, 0), (field.dim[2], field.dim[1], field.dim[0]))))
if grid:
mlab.outline(vecs, extent=extent)
if labels:
mlab.axes(vecs, extent=extent)
mlab.title(title, height=0.95, size=0.35)
if orientation:
oa = mlab.orientation_axes()
oa.marker.set_viewport(0, 0, 0.4, 0.4)
mlab.draw()
engine = mlab.get_engine()
scene = engine.scenes[0]
if view == 'isometric':
scene.scene.isometric_view()
elif view == 'x_plus_view':
scene.scene.x_plus_view()
elif view == 'y_plus_view':
scene.scene.y_plus_view()
if position:
scene.scene.camera.position = position
return vecs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment