__global__ void mm_iter() { // Compute a unique index (number) for this thread. // This will be used as an array index. // int tid = threadIdx.x + blockIdx.x * blockDim.x; int thread_count = blockDim.x * gridDim.x; int row_mask = row_stride - 1; for ( int c_idx = tid; c_idx < array_size; c_idx += 1 ) { int col = c_idx & row_mask; int row = c_idx >> row_stride_lg; int a_idx_base = row << row_stride_lg; float c_value = 0; for ( int k=0; k<row_stride; k++ ) { int a_idx = a_idx_base + k; int b_idx = ( k << row_stride_lg ) + col; c_value += a[a_idx] * b[b_idx]; } c[c_idx] = c_value; } }