/opencl/oclfft

To get this branch, use:
bzr branch http://darksoft.org/webbzr/opencl/oclfft

« back to all changes in this revision

Viewing changes to fft_kernelstring.cpp

  • Committer: Matthias Vogelgesang
  • Date: 2011-01-31 09:18:47 UTC
  • Revision ID: git-v1:418c612a670678194837191e7c580047d8d58c28
Initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
 
 
2
//
 
3
// File:       fft_kernelstring.cpp
 
4
//
 
5
// Version:    <1.0>
 
6
//
 
7
// Disclaimer: IMPORTANT:  This Apple software is supplied to you by Apple Inc. ("Apple")
 
8
//             in consideration of your agreement to the following terms, and your use,
 
9
//             installation, modification or redistribution of this Apple software
 
10
//             constitutes acceptance of these terms.  If you do not agree with these
 
11
//             terms, please do not use, install, modify or redistribute this Apple
 
12
//             software.
 
13
//
 
14
//             In consideration of your agreement to abide by the following terms, and
 
15
//             subject to these terms, Apple grants you a personal, non - exclusive
 
16
//             license, under Apple's copyrights in this original Apple software ( the
 
17
//             "Apple Software" ), to use, reproduce, modify and redistribute the Apple
 
18
//             Software, with or without modifications, in source and / or binary forms;
 
19
//             provided that if you redistribute the Apple Software in its entirety and
 
20
//             without modifications, you must retain this notice and the following text
 
21
//             and disclaimers in all such redistributions of the Apple Software. Neither
 
22
//             the name, trademarks, service marks or logos of Apple Inc. may be used to
 
23
//             endorse or promote products derived from the Apple Software without specific
 
24
//             prior written permission from Apple.  Except as expressly stated in this
 
25
//             notice, no other rights or licenses, express or implied, are granted by
 
26
//             Apple herein, including but not limited to any patent rights that may be
 
27
//             infringed by your derivative works or by other works in which the Apple
 
28
//             Software may be incorporated.
 
29
//
 
30
//             The Apple Software is provided by Apple on an "AS IS" basis.  APPLE MAKES NO
 
31
//             WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED
 
32
//             WARRANTIES OF NON - INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A
 
33
//             PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND OPERATION
 
34
//             ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
 
35
//
 
36
//             IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR
 
37
//             CONSEQUENTIAL DAMAGES ( INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 
38
//             SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 
39
//             INTERRUPTION ) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION
 
40
//             AND / OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED AND WHETHER
 
41
//             UNDER THEORY OF CONTRACT, TORT ( INCLUDING NEGLIGENCE ), STRICT LIABILITY OR
 
42
//             OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
43
//
 
44
// Copyright ( C ) 2008 Apple Inc. All Rights Reserved.
 
45
//
 
46
////////////////////////////////////////////////////////////////////////////////////////////////////
 
47
 
 
48
 
 
49
#include <stdio.h>
 
50
#include <stdlib.h>
 
51
#include <math.h>
 
52
#include <iostream>
 
53
#include <sstream>
 
54
#include <string.h>
 
55
#include <assert.h>
 
56
#include "fft_internal.h"
 
57
#include "clFFT.h"
 
58
 
 
59
using namespace std;
 
60
 
 
61
#define max(A,B) ((A) > (B) ? (A) : (B))
 
62
#define min(A,B) ((A) < (B) ? (A) : (B))
 
63
 
 
64
static string 
 
65
num2str(int num)
 
66
{
 
67
        char temp[200];
 
68
        sprintf(temp, "%d", num);
 
69
        return string(temp);
 
70
}
 
71
 
 
72
// For any n, this function decomposes n into factors for loacal memory tranpose 
 
73
// based fft. Factors (radices) are sorted such that the first one (radixArray[0])
 
74
// is the largest. This base radix determines the number of registers used by each
 
75
// work item and product of remaining radices determine the size of work group needed.
 
76
// To make things concrete with and example, suppose n = 1024. It is decomposed into
 
77
// 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and 
 
78
// needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length
 
79
// 16 ffts (64 work items working in parallel) following by transpose using local 
 
80
// memory followed by again 64 length 16 ffts followed by transpose using local memory
 
81
// followed by 256 length 4 ffts. For the last step since with size of work group is 
 
82
// 64 and each work item can array for 16 values, 64 work items can compute 256 length
 
83
// 4 ffts by each work item computing 4 length 4 ffts. 
 
84
// Similarly for n = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work
 
85
// iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed
 
86
// by transpose using local memory, followed by 256 length 8 in-register ffts, followed
 
87
// by transpose using local memory, followed by 256 length 8 in-register ffts, followed
 
88
// by transpose using local memory, followed by 512 length 4 in-register ffts. Again,
 
89
// for the last step, each work item computes two length 4 in-register ffts and thus
 
90
// 256 work items are needed to compute all 512 ffts. 
 
91
// For n = 32 = 8 x 4, 4 work items first compute 4 in-register 
 
92
// lenth 8 ffts, followed by transpose using local memory followed by 8 in-register
 
93
// length 4 ffts, where each work item computes two length 4 ffts thus 4 work items
 
94
// can compute 8 length 4 ffts. However if work group size of say 64 is choosen, 
 
95
// each work group can compute 64/ 4 = 16 size 32 ffts (batched transform). 
 
96
// Users can play with these parameters to figure what gives best performance on
 
97
// their particular device i.e. some device have less register space thus using
 
98
// smaller base radix can avoid spilling ... some has small local memory thus 
 
99
// using smaller work group size may be required etc
 
100
 
 
101
static void 
 
102
getRadixArray(unsigned int n, unsigned int *radixArray, unsigned int *numRadices, unsigned int maxRadix)
 
