Skip to content
Snippets Groups Projects
fft.py 6.6 KiB
Newer Older
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 28 15:30:10 2014

@author: Jan
"""

# TODO: Document!

import numpy as np
import cPickle as pickle
import os

# pyFFTW depends on this
try:
    from collections import Counter  #analysis:ignore
except ImportError:
    import collections_python27
    import collections
    collections.Counter = collections_python27.Counter

try:
    import pyfftw
    BACKEND = 'fftw'
except ImportError:
    pyfftw = None
    BACKEND = 'numpy'
    print("pyFFTW module not found. Using numpy implementation.")


__all__ = ['PLANS', 'FLOAT', 'COMPLEX', 'NTHREADS', 'dump_wisdom', 'load_wisdom', 'zeros', 'empty',
           'configure_backend', 'fftn', 'ifftn', 'rfftn', 'irfftn', 'rfftn_adj', 'irfftn_adj']


class FFTWCache(object):

    def __init__(self):
        self.cache = dict()

    def add_fftw(self, fft_type, fftw_obj, s, axes, nthreads):
        in_arr = fftw_obj.get_input_array()
        key = (fft_type, in_arr.shape, in_arr.dtype, nthreads)
        self.cache[key] = fftw_obj

    def lookup_fftw(self, fft_type, in_arr, s, axes, nthreads):
        key = (fft_type, in_arr.shape, in_arr.dtype, nthreads)
        return self.cache.get(key, None)

    def clear_cache(self):
        self.cache = dict()


PLANS = FFTWCache()
FLOAT = np.float32      # One convenient place to
COMPLEX = np.complex64  # change from 32 to 64 bit
NTHREADS = 1


# Numpy functions:
def _fftn_numpy(a, s=None, axes=None):
    return np.fft.fftn(a, s, axes)


def _ifftn_numpy(a, s=None, axes=None):
    return np.fft.ifftn(a, s, axes)


def _rfftn_numpy(a, s=None, axes=None):
    return np.fft.rfftn(a, s, axes)


def _irfftn_numpy(a, s=None, axes=None):
    return np.fft.irfftn(a, s, axes)


def _rfftn_adj_numpy(a):
    n = 2 * (a.shape[-1] - 1)
    out_shape = a.shape[:-1] + (n,)
    out_arr = zeros(out_shape, dtype=a.dtype)
    out_arr[:, :n] = a
    return _ifftn_numpy(out_arr).real * np.prod(out_shape)


def _irfftn_adj_numpy(a):
    n = a.shape[-1] // 2 + 1
    out_arr = _fftn_numpy(a, axis=-1) / a.shape[-1]
    if a.shape[-1] % 2 == 0:  # even
        out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n-1:-1])
    else:  # odd
        out_arr[:, 1:n] += np.conj(out_arr[:, :n-1:-1])
    axes = tuple(range(len(out_arr.shape[:-1])))
    return _fftn_numpy(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1])


# FFTW functions:
def _fftn_fftw(a, s=None, axes=None):
    if a.dtype not in (FLOAT, COMPLEX):
        raise TypeError('Wrong input type!')
    fftw = PLANS.lookup_fftw('fftn', a, s, axes, NTHREADS)
    if fftw is None:
        fftw = pyfftw.builders.fftn(a, s, axes)
        PLANS.add_fftw('fftn', fftw, s, axes, NTHREADS)
    return fftw(a).copy()


def _ifftn_fftw(a, s=None, axes=None):
    if a.dtype not in (FLOAT, COMPLEX):
        raise TypeError('Wrong input type!')
    fftw = PLANS.lookup_fftw('ifftn', a, s, axes, NTHREADS)
    if fftw is None:
        fftw = pyfftw.builders.ifftn(a, s, axes)
        PLANS.add_fftw('ifftn', fftw, s, axes, NTHREADS)
    return fftw(a).copy()


def _rfftn_fftw(a, s=None, axes=None):
    if a.dtype != FLOAT:
        raise TypeError('Wrong input type!')
    fftw = PLANS.lookup_fftw('rfftn', a, s, axes, NTHREADS)
    if fftw is None:
        fftw = pyfftw.builders.rfftn(a, s, axes)
        PLANS.add_fftw('rfftn', fftw, s, axes, NTHREADS)
    return fftw(a).copy()


def _irfftn_fftw(a, s=None, axes=None):
    if a.dtype != COMPLEX:
        raise TypeError('Wrong input type!')
    fftw = PLANS.lookup_fftw('irfftn', a, s, axes, NTHREADS)
    if fftw is None:
        fftw = pyfftw.builders.irfftn(a, s, axes)
        PLANS.add_fftw('irfftn', fftw, s, axes, NTHREADS)
    return fftw(a).copy()


def _rfftn_adj_fftw(a):
    n = 2 * (a.shape[-1] - 1)
    out_shape = a.shape[:-1] + (n,)
    out_arr = zeros(out_shape, dtype=a.dtype)
    out_arr[:, :a.shape[-1]] = a
    return _ifftn_fftw(out_arr).real * np.prod(out_shape)


def _irfftn_adj_fftw(a):
    out_arr = _fftn_fftw(a, axes=(-1,)) / a.shape[-1]  # FFT of last axis
    n = a.shape[-1] // 2 + 1
    if a.shape[-1] % 2 == 0:  # even
        out_arr[:, 1:n-1] += np.conj(out_arr[:, :n-1:-1])
    else:  # odd
        out_arr[:, 1:n] += np.conj(out_arr[:, :n-1:-1])
    axes = tuple(range(len(out_arr.shape[:-1])))
    return _fftn_fftw(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1])


# These wisdom functions do nothing if pyFFTW is not available
def dump_wisdom(fname):
    # TODO: Docstring!
    if pyfftw is not None:
        with open(fname, 'wb') as fp:
            pickle.dump(pyfftw.export_wisdom(), fp, pickle.HIGHEST_PROTOCOL)


def load_wisdom(fname):
    # TODO: Docstring!
    if pyfftw is not None:
        if not os.path.exists(fname):
            print("Warning: Wisdom file does not exist. First time use?")
        else:
            with open(fname, 'rb') as fp:
                pyfftw.import_wisdom(pickle.load(fp))


# Array setups:
Jan Caron's avatar
Jan Caron committed
def empty(shape, dtype=FLOAT):
    # TODO: Docstring!
    result = np.empty(shape, dtype)
    if pyfftw is not None:
        result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
    return result


def zeros(shape, dtype=FLOAT):
    # TODO: Docstring!
    result = np.zeros(shape, dtype)
    if pyfftw is not None:
        result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
    return result


Jan Caron's avatar
Jan Caron committed
def ones(shape, dtype=FLOAT):
Jan Caron's avatar
Jan Caron committed
    result = np.ones(shape, dtype)
    if pyfftw is not None:
        result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
    return result


# Configure backend:
def configure_backend(backend):
    """
    Change FFT backend.

    Supported values are "numpy" and "fftw".
    """
    global fftn
    global ifftn
    global rfftn
    global irfftn
    global rfftn_adj
    global irfftn_adj
    global BACKEND
    if backend == "numpy":
        fftn = _fftn_numpy
        ifftn = _ifftn_numpy
        rfftn = _rfftn_numpy
        irfftn = _irfftn_numpy
        rfftn_adj = _rfftn_adj_numpy
        irfftn_adj = _irfftn_adj_numpy
        BACKEND = 'numpy'
    elif backend == "fftw":
        if pyfftw is not None:
            fftn = _fftn_fftw
            ifftn = _ifftn_fftw
            rfftn = _rfftn_fftw
            irfftn = _irfftn_fftw
            rfftn_adj = _rfftn_adj_fftw
            irfftn_adj = _irfftn_adj_fftw
            BACKEND = 'pyfftw'
        else:
            print("Error: FFTW requested but not available")


# On import:
if pyfftw is not None:
    configure_backend("fftw")
else:
    configure_backend("numpy")