3
// File: fft_kernelstring.cpp
7
// Disclaimer: IMPORTANT: This Apple software is supplied to you by Apple Inc. ("Apple")
8
// in consideration of your agreement to the following terms, and your use,
9
// installation, modification or redistribution of this Apple software
10
// constitutes acceptance of these terms. If you do not agree with these
11
// terms, please do not use, install, modify or redistribute this Apple
14
// In consideration of your agreement to abide by the following terms, and
15
// subject to these terms, Apple grants you a personal, non - exclusive
16
// license, under Apple's copyrights in this original Apple software ( the
17
// "Apple Software" ), to use, reproduce, modify and redistribute the Apple
18
// Software, with or without modifications, in source and / or binary forms;
19
// provided that if you redistribute the Apple Software in its entirety and
20
// without modifications, you must retain this notice and the following text
21
// and disclaimers in all such redistributions of the Apple Software. Neither
22
// the name, trademarks, service marks or logos of Apple Inc. may be used to
23
// endorse or promote products derived from the Apple Software without specific
24
// prior written permission from Apple. Except as expressly stated in this
25
// notice, no other rights or licenses, express or implied, are granted by
26
// Apple herein, including but not limited to any patent rights that may be
27
// infringed by your derivative works or by other works in which the Apple
28
// Software may be incorporated.
30
// The Apple Software is provided by Apple on an "AS IS" basis. APPLE MAKES NO
31
// WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED
32
// WARRANTIES OF NON - INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A
33
// PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND OPERATION
34
// ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
36
// IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR
37
// CONSEQUENTIAL DAMAGES ( INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
38
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
39
// INTERRUPTION ) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION
40
// AND / OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED AND WHETHER
41
// UNDER THEORY OF CONTRACT, TORT ( INCLUDING NEGLIGENCE ), STRICT LIABILITY OR
42
// OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
// Copyright ( C ) 2008 Apple Inc. All Rights Reserved.
46
////////////////////////////////////////////////////////////////////////////////////////////////////
56
#include "fft_internal.h"
61
#define max(A,B) ((A) > (B) ? (A) : (B))
62
#define min(A,B) ((A) < (B) ? (A) : (B))
68
sprintf(temp, "%d", num);
72
// For any n, this function decomposes n into factors for loacal memory tranpose
73
// based fft. Factors (radices) are sorted such that the first one (radixArray[0])
74
// is the largest. This base radix determines the number of registers used by each
75
// work item and product of remaining radices determine the size of work group needed.
76
// To make things concrete with and example, suppose n = 1024. It is decomposed into
77
// 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and
78
// needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length
79
// 16 ffts (64 work items working in parallel) following by transpose using local
80
// memory followed by again 64 length 16 ffts followed by transpose using local memory
81
// followed by 256 length 4 ffts. For the last step since with size of work group is
82
// 64 and each work item can array for 16 values, 64 work items can compute 256 length
83
// 4 ffts by each work item computing 4 length 4 ffts.
84
// Similarly for n = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work
85
// iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed
86
// by transpose using local memory, followed by 256 length 8 in-register ffts, followed
87
// by transpose using local memory, followed by 256 length 8 in-register ffts, followed
88
// by transpose using local memory, followed by 512 length 4 in-register ffts. Again,
89
// for the last step, each work item computes two length 4 in-register ffts and thus
90
// 256 work items are needed to compute all 512 ffts.
91
// For n = 32 = 8 x 4, 4 work items first compute 4 in-register
92
// lenth 8 ffts, followed by transpose using local memory followed by 8 in-register
93
// length 4 ffts, where each work item computes two length 4 ffts thus 4 work items
94
// can compute 8 length 4 ffts. However if work group size of say 64 is choosen,
95
// each work group can compute 64/ 4 = 16 size 32 ffts (batched transform).
96
// Users can play with these parameters to figure what gives best performance on
97
// their particular device i.e. some device have less register space thus using
98
// smaller base radix can avoid spilling ... some has small local memory thus
99
// using smaller work group size may be required etc
102
getRadixArray(unsigned int n, unsigned int *radixArray, unsigned int *numRadices, unsigned int maxRadix)
106
maxRadix = min(n, maxRadix);
107
unsigned int cnt = 0;
110
radixArray[cnt++] = maxRadix;
113
radixArray[cnt++] = n;
137
radixArray[0] = 8; radixArray[1] = 2;
142
radixArray[0] = 8; radixArray[1] = 4;
147
radixArray[0] = 8; radixArray[1] = 8;
152
radixArray[0] = 8; radixArray[1] = 4; radixArray[2] = 4;
157
radixArray[0] = 4; radixArray[1] = 4; radixArray[2] = 4; radixArray[3] = 4;
162
radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8;
167
radixArray[0] = 16; radixArray[1] = 16; radixArray[2] = 4;
171
radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; radixArray[3] = 4;
180
insertHeader(string &kernelString, string &kernelName, clFFT_DataFormat dataFormat)
182
if(dataFormat == clFFT_SplitComplexFormat)
183
kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n");
185
kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n");
189
insertVariables(string &kStream, int maxRadix)
191
kStream += string(" int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n");
192
kStream += string(" int s, ii, jj, offset;\n");
193
kStream += string(" float2 w;\n");
194
kStream += string(" float ang, angf, ang1;\n");
195
kStream += string(" __local float *lMemStore, *lMemLoad;\n");
196
kStream += string(" float2 a[") + num2str(maxRadix) + string("];\n");
197
kStream += string(" int lId = get_local_id( 0 );\n");
198
kStream += string(" int groupId = get_group_id( 0 );\n");
202
formattedLoad(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
204
if(dataFormat == clFFT_InterleavedComplexFormat)
205
kernelString += string(" a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n");
208
kernelString += string(" a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n");
209
kernelString += string(" a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n");
214
formattedStore(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
216
if(dataFormat == clFFT_InterleavedComplexFormat)
217
kernelString += string(" out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n");
220
kernelString += string(" out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n");
221
kernelString += string(" out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n");
226
insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXForm, int numXFormsPerWG, int R0, int mem_coalesce_width, clFFT_DataFormat dataFormat)
228
int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm);
229
int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
233
if(numXFormsPerWG > 1)
234
kernelString += string(" s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n");
236
if(numWorkItemsPerXForm >= mem_coalesce_width)
238
if(numXFormsPerWG > 1)
240
kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n");
241
kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
242
kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
243
kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n");
244
if(dataFormat == clFFT_InterleavedComplexFormat)
246
kernelString += string(" in += offset;\n");
247
kernelString += string(" out += offset;\n");
251
kernelString += string(" in_real += offset;\n");
252
kernelString += string(" in_imag += offset;\n");
253
kernelString += string(" out_real += offset;\n");
254
kernelString += string(" out_imag += offset;\n");
256
for(i = 0; i < R0; i++)
257
formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
258
kernelString += string(" }\n");
262
kernelString += string(" ii = lId;\n");
263
kernelString += string(" jj = 0;\n");
264
kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n");
265
if(dataFormat == clFFT_InterleavedComplexFormat)
267
kernelString += string(" in += offset;\n");
268
kernelString += string(" out += offset;\n");
272
kernelString += string(" in_real += offset;\n");
273
kernelString += string(" in_imag += offset;\n");
274
kernelString += string(" out_real += offset;\n");
275
kernelString += string(" out_imag += offset;\n");
277
for(i = 0; i < R0; i++)
278
formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
281
else if( N >= mem_coalesce_width )
283
int numInnerIter = N / mem_coalesce_width;
284
int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
286
kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
287
kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
288
kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
289
kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n");
290
kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n");
291
if(dataFormat == clFFT_InterleavedComplexFormat)
293
kernelString += string(" in += offset;\n");
294
kernelString += string(" out += offset;\n");
298
kernelString += string(" in_real += offset;\n");
299
kernelString += string(" in_imag += offset;\n");
300
kernelString += string(" out_real += offset;\n");
301
kernelString += string(" out_imag += offset;\n");
304
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
305
for(i = 0; i < numOuterIter; i++ )
307
kernelString += string(" if( jj < s ) {\n");
308
for(j = 0; j < numInnerIter; j++ )
309
formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);
310
kernelString += string(" }\n");
311
if(i != numOuterIter - 1)
312
kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
314
kernelString += string("}\n ");
315
kernelString += string("else {\n");
316
for(i = 0; i < numOuterIter; i++ )
318
for(j = 0; j < numInnerIter; j++ )
319
formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);
321
kernelString += string("}\n");
323
kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
324
kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
325
kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n");
327
for( i = 0; i < numOuterIter; i++ )
329
for( j = 0; j < numInnerIter; j++ )
331
kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
332
num2str(i * numInnerIter + j) + string("].x;\n");
335
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
337
for( i = 0; i < R0; i++ )
338
kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
339
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
341
for( i = 0; i < numOuterIter; i++ )
343
for( j = 0; j < numInnerIter; j++ )
345
kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
346
num2str(i * numInnerIter + j) + string("].y;\n");
349
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
351
for( i = 0; i < R0; i++ )
352
kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
353
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
355
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
359
kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n");
360
if(dataFormat == clFFT_InterleavedComplexFormat)
362
kernelString += string(" in += offset;\n");
363
kernelString += string(" out += offset;\n");
367
kernelString += string(" in_real += offset;\n");
368
kernelString += string(" in_imag += offset;\n");
369
kernelString += string(" out_real += offset;\n");
370
kernelString += string(" out_imag += offset;\n");
373
kernelString += string(" ii = lId & ") + num2str(N-1) + string(";\n");
374
kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n");
375
kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
377
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
378
for( i = 0; i < R0; i++ )
380
kernelString += string(" if(jj < s )\n");
381
formattedLoad(kernelString, i, i*groupSize, dataFormat);
383
kernelString += string(" jj += ") + num2str(groupSize / N) + string(";\n");
385
kernelString += string("}\n");
386
kernelString += string("else {\n");
387
for( i = 0; i < R0; i++ )
389
formattedLoad(kernelString, i, i*groupSize, dataFormat);
391
kernelString += string("}\n");
393
if(numWorkItemsPerXForm > 1)
395
kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
396
kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
397
kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
401
kernelString += string(" ii = 0;\n");
402
kernelString += string(" jj = lId;\n");
403
kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n");
407
for( i = 0; i < R0; i++ )
408
kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].x;\n");
409
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
411
for( i = 0; i < R0; i++ )
412
kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
413
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
415
for( i = 0; i < R0; i++ )
416
kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].y;\n");
417
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
419
for( i = 0; i < R0; i++ )
420
kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
421
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
423
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
430
insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr, int numWorkItemsPerXForm, int numXFormsPerWG, int mem_coalesce_width, clFFT_DataFormat dataFormat)
432
int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
435
int numIter = maxRadix / Nr;
436
string indent = string("");
438
if( numWorkItemsPerXForm >= mem_coalesce_width )
440
if(numXFormsPerWG > 1)
442
kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
443
indent = string(" ");
445
for(i = 0; i < maxRadix; i++)
450
formattedStore(kernelString, ind, i*numWorkItemsPerXForm, dataFormat);
452
if(numXFormsPerWG > 1)
453
kernelString += string(" }\n");
455
else if( N >= mem_coalesce_width )
457
int numInnerIter = N / mem_coalesce_width;
458
int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
460
kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
461
kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
462
kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
463
kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
465
for( i = 0; i < maxRadix; i++ )
470
kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
472
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
474
for( i = 0; i < numOuterIter; i++ )
475
for( j = 0; j < numInnerIter; j++ )
476
kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].x = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
477
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
479
for( i = 0; i < maxRadix; i++ )
484
kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
486
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
488
for( i = 0; i < numOuterIter; i++ )
489
for( j = 0; j < numInnerIter; j++ )
490
kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].y = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
491
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
493
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
494
for(i = 0; i < numOuterIter; i++ )
496
kernelString += string(" if( jj < s ) {\n");
497
for(j = 0; j < numInnerIter; j++ )
498
formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
499
kernelString += string(" }\n");
500
if(i != numOuterIter - 1)
501
kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
503
kernelString += string("}\n");
504
kernelString += string("else {\n");
505
for(i = 0; i < numOuterIter; i++ )
507
for(j = 0; j < numInnerIter; j++ )
508
formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
510
kernelString += string("}\n");
512
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
516
kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
518
kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n");
519
kernelString += string(" jj = lId >> ") + num2str((int) log2(N)) + string(";\n");
520
kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
522
for( i = 0; i < maxRadix; i++ )
527
kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
529
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
531
for( i = 0; i < maxRadix; i++ )
532
kernelString += string(" a[") + num2str(i) + string("].x = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n");
533
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
535
for( i = 0; i < maxRadix; i++ )
540
kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
542
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
544
for( i = 0; i < maxRadix; i++ )
545
kernelString += string(" a[") + num2str(i) + string("].y = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n");
546
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
548
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
549
for( i = 0; i < maxRadix; i++ )
551
kernelString += string(" if(jj < s ) {\n");
552
formattedStore(kernelString, i, i*groupSize, dataFormat);
553
kernelString += string(" }\n");
554
if( i != maxRadix - 1)
555
kernelString += string(" jj +=") + num2str(groupSize / N) + string(";\n");
557
kernelString += string("}\n");
558
kernelString += string("else {\n");
559
for( i = 0; i < maxRadix; i++ )
561
formattedStore(kernelString, i, i*groupSize, dataFormat);
563
kernelString += string("}\n");
565
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
572
insertfftKernel(string &kernelString, int Nr, int numIter)
575
for(i = 0; i < numIter; i++)
577
kernelString += string(" fftKernel") + num2str(Nr) + string("(a+") + num2str(i*Nr) + string(", dir);\n");
582
insertTwiddleKernel(string &kernelString, int Nr, int numIter, int Nprev, int len, int numWorkItemsPerXForm)
585
int logNPrev = log2(Nprev);
587
for(z = 0; z < numIter; z++)
592
kernelString += string(" angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n");
594
kernelString += string(" angf = (float) ii;\n");
599
kernelString += string(" angf = (float) ((") + num2str(z*numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n");
601
kernelString += string(" angf = (float) (") + num2str(z*numWorkItemsPerXForm) + string(" + ii);\n");
604
for(k = 1; k < Nr; k++) {
606
//float fac = (float) (2.0 * M_PI * (double) k / (double) len);
607
kernelString += string(" ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n");
608
kernelString += string(" w = (float2)(native_cos(ang), native_sin(ang));\n");
609
kernelString += string(" a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n");
615
getPadding(int numWorkItemsPerXForm, int Nprev, int numWorkItemsReq, int numXFormsPerWG, int Nr, int numBanks, int *offset, int *midPad)
617
if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks))
620
int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev;
623
numColsReq = numRowsReq / Nr;
624
numColsReq = Nprev * numColsReq;
625
*offset = numColsReq;
628
if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1)
631
int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1);
632
if( bankNum >= numWorkItemsPerXForm )
635
*midPad = numWorkItemsPerXForm - bankNum;
638
int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1);
644
insertLocalStores(string &kernelString, int numIter, int Nr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
648
for(z = 0; z < numIter; z++) {
649
for(k = 0; k < Nr; k++) {
650
int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm;
651
kernelString += string(" lMemStore[") + num2str(index) + string("] = a[") + num2str(z*Nr + k) + string("].") + comp + string(";\n");
654
kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
658
insertLocalLoads(string &kernelString, int n, int Nr, int Nrn, int Nprev, int Ncurr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
660
int numWorkItemsReqN = n / Nrn;
661
int interBlockHNum = max( Nprev / numWorkItemsPerXForm, 1 );
662
int interBlockHStride = numWorkItemsPerXForm;
663
int vertWidth = max(numWorkItemsPerXForm / Nprev, 1);
664
vertWidth = min( vertWidth, Nr);
665
int vertNum = Nr / vertWidth;
666
int vertStride = ( n / Nr + offset ) * vertWidth;
667
int iter = max( numWorkItemsReqN / numWorkItemsPerXForm, 1);
668
int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1;
669
intraBlockHStride *= Nprev;
671
int stride = numWorkItemsReq / Nrn;
673
for(i = 0; i < iter; i++) {
674
int ii = i / (interBlockHNum * vertNum);
675
int zz = i % (interBlockHNum * vertNum);
676
int jj = zz % interBlockHNum;
677
int kk = zz / interBlockHNum;
679
for(z = 0; z < Nrn; z++) {
680
int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride;
681
kernelString += string(" a[") + num2str(i*Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n");
684
kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
688
insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numWorkItemsReq, int numWorkItemsPerXForm, int numXFormsPerWG, int offset, int midPad)
690
int Ncurr = Nprev * Nr;
691
int logNcurr = log2(Ncurr);
692
int logNprev = log2(Nprev);
693
int incr = (numWorkItemsReq + offset) * Nr + midPad;
695
if(Ncurr < numWorkItemsPerXForm)
698
kernelString += string(" j = ii & ") + num2str(Ncurr - 1) + string(";\n");
700
kernelString += string(" j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n");
703
kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n");
705
kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n");
710
kernelString += string(" j = ii;\n");
712
kernelString += string(" j = ii >> ") + num2str(logNprev) + string(";\n");
714
kernelString += string(" i = 0;\n");
716
kernelString += string(" i = ii & ") + num2str(Nprev - 1) + string(";\n");
719
if(numXFormsPerWG > 1)
720
kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n");
722
kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n");
726
insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int numXFormsPerWG, int Nr, int offset, int midPad)
728
if(numXFormsPerWG == 1) {
729
kernelString += string(" lMemStore = sMem + ii;\n");
732
kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n");
738
createLocalMemfftKernelString(cl_fft_plan *plan)
740
unsigned int radixArray[10];
741
unsigned int numRadix;
743
unsigned int n = plan->n.x;
745
assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n");
747
getRadixArray(n, radixArray, &numRadix, 0);
748
assert(numRadix > 0 && "no radix array supplied\n");
750
if(n/radixArray[0] > plan->max_work_item_per_workgroup)
751
getRadixArray(n, radixArray, &numRadix, plan->max_radix);
753
assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n");
754
assert(n/radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n");
756
unsigned int tmpLen = 1;
758
for(i = 0; i < numRadix; i++)
760
assert( radixArray[i] && !( (radixArray[i] - 1) & radixArray[i] ) );
761
tmpLen *= radixArray[i];
763
assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n");
766
string localString(""), kernelName("");
768
clFFT_DataFormat dataFormat = plan->format;
769
string *kernelString = plan->kernel_string;
772
cl_fft_kernel_info **kInfo = &plan->kernel_info;
777
kInfo = &(*kInfo)->next;
781
kernelName = string("fft") + num2str(kCount);
783
*kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
784
(*kInfo)->kernel = 0;
785
(*kInfo)->lmem_size = 0;
786
(*kInfo)->num_workgroups = 0;
787
(*kInfo)->num_workitems_per_workgroup = 0;
788
(*kInfo)->dir = cl_fft_kernel_x;
789
(*kInfo)->in_place_possible = 1;
790
(*kInfo)->next = NULL;
791
(*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
792
strcpy((*kInfo)->kernel_name, kernelName.c_str());
794
unsigned int numWorkItemsPerXForm = n / radixArray[0];
795
unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm;
796
assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup);
797
int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm;
798
(*kInfo)->num_workgroups = 1;
799
(*kInfo)->num_xforms_per_workgroup = numXFormsPerWG;
800
(*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG;
802
unsigned int *N = radixArray;
803
unsigned int maxRadix = N[0];
804
unsigned int lMemSize = 0;
806
insertVariables(localString, maxRadix);
808
lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat);
809
(*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
811
string xcomp = string("x");
812
string ycomp = string("y");
814
unsigned int Nprev = 1;
815
unsigned int len = n;
817
for(r = 0; r < numRadix; r++)
819
int numIter = N[0] / N[r];
820
int numWorkItemsReq = n / N[r];
821
int Ncurr = Nprev * N[r];
822
insertfftKernel(localString, N[r], numIter);
824
if(r < (numRadix - 1)) {
825
insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm);
826
lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad);
827
(*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
828
insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad);
829
insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad);
830
insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
831
insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
832
insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
833
insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
839
lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat);
840
(*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
842
insertHeader(*kernelString, kernelName, dataFormat);
843
*kernelString += string("{\n");
844
if((*kInfo)->lmem_size)
845
*kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
846
*kernelString += localString;
847
*kernelString += string("}\n");
850
// For n larger than what can be computed using local memory fft, global transposes
851
// multiple kernel launces is needed. For these sizes, n can be decomposed using
852
// much larger base radices i.e. say n = 262144 = 128 x 64 x 32. Thus three kernel
853
// launches will be needed, first computing 64 x 32, length 128 ffts, second computing
854
// 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts.
855
// Each of these base radices can futher be divided into factors so that each of these
856
// base ffts can be computed within one kernel launch using in-register ffts and local
857
// memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length
858
// 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length
859
// 16 ffts followed by transpose using local memory followed by each of these eight
860
// work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This
861
// means only 8 work items are needed for computing one length 128 fft. If we choose
862
// work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one
863
// work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this
864
// means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each
865
// work group where each work group is computing 8 length 128 ffts where each length
866
// 128 fft is computed by 8 work items. Same logic can be applied to other two kernels
867
// in this example. Users can play with difference base radices and difference
868
// decompositions of base radices to generates different kernels and see which gives
869
// best performance. Following function is just fixed to use 128 as base radix
872
getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices)
874
int baseRadix = min(n, 128);
884
for(int i = 0; i < numR; i++)
885
radix[i] = baseRadix;
891
for(int i = 0; i < numR; i++)
914
createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir dir, int vertBS)
917
int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
918
int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
919
int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
923
int maxThreadsPerBlock = plan->max_work_item_per_workgroup;
924
int maxArrayLen = plan->max_radix;
925
int batchSize = plan->min_mem_coalesce_width;
926
clFFT_DataFormat dataFormat = plan->format;
927
int vertical = (dir == cl_fft_kernel_x) ? 0 : 1;
929
getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices);
931
int numPasses = numRadices;
933
string localString(""), kernelName("");
934
string *kernelString = plan->kernel_string;
935
cl_fft_kernel_info **kInfo = &plan->kernel_info;
940
kInfo = &(*kInfo)->next;
945
int m = (int)log2(n);
946
int Rinit = vertical ? BS : 1;
947
batchSize = vertical ? min(BS, batchSize) : batchSize;
950
for(passNum = 0; passNum < numPasses; passNum++)
956
radix = radixArr[passNum];
961
for(i = 0; i < numPasses; i++)
963
strideI *= radixArr[i];
966
for(i = 0; i < passNum; i++)
967
strideO *= radixArr[i];
969
int threadsPerXForm = R2;
970
batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize;
971
batchSize = min(batchSize, strideI);
972
int threadsPerBlock = batchSize * threadsPerXForm;
973
threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock);
974
batchSize = threadsPerBlock / threadsPerXForm;
976
assert(R1*R2 == radix);
977
assert(R1 <= maxArrayLen);
978
assert(threadsPerBlock <= maxThreadsPerBlock);
980
int numIter = R1 / R2;
981
int gInInc = threadsPerBlock / batchSize;
984
int lgStrideO = log2(strideO);
985
int numBlocksPerXForm = strideI / batchSize;
986
int numBlocks = numBlocksPerXForm;
992
kernelName = string("fft") + num2str(kCount);
993
*kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
994
(*kInfo)->kernel = 0;
996
(*kInfo)->lmem_size = 0;
1000
(*kInfo)->lmem_size = (radix + 1)*batchSize;
1002
(*kInfo)->lmem_size = threadsPerBlock*R1;
1004
(*kInfo)->num_workgroups = numBlocks;
1005
(*kInfo)->num_xforms_per_workgroup = 1;
1006
(*kInfo)->num_workitems_per_workgroup = threadsPerBlock;
1007
(*kInfo)->dir = dir;
1008
if( (passNum == (numPasses - 1)) && (numPasses & 1) )
1009
(*kInfo)->in_place_possible = 1;
1011
(*kInfo)->in_place_possible = 0;
1012
(*kInfo)->next = NULL;
1013
(*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
1014
strcpy((*kInfo)->kernel_name, kernelName.c_str());
1016
insertVariables(localString, R1);
1020
localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n");
1021
localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
1022
localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n");
1023
localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n");
1024
localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
1025
localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
1026
int stride = radix*Rinit;
1027
for(i = 0; i < passNum; i++)
1028
stride *= radixArr[i];
1029
localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n");
1030
localString += string("bNum = groupId;\n");
1034
int lgNumBlocksPerXForm = log2(numBlocksPerXForm);
1035
localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
1036
localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n");
1037
localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n");
1038
localString += string("tid = indexIn;\n");
1039
localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
1040
localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
1041
int stride = radix*Rinit;
1042
for(i = 0; i < passNum; i++)
1043
stride *= radixArr[i];
1044
localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n");
1045
localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n");
1046
localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n");
1050
int lgBatchSize = log2(batchSize);
1051
localString += string("tid = lId;\n");
1052
localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n");
1053
localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n");
1054
localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n");
1056
if(dataFormat == clFFT_SplitComplexFormat)
1058
localString += string("in_real += indexIn;\n");
1059
localString += string("in_imag += indexIn;\n");
1060
for(j = 0; j < R1; j++)
1061
localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j*gInInc*strideI) + string("];\n");
1062
for(j = 0; j < R1; j++)
1063
localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j*gInInc*strideI) + string("];\n");
1067
localString += string("in += indexIn;\n");
1068
for(j = 0; j < R1; j++)
1069
localString += string("a[") + num2str(j) + string("] = in[") + num2str(j*gInInc*strideI) + string("];\n");
1072
localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n");
1077
for(k = 1; k < R1; k++)
1079
localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n");
1080
localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
1081
localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n");
1086
localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n");
1087
localString += string("lMemStore = sMem + tid;\n");
1088
localString += string("lMemLoad = sMem + indexIn;\n");
1089
for(k = 0; k < R1; k++)
1090
localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
1091
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1092
for(k = 0; k < numIter; k++)
1093
for(t = 0; t < R2; t++)
1094
localString += string("a[") + num2str(k*R2+t) + string("].x = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
1095
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1096
for(k = 0; k < R1; k++)
1097
localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
1098
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1099
for(k = 0; k < numIter; k++)
1100
for(t = 0; t < R2; t++)
1101
localString += string("a[") + num2str(k*R2+t) + string("].y = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
1102
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1104
for(j = 0; j < numIter; j++)
1105
localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j*R2) + string(", dir);\n");
1109
if(passNum < (numPasses - 1))
1111
localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n");
1112
localString += string("k = j << ") + num2str((int)log2(R1/R2)) + string(";\n");
1113
localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n");
1114
for(t = 0; t < R1; t++)
1116
localString += string("ang = ang1*(k + ") + num2str((t%R2)*R1 + (t/R2)) + string(");\n");
1117
localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
1118
localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n");
1126
localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n");
1127
localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n");
1129
for(int i = 0; i < R1/R2; i++)
1130
for(int j = 0; j < R2; j++)
1131
localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].x;\n");
1132
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1133
if(threadsPerBlock >= radix)
1135
for(int i = 0; i < R1; i++)
1136
localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
1140
int innerIter = radix/threadsPerBlock;
1141
int outerIter = R1/innerIter;
1142
for(int i = 0; i < outerIter; i++)
1143
for(int j = 0; j < innerIter; j++)
1144
localString += string("a[") + num2str(i*innerIter+j) + string("].x = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
1146
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1148
for(int i = 0; i < R1/R2; i++)
1149
for(int j = 0; j < R2; j++)
1150
localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].y;\n");
1151
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1152
if(threadsPerBlock >= radix)
1154
for(int i = 0; i < R1; i++)
1155
localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
1159
int innerIter = radix/threadsPerBlock;
1160
int outerIter = R1/innerIter;
1161
for(int i = 0; i < outerIter; i++)
1162
for(int j = 0; j < innerIter; j++)
1163
localString += string("a[") + num2str(i*innerIter+j) + string("].y = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
1165
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1167
localString += string("indexOut += tid;\n");
1168
if(dataFormat == clFFT_SplitComplexFormat) {
1169
localString += string("out_real += indexOut;\n");
1170
localString += string("out_imag += indexOut;\n");
1171
for(k = 0; k < R1; k++)
1172
localString += string("out_real[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
1173
for(k = 0; k < R1; k++)
1174
localString += string("out_imag[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
1177
localString += string("out += indexOut;\n");
1178
for(k = 0; k < R1; k++)
1179
localString += string("out[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n");
1185
localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n");
1186
if(dataFormat == clFFT_SplitComplexFormat) {
1187
localString += string("out_real += indexOut;\n");
1188
localString += string("out_imag += indexOut;\n");
1189
for(k = 0; k < R1; k++)
1190
localString += string("out_real[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].x;\n");
1191
for(k = 0; k < R1; k++)
1192
localString += string("out_imag[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].y;\n");
1195
localString += string("out += indexOut;\n");
1196
for(k = 0; k < R1; k++)
1197
localString += string("out[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("];\n");
1201
insertHeader(*kernelString, kernelName, dataFormat);
1202
*kernelString += string("{\n");
1203
if((*kInfo)->lmem_size)
1204
*kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
1205
*kernelString += localString;
1206
*kernelString += string("}\n");
1209
kInfo = &(*kInfo)->next;
1214
void FFT1D(cl_fft_plan *plan, cl_fft_kernel_dir dir)
1216
unsigned int radixArray[10];
1217
unsigned int numRadix;
1221
case cl_fft_kernel_x:
1222
if(plan->n.x > plan->max_localmem_fft_size)
1224
createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
1226
else if(plan->n.x > 1)
1228
getRadixArray(plan->n.x, radixArray, &numRadix, 0);
1229
if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
1231
createLocalMemfftKernelString(plan);
1235
getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix);
1236
if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
1237
createLocalMemfftKernelString(plan);
1239
createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
1244
case cl_fft_kernel_y:
1246
createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1);
1249
case cl_fft_kernel_z:
1251
createGlobalFFTKernelString(plan, plan->n.z, plan->n.x*plan->n.y, cl_fft_kernel_z, 1);