103
{
 
104
    if(maxRadix > 1)
 
105
    {
 
106
        maxRadix = min(n, maxRadix);
 
107
        unsigned int cnt = 0;
 
108
        while(n > maxRadix)
 
109
        {
 
110
            radixArray[cnt++] = maxRadix;
 
111
            n /= maxRadix;
 
112
        }
 
113
        radixArray[cnt++] = n;
 
114
        *numRadices = cnt;
 
115
        return;
 
116
    }
 
117
 
 
118
        switch(n) 
 
119
        {
 
120
                case 2:
 
121
                        *numRadices = 1;
 
122
                        radixArray[0] = 2;
 
123
                        break;
 
124
                        
 
125
                case 4:
 
126
                        *numRadices = 1;
 
127
                        radixArray[0] = 4;
 
128
                        break;
 
129
                        
 
130
                case 8:
 
131
                        *numRadices = 1;
 
132
                        radixArray[0] = 8;
 
133
                        break;
 
134
                        
 
135
                case 16:
 
136
                        *numRadices = 2;
 
137
                        radixArray[0] = 8; radixArray[1] = 2; 
 
138
                        break;
 
139
                        
 
140
                case 32:
 
141
                        *numRadices = 2;
 
142
                        radixArray[0] = 8; radixArray[1] = 4;
 
143
                        break;
 
144
                        
 
145
                case 64:
 
146
                        *numRadices = 2;
 
147
                        radixArray[0] = 8; radixArray[1] = 8;
 
148
                        break;
 
149
                        
 
150
                case 128:
 
151
                        *numRadices = 3;
 
152
                        radixArray[0] = 8; radixArray[1] = 4; radixArray[2] = 4;
 
153
                        break;
 
154
                        
 
155
                case 256:
 
156
                        *numRadices = 4;
 
157
                        radixArray[0] = 4; radixArray[1] = 4; radixArray[2] = 4; radixArray[3] = 4;
 
158
                        break;
 
159
                        
 
160
                case 512:
 
161
                        *numRadices = 3;
 
162
                        radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8;
 
163
                        break;                  
 
164
                        
 
165
                case 1024:
 
166
                        *numRadices = 3;
 
167
                        radixArray[0] = 16; radixArray[1] = 16; radixArray[2] = 4;
 
168
                        break;  
 
169
                case 2048:
 
170
                        *numRadices = 4;
 
171
                        radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; radixArray[3] = 4;
 
172
                        break;
 
173
                default:
 
174
                        *numRadices = 0;
 
175
                        return;
 
176
        }
 
177
}
 
178
 
 
179
static void
 
180
insertHeader(string &kernelString, string &kernelName, clFFT_DataFormat dataFormat)
 
181
{
 
182
        if(dataFormat == clFFT_SplitComplexFormat) 
 
183
                kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n");
 
184
        else 
 
185
                kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n");
 
186
}
 
187
 
 
188
static void 
 
189
insertVariables(string &kStream, int maxRadix)
 
190
{
 
191
        kStream += string("    int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n");
 
192
    kStream += string("    int s, ii, jj, offset;\n");
 
193
        kStream += string("    float2 w;\n");
 
194
        kStream += string("    float ang, angf, ang1;\n");
 
195
    kStream += string("    __local float *lMemStore, *lMemLoad;\n");
 
196
    kStream += string("    float2 a[") +  num2str(maxRadix) + string("];\n");
 
197
    kStream += string("    int lId = get_local_id( 0 );\n");
 
198
    kStream += string("    int groupId = get_group_id( 0 );\n");
 
199
}
 
200
 
 
201
static void
 
