Skip to content
Snippets Groups Projects
Commit 4f1991fd authored by Jan Caron's avatar Jan Caron
Browse files

Switched from Pyramid fft (obsolete) to Jutil fft.

Renaming of variables in PhaseMapperRDFC.jac_T_dot to make more sense...
parent e1014166
No related branches found
No related tags found
No related merge requests found
......@@ -38,14 +38,11 @@ quaternion
Class which is used for easy rotations in the Projector classes.
colormap
Class which implements a custom direction encoding colormap.
fft
Class for custom FFT functions using numpy or FFTW.
"""
from . import analytic
from . import reconstruction
from . import fft
from . import fieldconverter
from . import magcreator
from . import colors
......@@ -71,7 +68,7 @@ _log = logging.getLogger(__name__)
_log.info("Starting Pyramid V{} HG{}".format(__version__, __hg_revision__))
del logging
__all__ = ['analytic', 'magcreator', 'reconstruction', 'fft', 'fieldconverter',
__all__ = ['analytic', 'magcreator', 'reconstruction', 'fieldconverter',
'colors', 'utils', 'load_phasemap', 'load_vectordata']
__all__.extend(costfunction.__all__)
__all__.extend(dataset.__all__)
......
......@@ -7,7 +7,8 @@ specified costfunction for a fixed magnetization distribution."""
import logging
from pyramid import fft
from jutil import fft
from pyramid.fielddata import VectorData
from pyramid.phasemap import PhaseMap
......@@ -64,7 +65,7 @@ class Diagnostics(object):
the calculation of the gain and averaging kernel matrizes and which ideally contains the
variance at position `row_idx` for the current component and position in 3D."""
if not self._updated_cov_row:
e_i = fft.zeros(self.cost.n, dtype=fft.FLOAT)
e_i = np.zeros(self.cost.n, dtype=self.x_rec.dtype)
e_i[self.row_idx] = 1
row = 2 * jutil.cg.conj_grad_solve(self._A, e_i, P=self._P, max_iter=self.max_iter)
self._std_row = np.asarray(row)
......@@ -101,7 +102,7 @@ class Diagnostics(object):
the solution is determined by the measurement (close to `1`) or by a priori information
(close to `0`)."""
if not self._updated_measure_contribution:
cache = self.fwd_model.jac_dot(self.x_rec, fft.ones(self.cost.n, fft.FLOAT))
cache = self.fwd_model.jac_dot(self.x_rec, np.ones(self.cost.n, self.x_rec.dtype))
cache = self.fwd_model.jac_T_dot(self.x_rec, self.Se_inv.dot(cache))
mc = 2 * jutil.cg.conj_grad_solve(self._A, cache, P=self._P, max_iter=self.max_iter)
self._measure_contribution = mc
......@@ -165,7 +166,7 @@ class Diagnostics(object):
self._log.debug('Calling get_avg_kern_row')
if pos is not None:
self.pos = pos
magdata_avg_kern = VectorData(self.cost.data_set.a, fft.zeros((3,) + self.dim))
magdata_avg_kern = VectorData(self.cost.data_set.a, np.zeros((3,) + self.dim))
magdata_avg_kern.set_vector(self.avrg_kern_row, mask=self.mask)
return magdata_avg_kern
......
......@@ -11,7 +11,8 @@ import logging
import numpy as np
from pyramid import fft
from jutil import fft
from pyramid.fielddata import VectorData
__all__ = ['convert_M_to_A', 'convert_A_to_B', 'convert_M_to_B']
......@@ -39,12 +40,10 @@ def convert_M_to_A(magdata, b_0=1.0):
assert isinstance(magdata, VectorData), 'Only VectorData objects can be mapped!'
dim = magdata.dim
dim_kern = tuple(2 * np.array(dim) - 1) # Dimensions of the kernel
if fft.BACKEND == 'pyfftw':
if fft.HAVE_FFTW:
dim_pad = tuple(2 * np.array(dim)) # is at least even (not neccessary a power of 2)
elif fft.BACKEND == 'numpy':
dim_pad = tuple(2 ** np.ceil(np.log2(2 * np.array(dim))).astype(int)) # pow(2)
else:
raise ValueError('Backend of the fft module is not correctly initiated!')
dim_pad = tuple(2 ** np.ceil(np.log2(2 * np.array(dim))).astype(int)) # pow(2)
slice_B = (slice(dim[0] - 1, dim_kern[0]), # Shift because kernel center
slice(dim[1] - 1, dim_kern[1]), # is not at (0, 0, 0)!
slice(dim[2] - 1, dim_kern[2]))
......@@ -57,9 +56,9 @@ def convert_M_to_A(magdata, b_0=1.0):
xxx -= dim[2] - 1
yyy -= dim[1] - 1
zzz -= dim[0] - 1
k_x = fft.empty(dim_kern, dtype=fft.FLOAT)
k_y = fft.empty(dim_kern, dtype=fft.FLOAT)
k_z = fft.empty(dim_kern, dtype=fft.FLOAT)
k_x = np.empty(dim_kern, dtype=magdata.field.dtype)
k_y = np.empty(dim_kern, dtype=magdata.field.dtype)
k_z = np.empty(dim_kern, dtype=magdata.field.dtype)
k_x[...] = coeff * xxx / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3
k_y[...] = coeff * yyy / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3
k_z[...] = coeff * zzz / np.abs(xxx ** 2 + yyy ** 2 + zzz ** 2 + 1E-30) ** 3
......@@ -68,9 +67,9 @@ def convert_M_to_A(magdata, b_0=1.0):
k_y_fft = fft.rfftn(k_y, dim_pad)
k_z_fft = fft.rfftn(k_z, dim_pad)
# Prepare magnetization:
x_mag = fft.zeros(dim_pad, dtype=fft.FLOAT)
y_mag = fft.zeros(dim_pad, dtype=fft.FLOAT)
z_mag = fft.zeros(dim_pad, dtype=fft.FLOAT)
x_mag = np.zeros(dim_pad, dtype=magdata.field.dtype)
y_mag = np.zeros(dim_pad, dtype=magdata.field.dtype)
z_mag = np.zeros(dim_pad, dtype=magdata.field.dtype)
x_mag[slice_M] = magdata.field[0, ...]
y_mag[slice_M] = magdata.field[1, ...]
z_mag[slice_M] = magdata.field[2, ...]
......
......@@ -19,7 +19,8 @@ from PIL import Image
from scipy.ndimage.interpolation import zoom
from . import fft
from jutil import fft
from . import colors
__all__ = ['VectorData', 'ScalarData']
......@@ -79,7 +80,7 @@ class FieldData(object, metaclass=abc.ABCMeta):
assert 3 <= len(field.shape) <= 4, 'Field has to be 3- or 4-dimensional (scalar / vector)!'
if len(field.shape) == 4:
assert field.shape[0] == 3, 'A vector field has to have exactly 3 components!'
self._field = np.asarray(field, dtype=fft.FLOAT)
self._field = field
@property
def field_amp(self):
......@@ -97,7 +98,6 @@ class FieldData(object, metaclass=abc.ABCMeta):
@field_vec.setter
def field_vec(self, mag_vec):
mag_vec = np.asarray(mag_vec, dtype=fft.FLOAT)
assert np.size(mag_vec) == np.prod(self.shape), \
'Vector has to match field shape! {} {}'.format(mag_vec.shape, np.prod(self.shape))
self.field = mag_vec.reshape((3,) + self.dim)
......@@ -565,7 +565,6 @@ class VectorData(FieldData):
"""
self._log.debug('Calling set_vector')
vector = np.asarray(vector, dtype=fft.FLOAT)
assert np.size(vector) % 3 == 0, 'Vector has to contain all 3 components for every pixel!'
count = np.size(vector) // 3
if mask is not None:
......@@ -1298,7 +1297,6 @@ class ScalarData(FieldData):
"""
self._log.debug('Calling set_vector')
vector = np.asarray(vector, dtype=fft.FLOAT)
if mask is not None:
self.field[mask] = vector
else:
......
......@@ -9,7 +9,7 @@ import logging
import numpy as np
from pyramid import fft
from jutil import fft
__all__ = ['Kernel', 'PHI_0']
......@@ -64,12 +64,14 @@ class Kernel(object):
prw_vec: tuple of 2 int, optional
A two-component vector describing the displacement of the reference wave to include
perturbation of this reference by the object itself (via fringing fields).
dtype: numpy dtype, optional
Data type of the kernel. Default is np.float32.
"""
_log = logging.getLogger(__name__ + '.Kernel')
def __init__(self, a, dim_uv, b_0=1., prw_vec=None, geometry='disc'):
def __init__(self, a, dim_uv, b_0=1., prw_vec=None, geometry='disc', dtype=np.float32):
self._log.debug('Calling __init__')
# Set basic properties:
self.dim_uv = dim_uv # Dimensions of the FOV
......@@ -77,9 +79,9 @@ class Kernel(object):
self.a = a
self.geometry = geometry
# Set up FFT:
if fft.BACKEND == 'pyfftw':
if fft.HAVE_FFTW:
self.dim_pad = tuple(2 * np.array(dim_uv)) # is at least even (not nec. power of 2)
elif fft.BACKEND == 'numpy':
else:
self.dim_pad = tuple(2 ** np.ceil(np.log2(2 * np.array(dim_uv))).astype(int)) # pow(2)
self.dim_fft = (self.dim_pad[0], self.dim_pad[1] // 2 + 1) # last axis is real
self.slice_phase = (slice(dim_uv[0] - 1, self.dim_kern[0]), # Shift because kernel center
......@@ -96,8 +98,8 @@ class Kernel(object):
u = np.linspace(-(u_dim - 1), u_dim - 1, num=2 * u_dim - 1)
v = np.linspace(-(v_dim - 1), v_dim - 1, num=2 * v_dim - 1)
uu, vv = np.meshgrid(u, v)
self.u = fft.empty(self.dim_kern, dtype=fft.FLOAT)
self.v = fft.empty(self.dim_kern, dtype=fft.FLOAT)
self.u = np.empty(self.dim_kern, dtype=dtype)
self.v = np.empty(self.dim_kern, dtype=dtype)
self.u[...] = coeff * self._get_elementary_phase(geometry, uu, vv, a)
self.v[...] = coeff * -self._get_elementary_phase(geometry, vv, uu, a)
# Include perturbed reference wave:
......
......@@ -12,7 +12,8 @@ import logging
import numpy as np
from . import fft
from jutil import fft
from .fielddata import VectorData, ScalarData
from .phasemap import PhaseMap
......@@ -107,9 +108,9 @@ class PhaseMapperRDFC(PhaseMapper):
self.kernel = kernel
self.m = np.prod(kernel.dim_uv)
self.n = 2 * self.m
self.u_mag = fft.zeros(kernel.dim_pad, dtype=fft.FLOAT)
self.v_mag = fft.zeros(kernel.dim_pad, dtype=fft.FLOAT)
self.mag_adj = fft.zeros(kernel.dim_pad, dtype=fft.FLOAT)
self.u_mag = np.zeros(kernel.dim_pad, dtype=kernel.u.dtype)
self.v_mag = np.zeros(kernel.dim_pad, dtype=kernel.u.dtype)
self.phase_adj = np.zeros(kernel.dim_pad, dtype=kernel.u.dtype)
self._log.debug('Created ' + str(self))
def __repr__(self):
......@@ -179,14 +180,13 @@ class PhaseMapperRDFC(PhaseMapper):
"""
assert len(vector) == self.m, \
'vector size not compatible! vector: {}, size: {}'.format(len(vector), self.m)
self.mag_adj[self.kernel.slice_phase] = vector.reshape(self.kernel.dim_uv)
mag_adj_fft = fft.irfftn_adj(self.mag_adj)
u_phase_adj_fft = mag_adj_fft * np.conj(self.kernel.u_fft)
v_phase_adj_fft = mag_adj_fft * np.conj(self.kernel.v_fft)
u_phase_adj = fft.rfftn_adj(u_phase_adj_fft)[self.kernel.slice_mag]
v_phase_adj = fft.rfftn_adj(v_phase_adj_fft)[self.kernel.slice_mag]
result = np.concatenate((u_phase_adj.ravel(), v_phase_adj.ravel()))
# TODO: Why minus?
self.phase_adj[self.kernel.slice_phase] = vector.reshape(self.kernel.dim_uv)
phase_adj_fft = fft.irfft2_adj(self.phase_adj)
u_mag_adj_fft = phase_adj_fft * np.conj(self.kernel.u_fft)
v_mag_adj_fft = phase_adj_fft * np.conj(self.kernel.v_fft)
u_mag_adj = fft.rfft2_adj(u_mag_adj_fft)[self.kernel.slice_mag]
v_mag_adj = fft.rfft2_adj(v_mag_adj_fft)[self.kernel.slice_mag]
result = np.concatenate((u_mag_adj.ravel(), v_mag_adj.ravel()))
return result
......
......@@ -79,12 +79,12 @@ class Projector(object, metaclass=abc.ABCMeta):
def __call__(self, field_data):
if isinstance(field_data, VectorData):
field_empty = fft.zeros((3, 1) + self.dim_uv, dtype=fft.FLOAT)
field_empty = np.zeros((3, 1) + self.dim_uv, dtype=field_data.field.dtype)
field_data_proj = VectorData(field_data.a, field_empty)
field_proj = self.jac_dot(field_data.field_vec).reshape((2,) + self.dim_uv)
field_data_proj.field[0:2, 0, ...] = field_proj
elif isinstance(field_data, ScalarData):
field_empty = fft.zeros((1,) + self.dim_uv, dtype=fft.FLOAT)
field_empty = np.zeros((1,) + self.dim_uv, dtype=field_data.field.dtype)
field_data_proj = ScalarData(field_data.a, field_empty)
field_proj = self.jac_dot(field_data.field_vec).reshape(self.dim_uv)
field_data_proj.field[0, ...] = field_proj
......@@ -93,7 +93,7 @@ class Projector(object, metaclass=abc.ABCMeta):
return field_data_proj
def _vector_field_projection(self, vector):
result = fft.zeros(2 * self.size_2d, dtype=fft.FLOAT)
result = np.zeros(2 * self.size_2d, dtype=vector.dtype)
# Go over all possible component projections (z, y, x) to (u, v):
vec_x, vec_y, vec_z = np.split(vector, 3)
vec_x_weighted = self.weight.dot(vec_x)
......@@ -116,7 +116,7 @@ class Projector(object, metaclass=abc.ABCMeta):
return result
def _vector_field_projection_T(self, vector):
result = np.zeros(3 * self.size_3d, dtype=fft.FLOAT)
result = np.zeros(3 * self.size_3d)
# Go over all possible component projections (u, v) to (z, y, x):
vec_u, vec_v = np.split(vector, 2)
vec_u_weighted = self.weight.T.dot(vec_u)
......
......@@ -143,7 +143,7 @@ setup(name=DISTNAME,
packages=find_packages(exclude=['tests']),
include_dirs=[numpy.get_include()],
requires=['numpy', 'scipy', 'matplotlib', 'Pillow',
'mayavi', 'pyfftw', 'hyperspy', 'nose'],
'mayavi', 'pyfftw', 'hyperspy', 'nose', 'jutil'],
test_suite='nose.collector',
cmdclass={'build_ext': build_ext, 'build': build})
print('-------------------------------------------------------------------------------\n')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment