Forked from
empyre / empyre
335 commits behind the upstream repository.
fft.py 10.67 KiB
# -*- coding: utf-8 -*-
# Copyright 2014 by Forschungszentrum Juelich GmbH
# Author: J. Caron
#
"""Custom FFT module with numpy and FFTW support.
This module provides custom methods for FFTs including inverse, adjoint and real variants. The
FFTW library is supported and is used as a default if the import succeeds. Otherwise the numpy.fft
pack will be used. FFTW objects are saved in a cache after creation which speeds up further similar
FFT operations.
"""
import pickle
import logging
import os
import numpy as np
from pyramid.config import NTHREADS
_log = logging.getLogger(__name__)
try:
import pyfftw
BACKEND = 'fftw'
except ImportError:
pyfftw = None
BACKEND = 'numpy'
_log.info('pyFFTW module not found. Using numpy implementation.')
__all__ = ['PLANS', 'FLOAT', 'COMPLEX', 'dump_wisdom', 'load_wisdom', # analysis:ignore
'zeros', 'empty', 'configure_backend',
'fftn', 'ifftn', 'rfftn', 'irfftn', 'rfftn_adj', 'irfftn_adj']
class FFTWCache(object):
"""Class for adding FFTW Plans and on-demand lookups.
This class is instantiated in this module to store FFTW plans and for the lookup of the former.
Attributes
----------
cache: dict
Cache for storing the FFTW plans.
Notes
-----
This class is used internally and is not normally not intended to be used directly by the user.
"""
_log = logging.getLogger(__name__ + '.FFTWCache')
def __init__(self):
self._log.debug('Calling __init__')
self.cache = dict()
self._log.debug('Created ' + str(self))
def add_fftw(self, fft_type, fftw_obj, s, axes, nthreads):
"""Add an FFTW object to the cache.
Parameters
----------
fft_type: basestring
Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn').
fftw_obj: :class:`~pyfftw.FFTW` object
The FFTW object which should be added to the cache.
s: tuple of ints
Shape of the output array.
axes: tuple of ints
The axes along which the FFTW should be executed.
nthreads: int
Number of threads which should be used.
"""
self._log.debug('Calling add_fftw')
in_arr = fftw_obj.get_input_array()
key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads)
self.cache[key] = fftw_obj
def lookup_fftw(self, fft_type, in_arr, s, axes, nthreads):
"""
Parameters
----------
fft_type: basestring
Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn').
in_arr:
Input array, internally, just the `dtype` and the `shape` are used to identify the FFT.
s: tuple of ints
Shape of the output array.
axes: tuple of ints
The axes along which the FFTW should be executed.
nthreads: int
Number of threads which should be used.
Returns
-------
fftw_obj: :class:`~pyfftw.FFTW` object
The requested FFTW object.
"""
self._log.debug('Calling lookup_fftw')
key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads)
return self.cache.get(key, None)
def clear_cache(self):
"""Clear the cache."""
self._log.debug('Calling clear_cache')
self.cache = dict()
PLANS = FFTWCache()
FLOAT = np.float32 # One convenient place to
COMPLEX = np.complex64 # change from 32 to 64 bit
# Numpy functions:
def _fftn_numpy(a, s=None, axes=None):
return np.fft.fftn(a, s, axes)
def _ifftn_numpy(a, s=None, axes=None):
return np.fft.ifftn(a, s, axes)
def _rfftn_numpy(a, s=None, axes=None):
return np.fft.rfftn(a, s, axes)
def _irfftn_numpy(a, s=None, axes=None):
return np.fft.irfftn(a, s, axes)
def _rfftn_adj_numpy(a):
n = 2 * (a.shape[-1] - 1)
out_shape = a.shape[:-1] + (n,)
out_arr = zeros(out_shape, dtype=a.dtype)
out_arr[:, :n] = a
return _ifftn_numpy(out_arr).real * np.prod(out_shape)
def _irfftn_adj_numpy(a):
n = a.shape[-1] // 2 + 1
out_arr = _fftn_numpy(a, axes=(-1,)) / a.shape[-1]
if a.shape[-1] % 2 == 0: # even
out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1])
else: # odd
out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1])
axes = tuple(range(len(out_arr.shape[:-1])))
return _fftn_numpy(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1])
# FFTW functions:
def _fftn_fftw(a, s=None, axes=None):
if a.dtype not in (FLOAT, COMPLEX):
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('fftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.fftn(a, s, axes, threads=NTHREADS)
PLANS.add_fftw('fftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _ifftn_fftw(a, s=None, axes=None):
if a.dtype not in (FLOAT, COMPLEX):
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('ifftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.ifftn(a, s, axes, threads=NTHREADS)
PLANS.add_fftw('ifftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _rfftn_fftw(a, s=None, axes=None):
if a.dtype != FLOAT:
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('rfftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.rfftn(a, s, axes, threads=NTHREADS)
PLANS.add_fftw('rfftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _irfftn_fftw(a, s=None, axes=None):
if a.dtype != COMPLEX:
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('irfftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.irfftn(a, s, axes, threads=NTHREADS)
PLANS.add_fftw('irfftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _rfftn_adj_fftw(a):
# Careful: just works for even a (which is guaranteed by the kernel!)
n = 2 * (a.shape[-1] - 1)
out_shape = a.shape[:-1] + (n,)
out_arr = zeros(out_shape, dtype=a.dtype)
out_arr[:, :a.shape[-1]] = a
return _ifftn_fftw(out_arr).real * np.prod(out_shape)
def _irfftn_adj_fftw(a):
out_arr = _fftn_fftw(a, axes=(-1,)) / a.shape[-1] # FFT of last axis
n = a.shape[-1] // 2 + 1
if a.shape[-1] % 2 == 0: # even
out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1])
else: # odd
out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1])
axes = tuple(range(len(out_arr.shape[:-1])))
return _fftn_fftw(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1])
# These wisdom functions do nothing if pyFFTW is not available:
def dump_wisdom(fname):
"""Wrapper function for the pyfftw.export_wisdom(), which uses a pickle dump.
Parameters
----------
fname: string
Name of the file in which the wisdom is saved.
Returns
-------
None
"""
_log.debug('Calling dump_wisdom')
if pyfftw is not None:
with open(fname, 'wb') as fp:
pickle.dump(pyfftw.export_wisdom(), fp, pickle.HIGHEST_PROTOCOL)
def load_wisdom(fname):
"""Wrapper function for the pyfftw.import_wisdom(), which uses a pickle to load a file.
Parameters
----------
fname: string
Name of the file from which the wisdom is loaded.
Returns
-------
None
"""
_log.debug('Calling load_wisdom')
if pyfftw is not None:
if not os.path.exists(fname):
print("Warning: Wisdom file does not exist. First time use?")
else:
with open(fname, 'rb') as fp:
pyfftw.import_wisdom(pickle.load(fp))
# Array setups:
def empty(shape, dtype=FLOAT):
"""Return a new array of given shape and type without initializing entries.
Parameters
----------
shape: int or tuple of int
Shape of the array.
dtype: data-type, optional
Desired output data-type.
Returns
-------
out: :class:`~numpy.ndarray`
The created array.
"""
_log.debug('Calling empty')
result = np.empty(shape, dtype)
if pyfftw is not None:
result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
return result
def zeros(shape, dtype=FLOAT):
"""Return a new array of given shape and type, filled with zeros.
Parameters
----------
shape: int or tuple of int
Shape of the array.
dtype: data-type, optional
Desired output data-type.
Returns
-------
out: :class:`~numpy.ndarray`
The created array.
"""
_log.debug('Calling zeros')
result = np.zeros(shape, dtype)
if pyfftw is not None:
result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
return result
def ones(shape, dtype=FLOAT):
"""Return a new array of given shape and type, filled with ones.
Parameters
----------
shape: int or tuple of int
Shape of the array.
dtype: data-type, optional
Desired output data-type.
Returns
-------
out: :class:`~numpy.ndarray`
The created array.
"""
_log.debug('Calling ones')
result = np.ones(shape, dtype)
if pyfftw is not None:
result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
return result
# Configure backend:
def configure_backend(backend):
"""Change FFT backend.
Parameters
----------
backend: string
Backend to use. Supported values are "numpy" and "fftw".
Returns
-------
None
"""
_log.debug('Calling configure_backend')
global fftn
global ifftn
global rfftn
global irfftn
global rfftn_adj
global irfftn_adj
global BACKEND
if backend == 'numpy':
fftn = _fftn_numpy
ifftn = _ifftn_numpy
rfftn = _rfftn_numpy
irfftn = _irfftn_numpy
rfftn_adj = _rfftn_adj_numpy
irfftn_adj = _irfftn_adj_numpy
BACKEND = 'numpy'
elif backend == 'fftw':
if pyfftw is not None:
fftn = _fftn_fftw
ifftn = _ifftn_fftw
rfftn = _rfftn_fftw
irfftn = _irfftn_fftw
rfftn_adj = _rfftn_adj_fftw
irfftn_adj = _irfftn_adj_fftw
BACKEND = 'pyfftw'
else:
print('Error: FFTW requested but not available')
# On import:
ifftn = None
fftn = None
rfftn = None
irfftn = None
rfftn_adj = None
irfftn_adj = None
if pyfftw is not None:
configure_backend('fftw')
else:
configure_backend('numpy')