#include <stdio.h>
#include <assert.h>
#include <ranges>
#include <random>
#include <vector>
#include <algorithm>
#include <gp/misc.h>
using namespace std;
typedef float elt_t;
void
mm_simple( elt_t *g, elt_t *a, elt_t *b, int d1, int d2, int d3 )
{
for ( int r=0; r<d1; r++ )
for ( int c=0; c<d3; c++ )
{
elt_t e = 0;
for ( int k=0; k<d2; k++ )
e += a[ r*d2 + k ] * b[ k*d3 + c ];
g[ r*d3 + c ] = e;
}
}
template< int t >
void
mm_tiled( elt_t *g, elt_t *a, elt_t *b, int d1, int d2, int d3 )
{
assert( d1 % t == 0 ); assert( d2 % t == 0 ); assert( d3 % t == 0 );
for ( int rr = 0; rr < d1; rr += t )
for ( int cc = 0; cc < d3; cc += t )
for ( int kk = 0; kk < d2; kk += t )
for ( int r = rr; r < rr + t; r++ )
for ( int c = cc; c < cc + t; c++ )
{
elt_t e = 0;
for ( int k = kk; k < kk + t; k++ )
e += a[ r*d2 + k ] * b[ k*d3 + c ];
g[ r*d3 + c ] += e;
}
}
template< typename T >
class My_Rand {
public:
My_Rand():seed(2735),re(seed),uni_pm1(-1,1)
{
}
T operator () () { return uni_pm1(re); }
int seed;
default_random_engine re;
uniform_real_distribution<T> uni_pm1;
};
My_Rand<elt_t> rand_pm1;
template< int t >
void
mm_tiled_do( elt_t *g, elt_t *a, elt_t *b, int d1, int d2, int d3, elt_t *ch )
{
size_t g_size = d1 * d3;
for ( size_t i=0; i<g_size; i++ ) g[i] = 0;
const double t_start_s = time_wall_fp();
mm_tiled<t>( g, a, b, d1, d2, d3 );
const double t_end_s = time_wall_fp();
int n_err = 0, n_err_z = 0;
double max_err = 0;
float tol = numeric_limits<elt_t>::epsilon() * d2 * 5;
for ( size_t i=0; i<g_size; i++ )
{
elt_t hd = g[i];
elt_t hc = ch[i];
float d = fabs( hd - hc );
if ( d > max_err ) max_err = d;
if ( d > tol )
{
n_err++;
if ( hd == 0 ) n_err_z++;
}
}
printf("Duration tile size %d: %.3f ms\n",
t, ( t_end_s - t_start_s ) * 1000 );
bool verbose = false;
if ( !n_err && !verbose ) return;
printf("Number of errors %d out of %zd. %d zeros\n",
n_err, g_size, n_err_z);
printf("Max err = %.9f, tolerance = %.9f\n", max_err, tol );
}
void
mm_do( int d1, int d2, int d3 )
{
vector<elt_t> gs( d1*d3 ), gt( d1*d3 );
vector<elt_t> a( d1*d2 ), b( d2*d3 );
ranges::generate(a,rand_pm1);
ranges::generate(b,rand_pm1);
mm_simple( gs.data(), a.data(), b.data(), d1, d2, d3 );
const double t_start_s = time_wall_fp();
mm_simple( gs.data(), a.data(), b.data(), d1, d2, d3 );
const double t_simple_end_s = time_wall_fp();
const double dur_simple_s = t_simple_end_s - t_start_s;
printf("Duration simple: %.3f ms\n",
dur_simple_s * 1000 );
mm_tiled_do<4>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );
mm_tiled_do<8>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );
mm_tiled_do<16>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );
mm_tiled_do<32>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );
}
int
main( int argc, char **argv )
{
mm_do(32,32,32);
mm_do(128,128,128);
mm_do(1024,1024,1024);
}