1
#define A(i, j) shmem[(i) * PPT * BLOCK_SIZE + (j)]
2
#define B(i, j) shmem[PPT * PPT * BLOCK_SIZE * BLOCK_SIZE + (i) * PPT * BLOCK_SIZE + (j)]
4
__kernel void multiply(__global float *res, __global float *a, __global float *b, unsigned long size, __local float *shmem) {
5
float sum[PPT][PPT] = {0};
7
int bx = get_group_id(0) * get_local_size(0) * PPT;
8
int by = get_group_id(1) * get_local_size(1) * PPT;
10
int tx = get_local_id(0);
11
int ty = get_local_id(1);
13
int i = get_global_id(1);
14
int j = get_global_id(0);
20
for(k = 0; k < size; k += PPT * BLOCK_SIZE) {
22
for (y = 0; y < PPT; ++y) {
24
for (x = 0; x < PPT; ++x) {
25
A(y * BLOCK_SIZE + ty, x * BLOCK_SIZE + tx) = a[(by + y * BLOCK_SIZE + ty) * size + (k + x * BLOCK_SIZE + tx)];
26
B(y * BLOCK_SIZE + ty, x * BLOCK_SIZE + tx) = b[(k + y * BLOCK_SIZE + ty) * size + (bx + x * BLOCK_SIZE + tx)];
30
barrier(CLK_LOCAL_MEM_FENCE);
32
#pragma unroll PPT * BLOCK_SIZE
33
for (l = 0; l < PPT * BLOCK_SIZE; ++l) {
35
for (y = 0; y < PPT; ++y) {
37
for (x = 0; x < PPT; ++x) {
38
sum[y][x] += A(y * BLOCK_SIZE + ty, l) * B(l, x * BLOCK_SIZE + tx);
43
barrier(CLK_LOCAL_MEM_FENCE);
47
for (y = 0; y < PPT; ++y) {
49
for (x = 0; x < PPT; ++x) {
50
res[(by + y * BLOCK_SIZE + ty) * size + bx + x * BLOCK_SIZE + tx] = sum[y][x];