Loading...
Searching...
No Matches
TensorFFT.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Jianwei Cui <thucjw@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_FFT_H
11#define EIGEN_CXX11_TENSOR_TENSOR_FFT_H
12
13// This code requires the ability to initialize arrays of constant
14// values directly inside a class.
15#if __cplusplus >= 201103L || EIGEN_COMP_MSVC >= 1900
16
17namespace Eigen {
18
30template <bool NeedUprade> struct MakeComplex {
31 template <typename T>
32 EIGEN_DEVICE_FUNC
33 T operator() (const T& val) const { return val; }
34};
35
36template <> struct MakeComplex<true> {
37 template <typename T>
38 EIGEN_DEVICE_FUNC
39 std::complex<T> operator() (const T& val) const { return std::complex<T>(val, 0); }
40};
41
42template <> struct MakeComplex<false> {
43 template <typename T>
44 EIGEN_DEVICE_FUNC
45 std::complex<T> operator() (const std::complex<T>& val) const { return val; }
46};
47
48template <int ResultType> struct PartOf {
49 template <typename T> T operator() (const T& val) const { return val; }
50};
51
52template <> struct PartOf<RealPart> {
53 template <typename T> T operator() (const std::complex<T>& val) const { return val.real(); }
54};
55
56template <> struct PartOf<ImagPart> {
57 template <typename T> T operator() (const std::complex<T>& val) const { return val.imag(); }
58};
59
60namespace internal {
61template <typename FFT, typename XprType, int FFTResultType, int FFTDir>
62struct traits<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir> > : public traits<XprType> {
63 typedef traits<XprType> XprTraits;
64 typedef typename NumTraits<typename XprTraits::Scalar>::Real RealScalar;
65 typedef typename std::complex<RealScalar> ComplexScalar;
66 typedef typename XprTraits::Scalar InputScalar;
67 typedef typename conditional<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>::type OutputScalar;
68 typedef typename XprTraits::StorageKind StorageKind;
69 typedef typename XprTraits::Index Index;
70 typedef typename XprType::Nested Nested;
71 typedef typename remove_reference<Nested>::type _Nested;
72 static const int NumDimensions = XprTraits::NumDimensions;
73 static const int Layout = XprTraits::Layout;
74};
75
76template <typename FFT, typename XprType, int FFTResultType, int FFTDirection>
77struct eval<TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>, Eigen::Dense> {
78 typedef const TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>& type;
79};
80
81template <typename FFT, typename XprType, int FFTResultType, int FFTDirection>
82struct nested<TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>, 1, typename eval<TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection> >::type> {
83 typedef TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection> type;
84};
85
86} // end namespace internal
87
88template <typename FFT, typename XprType, int FFTResultType, int FFTDir>
89class TensorFFTOp : public TensorBase<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir>, ReadOnlyAccessors> {
90 public:
91 typedef typename Eigen::internal::traits<TensorFFTOp>::Scalar Scalar;
92 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
93 typedef typename std::complex<RealScalar> ComplexScalar;
94 typedef typename internal::conditional<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>::type OutputScalar;
95 typedef OutputScalar CoeffReturnType;
96 typedef typename Eigen::internal::nested<TensorFFTOp>::type Nested;
97 typedef typename Eigen::internal::traits<TensorFFTOp>::StorageKind StorageKind;
98 typedef typename Eigen::internal::traits<TensorFFTOp>::Index Index;
99
100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorFFTOp(const XprType& expr, const FFT& fft)
101 : m_xpr(expr), m_fft(fft) {}
102
103 EIGEN_DEVICE_FUNC
104 const FFT& fft() const { return m_fft; }
105
106 EIGEN_DEVICE_FUNC
107 const typename internal::remove_all<typename XprType::Nested>::type& expression() const {
108 return m_xpr;
109 }
110
111 protected:
112 typename XprType::Nested m_xpr;
113 const FFT m_fft;
114};
115
116// Eval as rvalue
117template <typename FFT, typename ArgType, typename Device, int FFTResultType, int FFTDir>
118struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, Device> {
119 typedef TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir> XprType;
120 typedef typename XprType::Index Index;
121 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
122 typedef DSizes<Index, NumDims> Dimensions;
123 typedef typename XprType::Scalar Scalar;
124 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
125 typedef typename std::complex<RealScalar> ComplexScalar;
126 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
127 typedef internal::traits<XprType> XprTraits;
128 typedef typename XprTraits::Scalar InputScalar;
129 typedef typename internal::conditional<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>::type OutputScalar;
130 typedef OutputScalar CoeffReturnType;
131 typedef typename PacketType<OutputScalar, Device>::type PacketReturnType;
132 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
133
134 enum {
135 IsAligned = false,
136 PacketAccess = true,
137 BlockAccess = false,
138 Layout = TensorEvaluator<ArgType, Device>::Layout,
139 CoordAccess = false,
140 RawAccess = false
141 };
142
143 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_fft(op.fft()), m_impl(op.expression(), device), m_data(NULL), m_device(device) {
144 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
145 for (int i = 0; i < NumDims; ++i) {
146 eigen_assert(input_dims[i] > 0);
147 m_dimensions[i] = input_dims[i];
148 }
149
150 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
151 m_strides[0] = 1;
152 for (int i = 1; i < NumDims; ++i) {
153 m_strides[i] = m_strides[i - 1] * m_dimensions[i - 1];
154 }
155 } else {
156 m_strides[NumDims - 1] = 1;
157 for (int i = NumDims - 2; i >= 0; --i) {
158 m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1];
159 }
160 }
161 m_size = m_dimensions.TotalSize();
162 }
163
164 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
165 return m_dimensions;
166 }
167
168 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(OutputScalar* data) {
169 m_impl.evalSubExprsIfNeeded(NULL);
170 if (data) {
171 evalToBuf(data);
172 return false;
173 } else {
174 m_data = (CoeffReturnType*)m_device.allocate(sizeof(CoeffReturnType) * m_size);
175 evalToBuf(m_data);
176 return true;
177 }
178 }
179
180 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
181 if (m_data) {
182 m_device.deallocate(m_data);
183 m_data = NULL;
184 }
185 m_impl.cleanup();
186 }
187
188 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const {
189 return m_data[index];
190 }
191
192 template <int LoadMode>
193 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType
194 packet(Index index) const {
195 return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
196 }
197
198 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
199 costPerCoeff(bool vectorized) const {
200 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
201 }
202
203 EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; }
204
205
206 private:
207 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalToBuf(OutputScalar* data) {
208 const bool write_to_out = internal::is_same<OutputScalar, ComplexScalar>::value;
209 ComplexScalar* buf = write_to_out ? (ComplexScalar*)data : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * m_size);
210
211 for (Index i = 0; i < m_size; ++i) {
212 buf[i] = MakeComplex<internal::is_same<InputScalar, RealScalar>::value>()(m_impl.coeff(i));
213 }
214
215 for (size_t i = 0; i < m_fft.size(); ++i) {
216 Index dim = m_fft[i];
217 eigen_assert(dim >= 0 && dim < NumDims);
218 Index line_len = m_dimensions[dim];
219 eigen_assert(line_len >= 1);
220 ComplexScalar* line_buf = (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * line_len);
221 const bool is_power_of_two = isPowerOfTwo(line_len);
222 const Index good_composite = is_power_of_two ? 0 : findGoodComposite(line_len);
223 const Index log_len = is_power_of_two ? getLog2(line_len) : getLog2(good_composite);
224
225 ComplexScalar* a = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite);
226 ComplexScalar* b = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite);
227 ComplexScalar* pos_j_base_powered = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * (line_len + 1));
228 if (!is_power_of_two) {
229 // Compute twiddle factors
230 // t_n = exp(sqrt(-1) * pi * n^2 / line_len)
231 // for n = 0, 1,..., line_len-1.
232 // For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
233 pos_j_base_powered[0] = ComplexScalar(1, 0);
234 if (line_len > 1) {
235 const RealScalar pi_over_len(EIGEN_PI / line_len);
236 const ComplexScalar pos_j_base = ComplexScalar(
237 std::cos(pi_over_len), std::sin(pi_over_len));
238 pos_j_base_powered[1] = pos_j_base;
239 if (line_len > 2) {
240 const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
241 for (int j = 2; j < line_len + 1; ++j) {
242 pos_j_base_powered[j] = pos_j_base_powered[j - 1] *
243 pos_j_base_powered[j - 1] /
244 pos_j_base_powered[j - 2] * pos_j_base_sq;
245 }
246 }
247 }
248 }
249
250 for (Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) {
251 const Index base_offset = getBaseOffsetFromIndex(partial_index, dim);
252
253 // get data into line_buf
254 const Index stride = m_strides[dim];
255 if (stride == 1) {
256 memcpy(line_buf, &buf[base_offset], line_len*sizeof(ComplexScalar));
257 } else {
258 Index offset = base_offset;
259 for (int j = 0; j < line_len; ++j, offset += stride) {
260 line_buf[j] = buf[offset];
261 }
262 }
263
264 // processs the line
265 if (is_power_of_two) {
266 processDataLineCooleyTukey(line_buf, line_len, log_len);
267 }
268 else {
269 processDataLineBluestein(line_buf, line_len, good_composite, log_len, a, b, pos_j_base_powered);
270 }
271
272 // write back
273 if (FFTDir == FFT_FORWARD && stride == 1) {
274 memcpy(&buf[base_offset], line_buf, line_len*sizeof(ComplexScalar));
275 } else {
276 Index offset = base_offset;
277 const ComplexScalar div_factor = ComplexScalar(1.0 / line_len, 0);
278 for (int j = 0; j < line_len; ++j, offset += stride) {
279 buf[offset] = (FFTDir == FFT_FORWARD) ? line_buf[j] : line_buf[j] * div_factor;
280 }
281 }
282 }
283 m_device.deallocate(line_buf);
284 if (!is_power_of_two) {
285 m_device.deallocate(a);
286 m_device.deallocate(b);
287 m_device.deallocate(pos_j_base_powered);
288 }
289 }
290
291 if(!write_to_out) {
292 for (Index i = 0; i < m_size; ++i) {
293 data[i] = PartOf<FFTResultType>()(buf[i]);
294 }
295 m_device.deallocate(buf);
296 }
297 }
298
299 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static bool isPowerOfTwo(Index x) {
300 eigen_assert(x > 0);
301 return !(x & (x - 1));
302 }
303
304 // The composite number for padding, used in Bluestein's FFT algorithm
305 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Index findGoodComposite(Index n) {
306 Index i = 2;
307 while (i < 2 * n - 1) i *= 2;
308 return i;
309 }
310
311 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Index getLog2(Index m) {
312 Index log2m = 0;
313 while (m >>= 1) log2m++;
314 return log2m;
315 }
316
317 // Call Cooley Tukey algorithm directly, data length must be power of 2
318 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void processDataLineCooleyTukey(ComplexScalar* line_buf, Index line_len, Index log_len) {
319 eigen_assert(isPowerOfTwo(line_len));
320 scramble_FFT(line_buf, line_len);
321 compute_1D_Butterfly<FFTDir>(line_buf, line_len, log_len);
322 }
323
324 // Call Bluestein's FFT algorithm, m is a good composite number greater than (2 * n - 1), used as the padding length
325 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void processDataLineBluestein(ComplexScalar* line_buf, Index line_len, Index good_composite, Index log_len, ComplexScalar* a, ComplexScalar* b, const ComplexScalar* pos_j_base_powered) {
326 Index n = line_len;
327 Index m = good_composite;
328 ComplexScalar* data = line_buf;
329
330 for (Index i = 0; i < n; ++i) {
331 if(FFTDir == FFT_FORWARD) {
332 a[i] = data[i] * numext::conj(pos_j_base_powered[i]);
333 }
334 else {
335 a[i] = data[i] * pos_j_base_powered[i];
336 }
337 }
338 for (Index i = n; i < m; ++i) {
339 a[i] = ComplexScalar(0, 0);
340 }
341
342 for (Index i = 0; i < n; ++i) {
343 if(FFTDir == FFT_FORWARD) {
344 b[i] = pos_j_base_powered[i];
345 }
346 else {
347 b[i] = numext::conj(pos_j_base_powered[i]);
348 }
349 }
350 for (Index i = n; i < m - n; ++i) {
351 b[i] = ComplexScalar(0, 0);
352 }
353 for (Index i = m - n; i < m; ++i) {
354 if(FFTDir == FFT_FORWARD) {
355 b[i] = pos_j_base_powered[m-i];
356 }
357 else {
358 b[i] = numext::conj(pos_j_base_powered[m-i]);
359 }
360 }
361
362 scramble_FFT(a, m);
363 compute_1D_Butterfly<FFT_FORWARD>(a, m, log_len);
364
365 scramble_FFT(b, m);
366 compute_1D_Butterfly<FFT_FORWARD>(b, m, log_len);
367
368 for (Index i = 0; i < m; ++i) {
369 a[i] *= b[i];
370 }
371
372 scramble_FFT(a, m);
373 compute_1D_Butterfly<FFT_REVERSE>(a, m, log_len);
374
375 //Do the scaling after ifft
376 for (Index i = 0; i < m; ++i) {
377 a[i] /= m;
378 }
379
380 for (Index i = 0; i < n; ++i) {
381 if(FFTDir == FFT_FORWARD) {
382 data[i] = a[i] * numext::conj(pos_j_base_powered[i]);
383 }
384 else {
385 data[i] = a[i] * pos_j_base_powered[i];
386 }
387 }
388 }
389
390 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void scramble_FFT(ComplexScalar* data, Index n) {
391 eigen_assert(isPowerOfTwo(n));
392 Index j = 1;
393 for (Index i = 1; i < n; ++i){
394 if (j > i) {
395 std::swap(data[j-1], data[i-1]);
396 }
397 Index m = n >> 1;
398 while (m >= 2 && j > m) {
399 j -= m;
400 m >>= 1;
401 }
402 j += m;
403 }
404 }
405
406 template <int Dir>
407 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_2(ComplexScalar* data) {
408 ComplexScalar tmp = data[1];
409 data[1] = data[0] - data[1];
410 data[0] += tmp;
411 }
412
413 template <int Dir>
414 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_4(ComplexScalar* data) {
415 ComplexScalar tmp[4];
416 tmp[0] = data[0] + data[1];
417 tmp[1] = data[0] - data[1];
418 tmp[2] = data[2] + data[3];
419 if (Dir == FFT_FORWARD) {
420 tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]);
421 } else {
422 tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]);
423 }
424 data[0] = tmp[0] + tmp[2];
425 data[1] = tmp[1] + tmp[3];
426 data[2] = tmp[0] - tmp[2];
427 data[3] = tmp[1] - tmp[3];
428 }
429
430 template <int Dir>
431 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_8(ComplexScalar* data) {
432 ComplexScalar tmp_1[8];
433 ComplexScalar tmp_2[8];
434
435 tmp_1[0] = data[0] + data[1];
436 tmp_1[1] = data[0] - data[1];
437 tmp_1[2] = data[2] + data[3];
438 if (Dir == FFT_FORWARD) {
439 tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1);
440 } else {
441 tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1);
442 }
443 tmp_1[4] = data[4] + data[5];
444 tmp_1[5] = data[4] - data[5];
445 tmp_1[6] = data[6] + data[7];
446 if (Dir == FFT_FORWARD) {
447 tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1);
448 } else {
449 tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1);
450 }
451 tmp_2[0] = tmp_1[0] + tmp_1[2];
452 tmp_2[1] = tmp_1[1] + tmp_1[3];
453 tmp_2[2] = tmp_1[0] - tmp_1[2];
454 tmp_2[3] = tmp_1[1] - tmp_1[3];
455 tmp_2[4] = tmp_1[4] + tmp_1[6];
456// SQRT2DIV2 = sqrt(2)/2
457#define SQRT2DIV2 0.7071067811865476
458 if (Dir == FFT_FORWARD) {
459 tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2);
460 tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1);
461 tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2);
462 } else {
463 tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2);
464 tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1);
465 tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2);
466 }
467 data[0] = tmp_2[0] + tmp_2[4];
468 data[1] = tmp_2[1] + tmp_2[5];
469 data[2] = tmp_2[2] + tmp_2[6];
470 data[3] = tmp_2[3] + tmp_2[7];
471 data[4] = tmp_2[0] - tmp_2[4];
472 data[5] = tmp_2[1] - tmp_2[5];
473 data[6] = tmp_2[2] - tmp_2[6];
474 data[7] = tmp_2[3] - tmp_2[7];
475 }
476
477 template <int Dir>
478 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_1D_merge(
479 ComplexScalar* data, Index n, Index n_power_of_2) {
480 // Original code:
481 // RealScalar wtemp = std::sin(M_PI/n);
482 // RealScalar wpi = -std::sin(2 * M_PI/n);
483 const RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
484 const RealScalar wpi = (Dir == FFT_FORWARD)
485 ? m_minus_sin_2_PI_div_n_LUT[n_power_of_2]
486 : -m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
487
488 const ComplexScalar wp(wtemp, wpi);
489 const ComplexScalar wp_one = wp + ComplexScalar(1, 0);
490 const ComplexScalar wp_one_2 = wp_one * wp_one;
491 const ComplexScalar wp_one_3 = wp_one_2 * wp_one;
492 const ComplexScalar wp_one_4 = wp_one_3 * wp_one;
493 const Index n2 = n / 2;
494 ComplexScalar w(1.0, 0.0);
495 for (Index i = 0; i < n2; i += 4) {
496 ComplexScalar temp0(data[i + n2] * w);
497 ComplexScalar temp1(data[i + 1 + n2] * w * wp_one);
498 ComplexScalar temp2(data[i + 2 + n2] * w * wp_one_2);
499 ComplexScalar temp3(data[i + 3 + n2] * w * wp_one_3);
500 w = w * wp_one_4;
501
502 data[i + n2] = data[i] - temp0;
503 data[i] += temp0;
504
505 data[i + 1 + n2] = data[i + 1] - temp1;
506 data[i + 1] += temp1;
507
508 data[i + 2 + n2] = data[i + 2] - temp2;
509 data[i + 2] += temp2;
510
511 data[i + 3 + n2] = data[i + 3] - temp3;
512 data[i + 3] += temp3;
513 }
514 }
515
516 template <int Dir>
517 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly(
518 ComplexScalar* data, Index n, Index n_power_of_2) {
519 eigen_assert(isPowerOfTwo(n));
520 if (n > 8) {
521 compute_1D_Butterfly<Dir>(data, n / 2, n_power_of_2 - 1);
522 compute_1D_Butterfly<Dir>(data + n / 2, n / 2, n_power_of_2 - 1);
523 butterfly_1D_merge<Dir>(data, n, n_power_of_2);
524 } else if (n == 8) {
525 butterfly_8<Dir>(data);
526 } else if (n == 4) {
527 butterfly_4<Dir>(data);
528 } else if (n == 2) {
529 butterfly_2<Dir>(data);
530 }
531 }
532
533 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index getBaseOffsetFromIndex(Index index, Index omitted_dim) const {
534 Index result = 0;
535
536 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
537 for (int i = NumDims - 1; i > omitted_dim; --i) {
538 const Index partial_m_stride = m_strides[i] / m_dimensions[omitted_dim];
539 const Index idx = index / partial_m_stride;
540 index -= idx * partial_m_stride;
541 result += idx * m_strides[i];
542 }
543 result += index;
544 }
545 else {
546 for (Index i = 0; i < omitted_dim; ++i) {
547 const Index partial_m_stride = m_strides[i] / m_dimensions[omitted_dim];
548 const Index idx = index / partial_m_stride;
549 index -= idx * partial_m_stride;
550 result += idx * m_strides[i];
551 }
552 result += index;
553 }
554 // Value of index_coords[omitted_dim] is not determined to this step
555 return result;
556 }
557
558 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index getIndexFromOffset(Index base, Index omitted_dim, Index offset) const {
559 Index result = base + offset * m_strides[omitted_dim] ;
560 return result;
561 }
562
563 protected:
564 Index m_size;
565 const FFT& m_fft;
566 Dimensions m_dimensions;
567 array<Index, NumDims> m_strides;
568 TensorEvaluator<ArgType, Device> m_impl;
569 CoeffReturnType* m_data;
570 const Device& m_device;
571
572 // This will support a maximum FFT size of 2^32 for each dimension
573 // m_sin_PI_div_n_LUT[i] = (-2) * std::sin(M_PI / std::pow(2,i)) ^ 2;
574 const RealScalar m_sin_PI_div_n_LUT[32] = {
575 RealScalar(0.0),
576 RealScalar(-2),
577 RealScalar(-0.999999999999999),
578 RealScalar(-0.292893218813453),
579 RealScalar(-0.0761204674887130),
580 RealScalar(-0.0192147195967696),
581 RealScalar(-0.00481527332780311),
582 RealScalar(-0.00120454379482761),
583 RealScalar(-3.01181303795779e-04),
584 RealScalar(-7.52981608554592e-05),
585 RealScalar(-1.88247173988574e-05),
586 RealScalar(-4.70619042382852e-06),
587 RealScalar(-1.17654829809007e-06),
588 RealScalar(-2.94137117780840e-07),
589 RealScalar(-7.35342821488550e-08),
590 RealScalar(-1.83835707061916e-08),
591 RealScalar(-4.59589268710903e-09),
592 RealScalar(-1.14897317243732e-09),
593 RealScalar(-2.87243293150586e-10),
594 RealScalar( -7.18108232902250e-11),
595 RealScalar(-1.79527058227174e-11),
596 RealScalar(-4.48817645568941e-12),
597 RealScalar(-1.12204411392298e-12),
598 RealScalar(-2.80511028480785e-13),
599 RealScalar(-7.01277571201985e-14),
600 RealScalar(-1.75319392800498e-14),
601 RealScalar(-4.38298482001247e-15),
602 RealScalar(-1.09574620500312e-15),
603 RealScalar(-2.73936551250781e-16),
604 RealScalar(-6.84841378126949e-17),
605 RealScalar(-1.71210344531737e-17),
606 RealScalar(-4.28025861329343e-18)
607 };
608
609 // m_minus_sin_2_PI_div_n_LUT[i] = -std::sin(2 * M_PI / std::pow(2,i));
610 const RealScalar m_minus_sin_2_PI_div_n_LUT[32] = {
611 RealScalar(0.0),
612 RealScalar(0.0),
613 RealScalar(-1.00000000000000e+00),
614 RealScalar(-7.07106781186547e-01),
615 RealScalar(-3.82683432365090e-01),
616 RealScalar(-1.95090322016128e-01),
617 RealScalar(-9.80171403295606e-02),
618 RealScalar(-4.90676743274180e-02),
619 RealScalar(-2.45412285229123e-02),
620 RealScalar(-1.22715382857199e-02),
621 RealScalar(-6.13588464915448e-03),
622 RealScalar(-3.06795676296598e-03),
623 RealScalar(-1.53398018628477e-03),
624 RealScalar(-7.66990318742704e-04),
625 RealScalar(-3.83495187571396e-04),
626 RealScalar(-1.91747597310703e-04),
627 RealScalar(-9.58737990959773e-05),
628 RealScalar(-4.79368996030669e-05),
629 RealScalar(-2.39684498084182e-05),
630 RealScalar(-1.19842249050697e-05),
631 RealScalar(-5.99211245264243e-06),
632 RealScalar(-2.99605622633466e-06),
633 RealScalar(-1.49802811316901e-06),
634 RealScalar(-7.49014056584716e-07),
635 RealScalar(-3.74507028292384e-07),
636 RealScalar(-1.87253514146195e-07),
637 RealScalar(-9.36267570730981e-08),
638 RealScalar(-4.68133785365491e-08),
639 RealScalar(-2.34066892682746e-08),
640 RealScalar(-1.17033446341373e-08),
641 RealScalar(-5.85167231706864e-09),
642 RealScalar(-2.92583615853432e-09)
643 };
644};
645
646} // end namespace Eigen
647
648#endif // EIGEN_HAS_CONSTEXPR
649
650
651#endif // EIGEN_CXX11_TENSOR_TENSOR_FFT_H
Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition: TensorEvaluator.h:114