# -*- coding: utf-8 -*- # Copyright 2014 by Forschungszentrum Juelich GmbH # Author: J. Caron # """Custom FFT module with numpy and FFTW support. This module provides custom methods for FFTs including inverse, adjoint and real variants. The FFTW library is supported and is used as a default if the import succeeds. Otherwise the numpy.fft pack will be used. FFTW objects are saved in a cache after creation which speeds up further similar FFT operations. """ import pickle import logging import os import numpy as np from pyramid.config import NTHREADS _log = logging.getLogger(__name__) try: import pyfftw BACKEND = 'fftw' except ImportError: pyfftw = None BACKEND = 'numpy' _log.info('pyFFTW module not found. Using numpy implementation.') __all__ = ['PLANS', 'FLOAT', 'COMPLEX', 'dump_wisdom', 'load_wisdom', # analysis:ignore 'zeros', 'empty', 'configure_backend', 'fftn', 'ifftn', 'rfftn', 'irfftn', 'rfftn_adj', 'irfftn_adj'] class FFTWCache(object): """Class for adding FFTW Plans and on-demand lookups. This class is instantiated in this module to store FFTW plans and for the lookup of the former. Attributes ---------- cache: dict Cache for storing the FFTW plans. Notes ----- This class is used internally and is not normally not intended to be used directly by the user. """ _log = logging.getLogger(__name__ + '.FFTWCache') def __init__(self): self._log.debug('Calling __init__') self.cache = dict() self._log.debug('Created ' + str(self)) def add_fftw(self, fft_type, fftw_obj, s, axes, nthreads): """Add an FFTW object to the cache. Parameters ---------- fft_type: basestring Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn'). fftw_obj: :class:`~pyfftw.FFTW` object The FFTW object which should be added to the cache. s: tuple of ints Shape of the output array. axes: tuple of ints The axes along which the FFTW should be executed. nthreads: int Number of threads which should be used. """ self._log.debug('Calling add_fftw') in_arr = fftw_obj.get_input_array() key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads) self.cache[key] = fftw_obj def lookup_fftw(self, fft_type, in_arr, s, axes, nthreads): """ Parameters ---------- fft_type: basestring Identifier sting for the FFT type ('fftn', 'ifftn', 'rfftn', 'irfftn'). in_arr: Input array, internally, just the `dtype` and the `shape` are used to identify the FFT. s: tuple of ints Shape of the output array. axes: tuple of ints The axes along which the FFTW should be executed. nthreads: int Number of threads which should be used. Returns ------- fftw_obj: :class:`~pyfftw.FFTW` object The requested FFTW object. """ self._log.debug('Calling lookup_fftw') key = (fft_type, in_arr.shape, in_arr.dtype, s, axes, nthreads) return self.cache.get(key, None) def clear_cache(self): """Clear the cache.""" self._log.debug('Calling clear_cache') self.cache = dict() PLANS = FFTWCache() FLOAT = np.float32 # One convenient place to COMPLEX = np.complex64 # change from 32 to 64 bit # Numpy functions: def _fftn_numpy(a, s=None, axes=None): return np.fft.fftn(a, s, axes) def _ifftn_numpy(a, s=None, axes=None): return np.fft.ifftn(a, s, axes) def _rfftn_numpy(a, s=None, axes=None): return np.fft.rfftn(a, s, axes) def _irfftn_numpy(a, s=None, axes=None): return np.fft.irfftn(a, s, axes) def _rfftn_adj_numpy(a): n = 2 * (a.shape[-1] - 1) out_shape = a.shape[:-1] + (n,) out_arr = zeros(out_shape, dtype=a.dtype) out_arr[:, :n] = a return _ifftn_numpy(out_arr).real * np.prod(out_shape) def _irfftn_adj_numpy(a): n = a.shape[-1] // 2 + 1 out_arr = _fftn_numpy(a, axes=(-1,)) / a.shape[-1] if a.shape[-1] % 2 == 0: # even out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1]) else: # odd out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1]) axes = tuple(range(len(out_arr.shape[:-1]))) return _fftn_numpy(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1]) # FFTW functions: def _fftn_fftw(a, s=None, axes=None): if a.dtype not in (FLOAT, COMPLEX): raise TypeError('Wrong input type!') fftw = PLANS.lookup_fftw('fftn', a, s, axes, NTHREADS) if fftw is None: fftw = pyfftw.builders.fftn(a, s, axes, threads=NTHREADS) PLANS.add_fftw('fftn', fftw, s, axes, NTHREADS) return fftw(a).copy() def _ifftn_fftw(a, s=None, axes=None): if a.dtype not in (FLOAT, COMPLEX): raise TypeError('Wrong input type!') fftw = PLANS.lookup_fftw('ifftn', a, s, axes, NTHREADS) if fftw is None: fftw = pyfftw.builders.ifftn(a, s, axes, threads=NTHREADS) PLANS.add_fftw('ifftn', fftw, s, axes, NTHREADS) return fftw(a).copy() def _rfftn_fftw(a, s=None, axes=None): if a.dtype != FLOAT: raise TypeError('Wrong input type!') fftw = PLANS.lookup_fftw('rfftn', a, s, axes, NTHREADS) if fftw is None: fftw = pyfftw.builders.rfftn(a, s, axes, threads=NTHREADS) PLANS.add_fftw('rfftn', fftw, s, axes, NTHREADS) return fftw(a).copy() def _irfftn_fftw(a, s=None, axes=None): if a.dtype != COMPLEX: raise TypeError('Wrong input type!') fftw = PLANS.lookup_fftw('irfftn', a, s, axes, NTHREADS) if fftw is None: fftw = pyfftw.builders.irfftn(a, s, axes, threads=NTHREADS) PLANS.add_fftw('irfftn', fftw, s, axes, NTHREADS) return fftw(a).copy() def _rfftn_adj_fftw(a): # Careful: just works for even a (which is guaranteed by the kernel!) n = 2 * (a.shape[-1] - 1) out_shape = a.shape[:-1] + (n,) out_arr = zeros(out_shape, dtype=a.dtype) out_arr[:, :a.shape[-1]] = a return _ifftn_fftw(out_arr).real * np.prod(out_shape) def _irfftn_adj_fftw(a): out_arr = _fftn_fftw(a, axes=(-1,)) / a.shape[-1] # FFT of last axis n = a.shape[-1] // 2 + 1 if a.shape[-1] % 2 == 0: # even out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n - 1:-1]) else: # odd out_arr[:, 1:n] += np.conj(out_arr[:, :n - 1:-1]) axes = tuple(range(len(out_arr.shape[:-1]))) return _fftn_fftw(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1]) # These wisdom functions do nothing if pyFFTW is not available: def dump_wisdom(fname): """Wrapper function for the pyfftw.export_wisdom(), which uses a pickle dump. Parameters ---------- fname: string Name of the file in which the wisdom is saved. Returns ------- None """ _log.debug('Calling dump_wisdom') if pyfftw is not None: with open(fname, 'wb') as fp: pickle.dump(pyfftw.export_wisdom(), fp, pickle.HIGHEST_PROTOCOL) def load_wisdom(fname): """Wrapper function for the pyfftw.import_wisdom(), which uses a pickle to load a file. Parameters ---------- fname: string Name of the file from which the wisdom is loaded. Returns ------- None """ _log.debug('Calling load_wisdom') if pyfftw is not None: if not os.path.exists(fname): print("Warning: Wisdom file does not exist. First time use?") else: with open(fname, 'rb') as fp: pyfftw.import_wisdom(pickle.load(fp)) # Array setups: def empty(shape, dtype=FLOAT): """Return a new array of given shape and type without initializing entries. Parameters ---------- shape: int or tuple of int Shape of the array. dtype: data-type, optional Desired output data-type. Returns ------- out: :class:`~numpy.ndarray` The created array. """ _log.debug('Calling empty') result = np.empty(shape, dtype) if pyfftw is not None: result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) return result def zeros(shape, dtype=FLOAT): """Return a new array of given shape and type, filled with zeros. Parameters ---------- shape: int or tuple of int Shape of the array. dtype: data-type, optional Desired output data-type. Returns ------- out: :class:`~numpy.ndarray` The created array. """ _log.debug('Calling zeros') result = np.zeros(shape, dtype) if pyfftw is not None: result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) return result def ones(shape, dtype=FLOAT): """Return a new array of given shape and type, filled with ones. Parameters ---------- shape: int or tuple of int Shape of the array. dtype: data-type, optional Desired output data-type. Returns ------- out: :class:`~numpy.ndarray` The created array. """ _log.debug('Calling ones') result = np.ones(shape, dtype) if pyfftw is not None: result = pyfftw.n_byte_align(result, pyfftw.simd_alignment) return result # Configure backend: def configure_backend(backend): """Change FFT backend. Parameters ---------- backend: string Backend to use. Supported values are "numpy" and "fftw". Returns ------- None """ _log.debug('Calling configure_backend') global fftn global ifftn global rfftn global irfftn global rfftn_adj global irfftn_adj global BACKEND if backend == 'numpy': fftn = _fftn_numpy ifftn = _ifftn_numpy rfftn = _rfftn_numpy irfftn = _irfftn_numpy rfftn_adj = _rfftn_adj_numpy irfftn_adj = _irfftn_adj_numpy BACKEND = 'numpy' elif backend == 'fftw': if pyfftw is not None: fftn = _fftn_fftw ifftn = _ifftn_fftw rfftn = _rfftn_fftw irfftn = _irfftn_fftw rfftn_adj = _rfftn_adj_fftw irfftn_adj = _irfftn_adj_fftw BACKEND = 'pyfftw' else: print('Error: FFTW requested but not available') # On import: ifftn = None fftn = None rfftn = None irfftn = None rfftn_adj = None irfftn_adj = None if pyfftw is not None: configure_backend('fftw') else: configure_backend('numpy')