1
#define A(i, j) shmem[(i) * PPT * BLOCK_SIZE + (j)]
2
#define B(i, j) shmem[PPT * PPT * BLOCK_SIZE * BLOCK_SIZE + (i) * PPT * BLOCK_SIZE + (j)]
4
__kernel void multiply(__global float *res, __global float *a, __global float *b, unsigned long size, __local float *shmem) {
5
float sum[PPT][PPT] = {0};
7
int tx = get_local_id(0);
8
int ty = get_local_id(1);
10
int i = get_global_id(1);
11
int j = get_global_id(0);
16
for(k = 0; k < size; k += PPT * BLOCK_SIZE) {
18
for (y = 0; y < PPT; ++y) {
20
for (x = 0; x < PPT; ++x) {
21
A(ty * PPT + y, tx * PPT + x) = a[(i * PPT + y) * size + (k + tx * PPT + x)];
22
B(ty * PPT + y, tx * PPT + x) = b[(k + ty * PPT + y) * size + (j * PPT + x)];
26
barrier(CLK_LOCAL_MEM_FENCE);
28
#pragma unroll PPT * BLOCK_SIZE
29
for (l = 0; l < PPT * BLOCK_SIZE; ++l) {
31
for (y = 0; y < PPT; ++y) {
33
for (x = 0; x < PPT; ++x) {
34
sum[y][x] += A(ty * PPT + y, l) * B(l, tx * PPT + x);
35
// sum[y][x] += A(ty * PPT + y, l) * b[(k + l) * size + (j * PPT + x)];
36
// sum[y][x] += a[(i * PPT + y) * size + (k + l)] * b[(k + l) * size + (j * PPT + x)];
41
barrier(CLK_LOCAL_MEM_FENCE);
45
for (y = 0; y < PPT; ++y) {
47
for (x = 0; x < PPT; ++x) {
48
res[(i * PPT + y) * size + j * PPT + x] = sum[y][x];