Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 28 15:30:10 2014
@author: Jan
"""
# TODO: Document!
import numpy as np
import cPickle as pickle
import os
# pyFFTW depends on this
try:
from collections import Counter #analysis:ignore
except ImportError:
import collections_python27
import collections
collections.Counter = collections_python27.Counter
try:
import pyfftw
BACKEND = 'fftw'
except ImportError:
pyfftw = None
BACKEND = 'numpy'
print("pyFFTW module not found. Using numpy implementation.")
__all__ = ['PLANS', 'FLOAT', 'COMPLEX', 'NTHREADS', 'dump_wisdom', 'load_wisdom', 'zeros', 'empty',
'configure_backend', 'fftn', 'ifftn', 'rfftn', 'irfftn', 'rfftn_adj', 'irfftn_adj']
class FFTWCache(object):
def __init__(self):
self.cache = dict()
def add_fftw(self, fft_type, fftw_obj, s, axes, nthreads):
in_arr = fftw_obj.get_input_array()
key = (fft_type, in_arr.shape, in_arr.dtype, nthreads)
self.cache[key] = fftw_obj
def lookup_fftw(self, fft_type, in_arr, s, axes, nthreads):
key = (fft_type, in_arr.shape, in_arr.dtype, nthreads)
return self.cache.get(key, None)
def clear_cache(self):
self.cache = dict()
PLANS = FFTWCache()
FLOAT = np.float32 # One convenient place to
COMPLEX = np.complex64 # change from 32 to 64 bit
NTHREADS = 1
# Numpy functions:
def _fftn_numpy(a, s=None, axes=None):
return np.fft.fftn(a, s, axes)
def _ifftn_numpy(a, s=None, axes=None):
return np.fft.ifftn(a, s, axes)
def _rfftn_numpy(a, s=None, axes=None):
return np.fft.rfftn(a, s, axes)
def _irfftn_numpy(a, s=None, axes=None):
return np.fft.irfftn(a, s, axes)
def _rfftn_adj_numpy(a):
n = 2 * (a.shape[-1] - 1)
out_shape = a.shape[:-1] + (n,)
out_arr = zeros(out_shape, dtype=a.dtype)
out_arr[:, :n] = a
return _ifftn_numpy(out_arr).real * np.prod(out_shape)
def _irfftn_adj_numpy(a):
n = a.shape[-1] // 2 + 1
out_arr = _fftn_numpy(a, axis=-1) / a.shape[-1]
if a.shape[-1] % 2 == 0: # even
out_arr[:, 1:n - 1] += np.conj(out_arr[:, :n-1:-1])
else: # odd
out_arr[:, 1:n] += np.conj(out_arr[:, :n-1:-1])
axes = tuple(range(len(out_arr.shape[:-1])))
return _fftn_numpy(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1])
# FFTW functions:
def _fftn_fftw(a, s=None, axes=None):
if a.dtype not in (FLOAT, COMPLEX):
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('fftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.fftn(a, s, axes)
PLANS.add_fftw('fftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _ifftn_fftw(a, s=None, axes=None):
if a.dtype not in (FLOAT, COMPLEX):
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('ifftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.ifftn(a, s, axes)
PLANS.add_fftw('ifftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _rfftn_fftw(a, s=None, axes=None):
if a.dtype != FLOAT:
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('rfftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.rfftn(a, s, axes)
PLANS.add_fftw('rfftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _irfftn_fftw(a, s=None, axes=None):
if a.dtype != COMPLEX:
raise TypeError('Wrong input type!')
fftw = PLANS.lookup_fftw('irfftn', a, s, axes, NTHREADS)
if fftw is None:
fftw = pyfftw.builders.irfftn(a, s, axes)
PLANS.add_fftw('irfftn', fftw, s, axes, NTHREADS)
return fftw(a).copy()
def _rfftn_adj_fftw(a):
n = 2 * (a.shape[-1] - 1)
out_shape = a.shape[:-1] + (n,)
out_arr = zeros(out_shape, dtype=a.dtype)
out_arr[:, :a.shape[-1]] = a
return _ifftn_fftw(out_arr).real * np.prod(out_shape)
def _irfftn_adj_fftw(a):
out_arr = _fftn_fftw(a, axes=(-1,)) / a.shape[-1] # FFT of last axis
n = a.shape[-1] // 2 + 1
if a.shape[-1] % 2 == 0: # even
out_arr[:, 1:n-1] += np.conj(out_arr[:, :n-1:-1])
else: # odd
out_arr[:, 1:n] += np.conj(out_arr[:, :n-1:-1])
axes = tuple(range(len(out_arr.shape[:-1])))
return _fftn_fftw(out_arr[:, :n], axes=axes) / np.prod(out_arr.shape[:-1])
# These wisdom functions do nothing if pyFFTW is not available
def dump_wisdom(fname):
# TODO: Docstring!
if pyfftw is not None:
with open(fname, 'wb') as fp:
pickle.dump(pyfftw.export_wisdom(), fp, pickle.HIGHEST_PROTOCOL)
def load_wisdom(fname):
# TODO: Docstring!
if pyfftw is not None:
if not os.path.exists(fname):
print("Warning: Wisdom file does not exist. First time use?")
else:
with open(fname, 'rb') as fp:
pyfftw.import_wisdom(pickle.load(fp))
# Array setups:
def empty(shape, dtype=FLOAT):
# TODO: Docstring!
result = np.empty(shape, dtype)
if pyfftw is not None:
result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
return result
def zeros(shape, dtype=FLOAT):
# TODO: Docstring!
result = np.zeros(shape, dtype)
if pyfftw is not None:
result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
return result
# TODO: Docstring!
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
if pyfftw is not None:
result = pyfftw.n_byte_align(result, pyfftw.simd_alignment)
return result
# Configure backend:
def configure_backend(backend):
"""
Change FFT backend.
Supported values are "numpy" and "fftw".
"""
global fftn
global ifftn
global rfftn
global irfftn
global rfftn_adj
global irfftn_adj
global BACKEND
if backend == "numpy":
fftn = _fftn_numpy
ifftn = _ifftn_numpy
rfftn = _rfftn_numpy
irfftn = _irfftn_numpy
rfftn_adj = _rfftn_adj_numpy
irfftn_adj = _irfftn_adj_numpy
BACKEND = 'numpy'
elif backend == "fftw":
if pyfftw is not None:
fftn = _fftn_fftw
ifftn = _ifftn_fftw
rfftn = _rfftn_fftw
irfftn = _irfftn_fftw
rfftn_adj = _rfftn_adj_fftw
irfftn_adj = _irfftn_adj_fftw
BACKEND = 'pyfftw'
else:
print("Error: FFTW requested but not available")
# On import:
if pyfftw is not None:
configure_backend("fftw")
else:
configure_backend("numpy")