202
formattedLoad(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
 
203
{
 
204
        if(dataFormat == clFFT_InterleavedComplexFormat)
 
205
                kernelString += string("        a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n");
 
206
        else
 
207
        {
 
208
                kernelString += string("        a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n");
 
209
                kernelString += string("        a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n");
 
210
        }
 
211
}
 
212
 
 
213
static void
 
214
formattedStore(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
 
215
{
 
216
        if(dataFormat == clFFT_InterleavedComplexFormat)
 
217
                kernelString += string("        out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n");
 
218
        else
 
219
        {
 
220
                kernelString += string("        out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n");
 
221
                kernelString += string("        out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n");
 
222
        }
 
223
}
 
224
 
 
225
static int
 
226
insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXForm, int numXFormsPerWG, int R0, int mem_coalesce_width, clFFT_DataFormat dataFormat)
 
227
{
 
228
        int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm);
 
229
        int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
 
230
        int i, j;
 
231
        int lMemSize = 0;
 
232
        
 
233
        if(numXFormsPerWG > 1)
 
234
            kernelString += string("        s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n");
 
235
        
 
236
    if(numWorkItemsPerXForm >= mem_coalesce_width)
 
237
    {                   
 
238
                if(numXFormsPerWG > 1)
 
239
                {
 
240
            kernelString += string("    ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n");
 
241
            kernelString += string("    jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
 
242
            kernelString += string("    if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
 
243
                        kernelString += string("        offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n");
 
244
                        if(dataFormat == clFFT_InterleavedComplexFormat)
 
245
                        {
 
246
                            kernelString += string("        in += offset;\n");
 
247
                            kernelString += string("        out += offset;\n");
 
248
                        }
 
249
                        else
 
250
                        {
 
251
                            kernelString += string("        in_real += offset;\n");
 
252
                                kernelString += string("        in_imag += offset;\n");
 
253
                            kernelString += string("        out_real += offset;\n");
 
254
                                kernelString += string("        out_imag += offset;\n");
 
255
                        }
 
256
                        for(i = 0; i < R0; i++)
 
257
                                formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
 
258
                        kernelString += string("    }\n");
 
259
                }
 
260
                else
 
261
                {
 
262
                        kernelString += string("    ii = lId;\n");
 
263
                        kernelString += string("    jj = 0;\n");
 
264
                        kernelString += string("    offset =  mad24(groupId, ") + num2str(N) + string(", ii);\n");
 
265
                        if(dataFormat == clFFT_InterleavedComplexFormat)
 
266
                        {
 
267
                            kernelString += string("        in += offset;\n");
 
268
                            kernelString += string("        out += offset;\n");
 
269
                        }
 
270
                        else
 
271
                        {
 
272
                            kernelString += string("        in_real += offset;\n");
 
273
                                kernelString += string("        in_imag += offset;\n");
 
274
                            kernelString += string("        out_real += offset;\n");
 
275
                                kernelString += string("        out_imag += offset;\n");
 
276
                        }
 
277
                        for(i = 0; i < R0; i++)
 
278
                                formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
 
279
                }
 
280
    }
 
281
    else if( N >= mem_coalesce_width )
 
282
    {
 
283
        int numInnerIter = N / mem_coalesce_width;
 
284
        int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
 
285
                
 
286
        kernelString += string("    ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
 
287
        kernelString += string("    jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
 
288
        kernelString += string("    lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
 
289
        kernelString += string("    offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n");
 
290
        kernelString += string("    offset = mad24( offset, ") + num2str(N) + string(", ii );\n");
 
291
                if(dataFormat == clFFT_InterleavedComplexFormat)
 
292
                {
 
293
                        kernelString += string("        in += offset;\n");
 
294
                        kernelString += string("        out += offset;\n");
 
295
                }
 
296
                else
 
297
                {
 
298
                        kernelString += string("        in_real += offset;\n");
 
299
                        kernelString += string("        in_imag += offset;\n");
 
300
                        kernelString += string("        out_real += offset;\n");
 
301
                        kernelString += string("        out_imag += offset;\n");
 
302
                }
 
303
        
 
304
                kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
 
305
        for(i = 0; i < numOuterIter; i++ )
 
306
        {
 
307
            kernelString += string("    if( jj < s ) {\n");
 
308
                        for(j = 0; j < numInnerIter; j++ ) 
 
309
                                formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);
 
310
                        kernelString += string("    }\n"); 
 
311
                        if(i != numOuterIter - 1)
 
312
                            kernelString += string("    jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");                      
 
313
        }
 
314
                kernelString += string("}\n ");
 
315
                kernelString += string("else {\n");
 
316
        for(i = 0; i < numOuterIter; i++ )
 
317
        {
 
318
                        for(j = 0; j < numInnerIter; j++ ) 
 
319
                                formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);                     
 
320
        }               
 
321
                kernelString += string("}\n");
 
322
        
 
323
                kernelString += string("    ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
 
324
                kernelString += string("    jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
 
325
        kernelString += string("    lMemLoad  = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n");  
 
326
                
 
327
        for( i = 0; i < numOuterIter; i++ )
 
328
                {
 
329
                        for( j = 0; j < numInnerIter; j++ )
 
330
                        {       
 
331
                                kernelString += string("    lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") + 
 
332
                                                num2str(i * numInnerIter + j) + string("].x;\n");
 
333
                        }
 
334
                }       
 
335
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");
 
336
        
 
337
        for( i = 0; i < R0; i++ )
 
338
                        kernelString += string("    a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");            
 
339
                kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");  
 
340
 
 
341
            for( i = 0; i < numOuterIter; i++ )
 
342
                {
 
343
                        for( j = 0; j < numInnerIter; j++ )
 
344
                        {       
 
345
                                kernelString += string("    lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") + 
 
346
                                                                num2str(i * numInnerIter + j) + string("].y;\n");
 
347
                        }
 
348
            }   
 
349
                kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");
 
350
                                                                                                                                                                                   
 
351
                for( i = 0; i < R0; i++ )
 
352
                        kernelString += string("    a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");            
 
353
                kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");  
 
354
                
 
355
                lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
 
356
    }  
 
357
    else
 
358
    {
 
359
        kernelString += string("    offset = mad24( groupId,  ") + num2str(N * numXFormsPerWG) + string(", lId );\n");
 
360
                if(dataFormat == clFFT_InterleavedComplexFormat)
 
361
                {
 
362
                        kernelString += string("        in += offset;\n");
 
363
                        kernelString += string("        out += offset;\n");
 
364
                }
 
365
                else
 
366
                {
 
367
                        kernelString += string("        in_real += offset;\n");
 
368
                        kernelString += string("        in_imag += offset;\n");
 
369
                        kernelString += string("        out_real += offset;\n");
 
370
                        kernelString += string("        out_imag += offset;\n");
 
371
                }
 
372
        
 
373
        kernelString += string("    ii = lId & ") + num2str(N-1) + string(";\n");
 
374
        kernelString += string("    jj = lId >> ") + num2str((int)log2(N)) + string(";\n");
 
375
        kernelString += string("    lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
 
376
        
 
377
                kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
 
378
        for( i = 0; i < R0; i++ )
 
379
        {
 
380
            kernelString += string("    if(jj < s )\n");
 
381
                        formattedLoad(kernelString, i, i*groupSize, dataFormat);
 
382
                        if(i != R0 - 1)
 
383
                            kernelString += string("    jj += ") + num2str(groupSize / N) + string(";\n");
 
384
        }
 
385
                kernelString += string("}\n");
 
386
                kernelString += string("else {\n");
 
387
        for( i = 0; i < R0; i++ )
 
388
        {
 
389
                        formattedLoad(kernelString, i, i*groupSize, dataFormat);
 
390
        }               
 
391
                kernelString += string("}\n");
 
392
        
 
393
                if(numWorkItemsPerXForm > 1)
 
394
                {
 
395
            kernelString += string("    ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
 
396
            kernelString += string("    jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
 
397
            kernelString += string("    lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); 
 
398
                }
 
399
                else 
 
400
                {
 
401
            kernelString += string("    ii = 0;\n");
 
402
            kernelString += string("    jj = lId;\n");
 
403
            kernelString += string("    lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n");                   
 
404
                }
 
405
 
 
406
                
 
407
        for( i = 0; i < R0; i++ )
 
408
            kernelString += string("    lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].x;\n"); 
 
409
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n"); 
 
410
        
 
411
        for( i = 0; i < R0; i++ )
 
412
            kernelString += string("    a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
 
413
                kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");
 
414
        
 
415
        for( i = 0; i < R0; i++ )
 
416
            kernelString += string("    lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].y;\n"); 
 
417
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n"); 
 
418
        
 
419
        for( i = 0; i < R0; i++ )
 
420
            kernelString += string("    a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
 
421
                kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");
 
422
                
 
423
                lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
 
424
    }
 
425
        
 
426
        return lMemSize;
 
427
}
 
428
 
 
429
static int
 
430
insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr, int numWorkItemsPerXForm, int numXFormsPerWG, int mem_coalesce_width, clFFT_DataFormat dataFormat)
 
431
{
 
432
        int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
 
433
        int i, j, k, ind;
 
434
        int lMemSize = 0;
 
435
        int numIter = maxRadix / Nr;
 
436
        string indent = string("");
 
437
        
 
438
    if( numWorkItemsPerXForm >= mem_coalesce_width )
 
439
    {   
 
440
                if(numXFormsPerWG > 1)
 
441
                {
 
442
            kernelString += string("    if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
 
443
                        indent = string("    ");
 
444
                }       
 
445
                for(i = 0; i < maxRadix; i++) 
 
446
                {
 
447
                        j = i % numIter;
 
448
                        k = i / numIter;
 
449
                        ind = j * Nr + k;
 
450
                        formattedStore(kernelString, ind, i*numWorkItemsPerXForm, dataFormat);
 
451
                }
 
452
                if(numXFormsPerWG > 1)
 
453
                    kernelString += string("    }\n");
 
454
    }
 
455
    else if( N >= mem_coalesce_width )
 
456
    {
 
457
        int numInnerIter = N / mem_coalesce_width;
 
458
        int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
 
459
                
 
460
        kernelString += string("    lMemLoad  = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");  
 
461
        kernelString += string("    ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
 
462
        kernelString += string("    jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
 
463
        kernelString += string("    lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
 
464
                
 
465
        for( i = 0; i < maxRadix; i++ )
 
466
                {
 
467
                        j = i % numIter;
 
468
                        k = i / numIter;
 
469
                        ind = j * Nr + k;
 
470
            kernelString += string("    lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");            
 
471
                }       
 
472
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");         
 
473
                
 
474
        for( i = 0; i < numOuterIter; i++ )
 
475
                        for( j = 0; j < numInnerIter; j++ )
 
476
                                kernelString += string("    a[") + num2str(i*numInnerIter + j) + string("].x = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
 
477
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");
 
478
                
 
479
        for( i = 0; i < maxRadix; i++ )
 
480
                {
 
481
                        j = i % numIter;
 
482
                        k = i / numIter;
 
483
                        ind = j * Nr + k;
 
484
            kernelString += string("    lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");            
 
485
                }       
 
486
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");         
 
487
                
 
488
        for( i = 0; i < numOuterIter; i++ )
 
489
                        for( j = 0; j < numInnerIter; j++ )
 
490
                                kernelString += string("    a[") + num2str(i*numInnerIter + j) + string("].y = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
 
491
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n"); 
 
492
                
 
493
                kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
 
494
                for(i = 0; i < numOuterIter; i++ )
 
495
        {
 
496
            kernelString += string("    if( jj < s ) {\n");
 
497
                        for(j = 0; j < numInnerIter; j++ ) 
 
498
                                formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); 
 
499
                        kernelString += string("    }\n"); 
 
500
                        if(i != numOuterIter - 1)
 
501
                            kernelString += string("    jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");                      
 
502
        }
 
503
                kernelString += string("}\n");
 
504
                kernelString += string("else {\n");
 
505
                for(i = 0; i < numOuterIter; i++ )
 
506
        {
 
507
                        for(j = 0; j < numInnerIter; j++ ) 
 
508
                                formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); 
 
509
        }               
 
510
                kernelString += string("}\n");
 
511
                
 
512
                lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
 
513
        }       
 
514
    else
 
515
    {   
 
516
        kernelString += string("    lMemLoad  = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");  
 
517
        
 
518
                kernelString += string("    ii = lId & ") + num2str(N - 1) + string(";\n");
 
519
        kernelString += string("    jj = lId >> ") + num2str((int) log2(N)) + string(";\n");
 
520
        kernelString += string("    lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
 
521
        
 
522
        for( i = 0; i < maxRadix; i++ )
 
523
                {
 
524
                        j = i % numIter;
 
525
                        k = i / numIter;
 
526
                        ind = j * Nr + k;
 
527
            kernelString += string("    lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
 
528
                }       
 
529
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");
 
530
        
 
531
        for( i = 0; i < maxRadix; i++ )
 
532
            kernelString += string("    a[") + num2str(i) + string("].x = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); 
 
533
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n"); 
 
534
        
 
535
        for( i = 0; i < maxRadix; i++ )
 
536
                {
 
537
                        j = i % numIter;
 
538
                        k = i / numIter;
 
539
                        ind = j * Nr + k;
 
540
            kernelString += string("    lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
 
541
                }       
 
542
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n");
 
543
        
 
544
        for( i = 0; i < maxRadix; i++ )
 
545
            kernelString += string("    a[") + num2str(i) + string("].y = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); 
 
546
        kernelString += string("    barrier( CLK_LOCAL_MEM_FENCE );\n"); 
 
547
        
 
548
                kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
 
549
                for( i = 0; i < maxRadix; i++ )
 
550
        {
 
551
            kernelString += string("    if(jj < s ) {\n");
 
552
                        formattedStore(kernelString, i, i*groupSize, dataFormat);
 
553
                        kernelString += string("    }\n");
 
554
                        if( i != maxRadix - 1)
 
555
                                kernelString += string("    jj +=") + num2str(groupSize / N) + string(";\n");
 
556
        } 
 
557
                kernelString += string("}\n");
 
558
                kernelString += string("else {\n");
 
559
                for( i = 0; i < maxRadix; i++ )
 
560
        {
 
561
                        formattedStore(kernelString, i, i*groupSize, dataFormat);
 
562
        }               
 
563
                kernelString += string("}\n");
 
564
                
 
565
                lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
 
566
    }
 
567
        
 
568
        return lMemSize;
 
569
}
 
570
 
 
571
static void 
 
572
insertfftKernel(string &kernelString, int Nr, int numIter)
 
573
{
 
574
        int i;
 
575
        for(i = 0; i < numIter; i++) 
 
576
        {
 
577
                kernelString += string("    fftKernel") + num2str(Nr) + string("(a+") + num2str(i*Nr) + string(", dir);\n");
 
578
        }
 
579
}
 
580
 
 
581
static void
 
582
insertTwiddleKernel(string &kernelString, int Nr, int numIter, int Nprev, int len, int numWorkItemsPerXForm)
 
583
{
 
584
        int z, k;
 
585
        int logNPrev = log2(Nprev);
 
586
        
 
587
        for(z = 0; z < numIter; z++) 
 
588
        {
 
589
                if(z == 0)
 
590
                {
 
591
                        if(Nprev > 1)
 
592
                            kernelString += string("    angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n");
 
593
                        else
 
594
                                kernelString += string("    angf = (float) ii;\n");
 
595
                }       
 
596
                else
 
597
                {
 
598
                        if(Nprev > 1)
 
599
                            kernelString += string("    angf = (float) ((") + num2str(z*numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n"); 
 
600
                        else
 
601
                                kernelString += string("    angf = (float) (") + num2str(z*numWorkItemsPerXForm) + string(" + ii);\n");
 
602
                }       
 
603
        
 
604
                for(k = 1; k < Nr; k++) {
 
605
                        int ind = z*Nr + k;
 
606
                        //float fac =  (float) (2.0 * M_PI * (double) k / (double) len);
 
607
                        kernelString += string("    ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n");
 
608
                        kernelString += string("    w = (float2)(native_cos(ang), native_sin(ang));\n");
 
609
                        kernelString += string("    a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n");
 
610
                }
 
611
        }
 
612
}
 
613
 
 
614
static int
 
615
getPadding(int numWorkItemsPerXForm, int Nprev, int numWorkItemsReq, int numXFormsPerWG, int Nr, int numBanks, int *offset, int *midPad)
 
616
{
 
617
        if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks))
 
618
                *offset = 0;
 
619
        else {
 
620
                int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev;
 
621
                int numColsReq = 1;
 
622
                if(numRowsReq > Nr)
 
623
                        numColsReq = numRowsReq / Nr;
 
624
                numColsReq = Nprev * numColsReq;
 
625
                *offset = numColsReq;
 
626
        }
 
627
        
 
628
        if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1)
 
629
                *midPad = 0;
 
630
        else {
 
631
                int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1);
 
632
                if( bankNum >= numWorkItemsPerXForm )
 
633
                        *midPad = 0;
 
634
                else
 
635
                        *midPad = numWorkItemsPerXForm - bankNum;
 
636
        }
 
637
        
 
638
        int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1);
 
639
        return lMemSize;
 
640
}
 
641
 
 
642
 
 
643
static void 
 
644
insertLocalStores(string &kernelString, int numIter, int Nr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
 
645
{
 
646
        int z, k;
 
647
 
 
648
        for(z = 0; z < numIter; z++) {
 
649
                for(k = 0; k < Nr; k++) {
 
650
                        int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm;
 
651
                        kernelString += string("    lMemStore[") + num2str(index) + string("] = a[") + num2str(z*Nr + k) + string("].") + comp + string(";\n");
 
652
                }
 
653
        }
 
654
        kernelString += string("    barrier(CLK_LOCAL_MEM_FENCE);\n");
 
655
}
 
656
 
 
657
static void 
 
658
insertLocalLoads(string &kernelString, int n, int Nr, int Nrn, int Nprev, int Ncurr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
 
659
{
 
660
        int numWorkItemsReqN = n / Nrn;                                                                         
 
661
        int interBlockHNum = max( Nprev / numWorkItemsPerXForm, 1 );                    
 
662
        int interBlockHStride = numWorkItemsPerXForm;                                                   
 
663
        int vertWidth = max(numWorkItemsPerXForm / Nprev, 1);                                   
 
664
        vertWidth = min( vertWidth, Nr);                                                                        
 
665
        int vertNum = Nr / vertWidth;                                                                           
 
666
        int vertStride = ( n / Nr + offset ) * vertWidth;                                       
 
667
        int iter = max( numWorkItemsReqN / numWorkItemsPerXForm, 1);
 
668
        int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1;
 
669
        intraBlockHStride *= Nprev;
 
670
        
 
671
        int stride = numWorkItemsReq / Nrn;                                                                     
 
672
        int i;
 
673
        for(i = 0; i < iter; i++) {
 
674
                int ii = i / (interBlockHNum * vertNum);
 
675
                int zz = i % (interBlockHNum * vertNum);
 
676
                int jj = zz % interBlockHNum;
 
677
                int kk = zz / interBlockHNum;
 
678
                int z;
 
679
                for(z = 0; z < Nrn; z++) {
 
680
                        int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride;
 
681
                        kernelString += string("    a[") + num2str(i*Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n");
 
682
                }
 
683
        }
 
684
        kernelString += string("    barrier(CLK_LOCAL_MEM_FENCE);\n");
 
685
}
 
686
 
 
687
static void
 
688
insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numWorkItemsReq, int numWorkItemsPerXForm, int numXFormsPerWG, int offset, int midPad)
 
689
{       
 
690
        int Ncurr = Nprev * Nr;
 
691
        int logNcurr = log2(Ncurr);
 
692
        int logNprev = log2(Nprev);
 
693
        int incr = (numWorkItemsReq + offset) * Nr + midPad;
 
694
        
 
695
        if(Ncurr < numWorkItemsPerXForm) 
 
696
        {
 
697
                if(Nprev == 1)
 
698
                    kernelString += string("    j = ii & ") + num2str(Ncurr - 1) + string(";\n");
 
699
                else
 
700
                        kernelString += string("    j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n");
 
701
                
 
702
                if(Nprev == 1) 
 
703
                        kernelString += string("    i = ii >> ") + num2str(logNcurr) + string(";\n");
 
704
                else 
 
705
                        kernelString += string("    i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n"); 
 
706
        }       
 
707
        else 
 
708
        {
 
709
                if(Nprev == 1)
 
710
                    kernelString += string("    j = ii;\n");
 
711
                else
 
712
                        kernelString += string("    j = ii >> ") + num2str(logNprev) + string(";\n");
 
713
                if(Nprev == 1) 
 
714
                        kernelString += string("    i = 0;\n"); 
 
715
                else 
 
716
                        kernelString += string("    i = ii & ") + num2str(Nprev - 1) + string(";\n");
 
717
        }
 
718
 
 
719
    if(numXFormsPerWG > 1)
 
720
        kernelString += string("    i = mad24(jj, ") + num2str(incr) + string(", i);\n");               
 
721
 
 
722
    kernelString += string("    lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n"); 
 
723
}
 
724
 
 
725
static void
 
726
insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int numXFormsPerWG, int Nr, int offset, int midPad)
 
727
{
 
728
        if(numXFormsPerWG == 1) {
 
729
                kernelString += string("    lMemStore = sMem + ii;\n");         
 
730
        }
 
731
        else {
 
732
                kernelString += string("    lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n");     
 
733
        }
 
734
}
 
735
 
 
736
 
 
737
static void
 
738
createLocalMemfftKernelString(cl_fft_plan *plan)
 
739
{
 
740
        unsigned int radixArray[10];
 
741
        unsigned int numRadix;
 
742
         
 
743
        unsigned int n = plan->n.x;
 
744
        
 
745
        assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n");
 
746
        
 
747
        getRadixArray(n, radixArray, &numRadix, 0);
 
748
        assert(numRadix > 0 && "no radix array supplied\n");
 
749
        
 
750
        if(n/radixArray[0] > plan->max_work_item_per_workgroup)
 
751
            getRadixArray(n, radixArray, &numRadix, plan->max_radix);
 
752
 
 
753
        assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n");
 
754
        assert(n/radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n");
 
755
        
 
756
        unsigned int tmpLen = 1;
 
757
        unsigned int i;
 
758
        for(i = 0; i < numRadix; i++)
 
759
        {       
 
760
                assert( radixArray[i] && !( (radixArray[i] - 1) & radixArray[i] ) );
 
761
            tmpLen *= radixArray[i];
 
762
        }
 
763
        assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n");
 
764
        
 
765
        int offset, midPad;
 
766
        string localString(""), kernelName("");
 
767
        
 
768
        clFFT_DataFormat dataFormat = plan->format;
 
769
        string *kernelString = plan->kernel_string;
 
770
        
 
771
        
 
772
        cl_fft_kernel_info **kInfo = &plan->kernel_info;
 
773
        int kCount = 0;
 
774
        
 
775
        while(*kInfo)
 
776
        {
 
777
                kInfo = &(*kInfo)->next;
 
778
                kCount++;
 
779
        }
 
780
        
 
781
        kernelName = string("fft") + num2str(kCount);
 
782
        
 
783
        *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
 
784
        (*kInfo)->kernel = 0;
 
785
        (*kInfo)->lmem_size = 0;
 
786
        (*kInfo)->num_workgroups = 0;
 
787
        (*kInfo)->num_workitems_per_workgroup = 0;
 
788
        (*kInfo)->dir = cl_fft_kernel_x;
 
789
        (*kInfo)->in_place_possible = 1;
 
790
        (*kInfo)->next = NULL;
 
791
        (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
 
792
        strcpy((*kInfo)->kernel_name, kernelName.c_str());
 
793
        
 
794
        unsigned int numWorkItemsPerXForm = n / radixArray[0];
 
795
        unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm; 
 
796
        assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup);
 
797
        int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm;
 
798
        (*kInfo)->num_workgroups = 1;
 
799
    (*kInfo)->num_xforms_per_workgroup = numXFormsPerWG;
 
800
        (*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG;
 
801
        
 
802
        unsigned int *N = radixArray;
 
803
        unsigned int maxRadix = N[0];
 
804
        unsigned int lMemSize = 0;
 
805
                
 
806
        insertVariables(localString, maxRadix);
 
807
        
 
808
        lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat);
 
809
        (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
 
810
        
 
811
        string xcomp = string("x");
 
812
        string ycomp = string("y");
 
813
        
 
814
        unsigned int Nprev = 1;
 
815
        unsigned int len = n;
 
816
        unsigned int r;
 
817
        for(r = 0; r < numRadix; r++) 
 
818
        {
 
819
                int numIter = N[0] / N[r];
 
820
                int numWorkItemsReq = n / N[r];
 
821
                int Ncurr = Nprev * N[r];
 
822
                insertfftKernel(localString, N[r], numIter);
 
823
                
 
824
                if(r < (numRadix - 1)) {
 
825
                        insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm);
 
826
                        lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad);
 
827
                        (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
 
828
                        insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad);
 
829
                        insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad);
 
830
                        insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
 
831
                        insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
 
832
                        insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
 
833
                        insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
 
834
                        Nprev = Ncurr;
 
835
                        len = len / N[r];
 
836
                }
 
837
        }
 
838
        
 
839
        lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat);
 
840
        (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
 
841
        
 
842
        insertHeader(*kernelString, kernelName, dataFormat);
 
843
        *kernelString += string("{\n");
 
844
        if((*kInfo)->lmem_size)
 
845
        *kernelString += string("    __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
 
846
        *kernelString += localString;
 
847
        *kernelString += string("}\n");
 
848
}
 
849
 
 
850
// For n larger than what can be computed using local memory fft, global transposes
 
851
// multiple kernel launces is needed. For these sizes, n can be decomposed using
 
852
// much larger base radices i.e. say n = 262144 = 128 x 64 x 32. Thus three kernel
 
853
// launches will be needed, first computing 64 x 32, length 128 ffts, second computing
 
854
// 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts. 
 
855
// Each of these base radices can futher be divided into factors so that each of these 
 
856
// base ffts can be computed within one kernel launch using in-register ffts and local 
 
857
// memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length 
 
858
// 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length 
 
859
// 16 ffts followed by transpose using local memory followed by each of these eight 
 
860
// work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This 
 
861
// means only 8 work items are needed for computing one length 128 fft. If we choose
 
862
// work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one
 
863
// work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this 
 
864
// means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each 
 
865
// work group where each work group is computing 8 length 128 ffts where each length
 
866
// 128 fft is computed by 8 work items. Same logic can be applied to other two kernels
 
867
// in this example. Users can play with difference base radices and difference 
 
868
// decompositions of base radices to generates different kernels and see which gives
 
869
// best performance. Following function is just fixed to use 128 as base radix
 
870
 
 
871
void
 
872
getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices)
 
873
{
 
874
        int baseRadix = min(n, 128);
 
875
        
 
876
        int numR = 0;
 
877
        int N = n;
 
878
        while(N > baseRadix) 
 
879
        {
 
880
                N /= baseRadix;
 
881
                numR++;
 
882
        }
 
883
        
 
884
        for(int i = 0; i < numR; i++)
 
885
                radix[i] = baseRadix;
 
886
        
 
887
        radix[numR] = N;
 
888
        numR++;
 
889
        *numRadices = numR;
 
890
                
 
891
        for(int i = 0; i < numR; i++)
 
892
        {
 
893
                int B = radix[i];
 
894
                if(B <= 8)
 
895
                {
 
896
                        R1[i] = B;
 
897
                        R2[i] = 1;
 
898
                        continue;
 
899
                }
 
900
                
 
901
                int r1 = 2; 
 
902
                int r2 = B / r1;
 
903
            while(r2 > r1)
 
904
            {
 
905
                   r1 *=2;
 
906
                   r2 = B / r1;
 
907
            }
 
908
                R1[i] = r1;
 
909
                R2[i] = r2;
 
910
        }       
 
911
}
 
912
 
 
913
static void
 
914
createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir dir, int vertBS)
 
915
{               
 
916
        int i, j, k, t;
 
917
        int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
 
918
    int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
 
919
    int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
 
920
        int radix, R1, R2;
 
921
        int numRadices;
 
922
        
 
923
        int maxThreadsPerBlock = plan->max_work_item_per_workgroup;
 
924
        int maxArrayLen = plan->max_radix;
 
925
        int batchSize = plan->min_mem_coalesce_width;   
 
926
        clFFT_DataFormat dataFormat = plan->format;
 
927
        int vertical = (dir == cl_fft_kernel_x) ? 0 : 1;        
 
928
        
 
929
        getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices);
 
930
                
 
931
        int numPasses = numRadices;
 
932
        
 
933
        string localString(""), kernelName("");
 
934
        string *kernelString = plan->kernel_string;
 
935
        cl_fft_kernel_info **kInfo = &plan->kernel_info; 
 
936
        int kCount = 0;
 
937
        
 
938
        while(*kInfo)
 
939
        {
 
940
                kInfo = &(*kInfo)->next;
 
941
                kCount++;
 
942
        }
 
943
        
 
944
        int N = n;
 
945
        int m = (int)log2(n);
 
946
        int Rinit = vertical ? BS : 1;
 
947
        batchSize = vertical ? min(BS, batchSize) : batchSize;
 
948
        int passNum;
 
949
        
 
950
        for(passNum = 0; passNum < numPasses; passNum++) 
 
951
        {
 
952
                
 
953
                localString.clear();
 
954
                kernelName.clear();
 
955
                
 
956
                radix = radixArr[passNum];
 
957
                R1 = R1Arr[passNum];
 
958
                R2 = R2Arr[passNum];
 
959
                                
 
960
                int strideI = Rinit;
 
961
                for(i = 0; i < numPasses; i++)
 
962
                        if(i != passNum)
 
963
                                strideI *= radixArr[i];
 
964
                
 
965
                int strideO = Rinit;
 
966
                for(i = 0; i < passNum; i++)
 
967
                        strideO *= radixArr[i];
 
968
                
 
969
                int threadsPerXForm = R2;
 
970
                batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize;
 
971
                batchSize = min(batchSize, strideI);
 
972
                int threadsPerBlock = batchSize * threadsPerXForm;
 
973
                threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock);
 
974
                batchSize = threadsPerBlock / threadsPerXForm;
 
975
                assert(R2 <= R1);
 
976
                assert(R1*R2 == radix);
 
977
                assert(R1 <= maxArrayLen);
 
978
                assert(threadsPerBlock <= maxThreadsPerBlock);
 
979
                
 
980
                int numIter = R1 / R2;
 
981
                int gInInc = threadsPerBlock / batchSize;
 
982
                
 
983
                
 
984
                int lgStrideO = log2(strideO);
 
985
                int numBlocksPerXForm = strideI / batchSize;
 
986
                int numBlocks = numBlocksPerXForm;
 
987
                if(!vertical)
 
988
                        numBlocks *= BS;
 
989
                else
 
990
                        numBlocks *= vertBS;
 
991
                
 
992
                kernelName = string("fft") + num2str(kCount);
 
993
                *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
 
994
                (*kInfo)->kernel = 0;
 
995
                if(R2 == 1)
 
996
                        (*kInfo)->lmem_size = 0;
 
997
                else
 
998
                {
 
999
                    if(strideO == 1)
 
1000
                        (*kInfo)->lmem_size = (radix + 1)*batchSize;
 
1001
                    else
 
1002
                            (*kInfo)->lmem_size = threadsPerBlock*R1;
 
1003
                }
 
1004
                (*kInfo)->num_workgroups = numBlocks;
 
1005
        (*kInfo)->num_xforms_per_workgroup = 1;
 
1006
                (*kInfo)->num_workitems_per_workgroup = threadsPerBlock;
 
1007
                (*kInfo)->dir = dir;
 
1008
                if( (passNum == (numPasses - 1)) && (numPasses & 1) )
 
1009
                    (*kInfo)->in_place_possible = 1;
 
1010
                else
 
1011
                        (*kInfo)->in_place_possible = 0;
 
1012
                (*kInfo)->next = NULL;
 
1013
                (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
 
1014
                strcpy((*kInfo)->kernel_name, kernelName.c_str());
 
1015
                
 
1016
                insertVariables(localString, R1);
 
1017
                                                
 
1018
                if(vertical) 
 
1019
                {
 
1020
                        localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n");
 
1021
                        localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
 
1022
                        localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n");
 
1023
                        localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n");
 
1024
                        localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
 
1025
                        localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
 
1026
                        int stride = radix*Rinit;
 
1027
                        for(i = 0; i < passNum; i++)
 
1028
                                stride *= radixArr[i];
 
1029
                        localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n");
 
1030
                        localString += string("bNum = groupId;\n");
 
1031
                }
 
1032
                else 
 
1033
                {
 
1034
                        int lgNumBlocksPerXForm = log2(numBlocksPerXForm);
 
1035
                        localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
 
1036
                        localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n");
 
1037
                        localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n");
 
1038
                        localString += string("tid = indexIn;\n");
 
1039
                        localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
 
1040
                        localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); 
 
1041
                        int stride = radix*Rinit;
 
1042
                        for(i = 0; i < passNum; i++)
 
1043
                                stride *= radixArr[i];
 
1044
                        localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n");                    
 
1045
                        localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n");
 
1046
                        localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n");   
 
1047
                }
 
1048
                
 
1049
                // Load Data
 
1050
                int lgBatchSize = log2(batchSize);
 
1051
                localString += string("tid = lId;\n");
 
1052
                localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n");
 
1053
                localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n"); 
 
1054
                localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n");
 
1055
 
 
1056
                if(dataFormat == clFFT_SplitComplexFormat) 
 
1057
                {
 
1058
                        localString += string("in_real += indexIn;\n");
 
1059
                        localString += string("in_imag += indexIn;\n");                 
 
1060
                        for(j = 0; j < R1; j++)
 
1061
                                localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j*gInInc*strideI) + string("];\n");
 
1062
                        for(j = 0; j < R1; j++) 
 
1063
                                localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j*gInInc*strideI) + string("];\n");
 
1064
                }
 
1065
                else 
 
1066
                {
 
1067
                        localString += string("in += indexIn;\n");
 
1068
                        for(j = 0; j < R1; j++)
 
1069
                                localString += string("a[") + num2str(j) + string("] = in[") + num2str(j*gInInc*strideI) + string("];\n");
 
1070
            }
 
1071
                
 
1072
                localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n");                                                         
 
1073
                
 
1074
                if(R2 > 1)
 
1075
                {
 
1076
                    // twiddle
 
1077
                    for(k = 1; k < R1; k++) 
 
1078
                    {
 
1079
                            localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n");
 
1080
                            localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
 
1081
                            localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n"); 
 
1082
                    }
 
1083
                
 
1084
                    // shuffle
 
1085
                    numIter = R1 / R2;  
 
1086
                    localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n");
 
1087
                    localString += string("lMemStore = sMem + tid;\n");
 
1088
                    localString += string("lMemLoad = sMem + indexIn;\n");
 
1089
                    for(k = 0; k < R1; k++) 
 
1090
                            localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
 
1091
                    localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");   
 
1092
                    for(k = 0; k < numIter; k++)
 
1093
                            for(t = 0; t < R2; t++)
 
1094
                                    localString += string("a[") + num2str(k*R2+t) + string("].x = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
 
1095
                    localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
 
1096
                    for(k = 0; k < R1; k++) 
 
1097
                            localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
 
1098
                    localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");   
 
1099
                    for(k = 0; k < numIter; k++)
 
1100
                            for(t = 0; t < R2; t++)
 
1101
                                    localString += string("a[") + num2str(k*R2+t) + string("].y = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
 
1102
                    localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
 
1103
                
 
1104
                    for(j = 0; j < numIter; j++)
 
1105
                            localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j*R2) + string(", dir);\n");
 
1106
                }
 
1107
                
 
1108
                // twiddle
 
1109
                if(passNum < (numPasses - 1)) 
 
1110
                {
 
1111
                        localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n");
 
1112
                        localString += string("k = j << ") + num2str((int)log2(R1/R2)) + string(";\n"); 
 
1113
                        localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n");
 
1114
                        for(t = 0; t < R1; t++) 
 
1115
                        {
 
1116
                                localString += string("ang = ang1*(k + ") + num2str((t%R2)*R1 + (t/R2)) + string(");\n");
 
1117
                                localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
 
1118
                                localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n");
 
1119
                        }
 
1120
                }
 
1121
                
 
1122
                // Store Data
 
1123
                if(strideO == 1) 
 
1124
                {
 
1125
                        
 
1126
                        localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n");
 
1127
                        localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n");
 
1128
                        
 
1129
                        for(int i = 0; i < R1/R2; i++)
 
1130
                                for(int j = 0; j < R2; j++)
 
1131
                                        localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].x;\n");
 
1132
                        localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
 
1133
                        if(threadsPerBlock >= radix)
 
1134
            {
 
1135
                for(int i = 0; i < R1; i++)
 
1136
                localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
 
1137
            }
 
1138
            else
 
1139
            {
 
1140
                int innerIter = radix/threadsPerBlock;
 
1141
                int outerIter = R1/innerIter;
 
1142
                for(int i = 0; i < outerIter; i++)
 
1143
                    for(int j = 0; j < innerIter; j++)
 
1144
                        localString += string("a[") + num2str(i*innerIter+j) + string("].x = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
 
1145
            }
 
1146
                        localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
 
1147
                        
 
1148
                        for(int i = 0; i < R1/R2; i++)
 
1149
                                for(int j = 0; j < R2; j++)
 
1150
                                        localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].y;\n");
 
1151
                        localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
 
1152
                        if(threadsPerBlock >= radix)
 
1153
            {
 
1154
                for(int i = 0; i < R1; i++)
 
1155
                    localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
 
1156
            }
 
1157
            else
 
1158
            {
 
1159
                int innerIter = radix/threadsPerBlock;
 
1160
                int outerIter = R1/innerIter;
 
1161
                for(int i = 0; i < outerIter; i++)
 
1162
                    for(int j = 0; j < innerIter; j++)
 
1163
                        localString += string("a[") + num2str(i*innerIter+j) + string("].y = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
 
1164
            }
 
1165
                        localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
 
1166
                        
 
1167
                        localString += string("indexOut += tid;\n");
 
1168
                        if(dataFormat == clFFT_SplitComplexFormat) {
 
1169
                                localString += string("out_real += indexOut;\n");
 
1170
                                localString += string("out_imag += indexOut;\n");
 
1171
                                for(k = 0; k < R1; k++)
 
1172
                                        localString += string("out_real[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
 
1173
                                for(k = 0; k < R1; k++)
 
1174
                                        localString += string("out_imag[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
 
1175
                        }
 
1176
                        else {
 
1177
                                localString += string("out += indexOut;\n");
 
1178
                                for(k = 0; k < R1; k++)
 
1179
                                        localString += string("out[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n");                            
 
1180
                        }
 
1181
                 
 
1182
                }
 
1183
                else 
 
1184
                {
 
1185
                        localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n");
 
1186
                        if(dataFormat == clFFT_SplitComplexFormat) {
 
1187
                                localString += string("out_real += indexOut;\n");
 
1188
                                localString += string("out_imag += indexOut;\n");                       
 
1189
                                for(k = 0; k < R1; k++)
 
1190
                                        localString += string("out_real[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].x;\n");
 
1191
                                for(k = 0; k < R1; k++)
 
1192
                                        localString += string("out_imag[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].y;\n");
 
1193
                        }
 
1194
                        else {
 
1195
                                localString += string("out += indexOut;\n");
 
1196
                                for(k = 0; k < R1; k++)
 
1197
                                        localString += string("out[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("];\n");
 
1198
                        }
 
1199
                }
 
1200
                
 
1201
                insertHeader(*kernelString, kernelName, dataFormat);
 
1202
                *kernelString += string("{\n");
 
1203
                if((*kInfo)->lmem_size)
 
1204
                        *kernelString += string("    __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
 
1205
                *kernelString += localString;
 
1206
                *kernelString += string("}\n");         
 
1207
                
 
1208
                N /= radix;
 
1209
                kInfo = &(*kInfo)->next;
 
1210
                kCount++;
 
1211
        }
 
1212
}
 
1213
 
 
1214
void FFT1D(cl_fft_plan *plan, cl_fft_kernel_dir dir)
 
1215
{       
 
1216
    unsigned int radixArray[10];
 
1217
    unsigned int numRadix;
 
1218
    
 
1219
        switch(dir)
 
1220
        {
 
1221
                case cl_fft_kernel_x:
 
1222
                    if(plan->n.x > plan->max_localmem_fft_size)
 
1223
                    {
 
1224
                        createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
 
1225
                    }
 
1226
                    else if(plan->n.x > 1)
 
1227
                    {
 
1228
                        getRadixArray(plan->n.x, radixArray, &numRadix, 0);
 
1229
                        if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
 
1230
                        {
 
1231
                                    createLocalMemfftKernelString(plan);
 
1232
                                }
 
1233
                            else
 
1234
                            {
 
1235
                                getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix);
 
1236
                                if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
 
1237
                                    createLocalMemfftKernelString(plan);
 
1238
                                else
 
1239
                                        createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
 
1240
                                }
 
1241
                    }
 
1242
                        break;
 
1243
                        
 
1244
                case cl_fft_kernel_y:
 
1245
                        if(plan->n.y > 1)
 
1246
                            createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1);
 
1247
                        break;
 
1248
                        
 
1249
                case cl_fft_kernel_z:
 
1250
                        if(plan->n.z > 1)
 
1251
                            createGlobalFFTKernelString(plan, plan->n.z, plan->n.x*plan->n.y, cl_fft_kernel_z, 1);
 
1252
                default:
 
1253
                        return;
 
1254
        }
 
1255
}
 
1256