#include <stdio.h>
#include <assert.h>
#include <ranges>
#include <random>
#include <vector>
#include <algorithm>
#include <gp/misc.h>

using namespace std;

typedef float elt_t;

void
mm_simple( elt_t *g, elt_t *a, elt_t *b, int d1, int d2, int d3 )
{
  /// Un-tiled Matrix Multiplication

  // Matrix a: d1 rows, d2 columns.
  // Matrix b: d2 rows, d3 columns.
  // Matrix g: d1 rows, d3 columns.

  for ( int r=0; r<d1; r++ )
    for ( int c=0; c<d3; c++ )
      {
        elt_t e = 0;
        for ( int k=0; k<d2; k++ )
          e += a[ r*d2 + k ] * b[ k*d3 + c ];
        g[ r*d3 + c ] = e;
      }
}

template< int t >
void
mm_tiled( elt_t *g, elt_t *a, elt_t *b, int d1, int d2, int d3 )
{
  /// Tiled Multiplication. Tile size is t.

  // Matrix a: d1 rows, d2 columns.
  // Matrix b: d2 rows, d3 columns.
  // Matrix g: d1 rows, d3 columns.

  // Code only works if tile size is a factor of each dimension.
  //
  assert( d1 % t == 0 ); assert( d2 % t == 0 ); assert( d3 % t == 0 );

  for ( int rr = 0;  rr < d1;  rr += t )
    for ( int cc = 0;  cc < d3;  cc += t )


      for ( int kk = 0;  kk < d2;  kk += t )
        // This loop body executes
        //   ( d1 / t ) ( d2 / t ) ( d3 / t ) = d1*d2*d3 / t^3 times.
        // For square arrays: d^3 / t^3

        // Loops below operate on one t*t tile of a, b, and g ..
        // .. and so potentially access memory 2 t^2 times (ignoring g) ..
        // .. assuming at least 2t^2 + t + 1 storage elements.
        for ( int r = rr;  r < rr + t;  r++ )
          for ( int c = cc;  c < cc + t;  c++ )
            {
              elt_t e = 0;
              for ( int k = kk;  k < kk + t;  k++ )
                e += a[ r*d2 + k ] * b[ k*d3 + c ];

              g[ r*d3 + c ] += e;
            }
}

template< typename T >
class My_Rand {
public:
  My_Rand():seed(2735),re(seed),uni_pm1(-1,1)
    {
    }
  T operator () () { return uni_pm1(re); }
  int seed;
  default_random_engine re;
  uniform_real_distribution<T> uni_pm1;
};
My_Rand<elt_t> rand_pm1;


template< int t >
void
mm_tiled_do( elt_t *g, elt_t *a, elt_t *b, int d1, int d2, int d3, elt_t *ch )
{
  size_t g_size = d1 * d3;

  for ( size_t i=0; i<g_size; i++ ) g[i] = 0;

  const double t_start_s = time_wall_fp();

  mm_tiled<t>( g, a, b, d1, d2, d3 );

  const double t_end_s = time_wall_fp();

  int n_err = 0, n_err_z = 0;
  double max_err = 0;

  float tol = numeric_limits<elt_t>::epsilon() * d2 * 5;
  for ( size_t i=0; i<g_size; i++ )
    {
      elt_t hd = g[i];
      elt_t hc = ch[i];
      float d = fabs( hd - hc );
      if ( d > max_err ) max_err = d;
      if ( d > tol )
        {
          n_err++;
          if ( hd == 0 ) n_err_z++;
        }
    }

  printf("Duration tile size %d: %.3f ms\n",
         t, ( t_end_s - t_start_s ) * 1000 );

  bool verbose = false;
  if ( !n_err && !verbose ) return;
  printf("Number of errors %d out of %zd. %d zeros\n",
         n_err, g_size, n_err_z);
  printf("Max err = %.9f,  tolerance = %.9f\n", max_err, tol );


}

void
mm_do( int d1, int d2, int d3 )
{
  vector<elt_t> gs( d1*d3 ), gt( d1*d3 );
  vector<elt_t> a( d1*d2 ), b( d2*d3 );

  ranges::generate(a,rand_pm1);
  ranges::generate(b,rand_pm1);

  mm_simple( gs.data(), a.data(), b.data(), d1, d2, d3 );

  const double t_start_s = time_wall_fp();

  mm_simple( gs.data(), a.data(), b.data(), d1, d2, d3 );

  const double t_simple_end_s = time_wall_fp();
  const double dur_simple_s = t_simple_end_s - t_start_s;
  printf("Duration simple: %.3f ms\n",
         dur_simple_s * 1000 );

  mm_tiled_do<4>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );
  mm_tiled_do<8>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );
  mm_tiled_do<16>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );
  mm_tiled_do<32>( gt.data(), a.data(), b.data(), d1, d2, d3, gs.data() );

}


int
main( int argc, char **argv )
{
  mm_do(32,32,32);
  mm_do(128,128,128);
  mm_do(1024,1024,1024);
}