/// LSU EE 7722 GPU Microarchitecture
//
 /// Spring 2025
 /// Homework 5 -- SOLUTION
 //
 //  Assignment: https://www.ece.lsu.edu/koppel/gp/2025/hw05.pdf
 //
 //  Solution put in this file.
 //  Search for hw05 and SOLUTION to find solution.
 //
 //  Modified this file only.

#include <cuda_runtime.h>
#include <gp/cuda-gpuinfo.h>
#include <ptable.h>
#include <nperf.h>
#include <misc.h>
#include <mma.h>

#include <stdio.h>
#include <ranges>
#include <random>
#include <vector>
#include <algorithm>

using namespace nvcuda;

template<typename T1, typename T2>
requires is_integral<T1>::value && is_integral<T2>::value
__device__ __host__ constexpr T1
div_ceil( T1 a, T2 b ) { return (a+b-1)/b; }


typedef float elt_t;

struct Shape_Spec {
  const int A_nrows, A_ncols, B_ncols;
};

constexpr Shape_Spec shape_specs[] =
  { { 128,  32,  32768 }, // 0
    { 64,   32,  64    }, // 1  Easy to hand analyze
    { 5120, 2048+16, 4096 } // 2
  };

typedef uint32_t pClock_t;

constexpr size_t timing_item_size = 16;

struct __align__(timing_item_size) Timing_Item {
  pClock_t time_start;
  uint32_t smid_start;
  pClock_t time_end;
  uint32_t smid_end; // To detect preemption.
};
static_assert( timing_item_size == sizeof(Timing_Item) );

struct Timing { Timing_Item *timing_items; };
__constant__ Timing timing_dev;


class Shape_POD {
public:
  size_t A_nrows, A_ncols, B_ncols;
  size_t A_bytes, AT_bytes, B_bytes, BT_bytes, CT_bytes;
  size_t A_tf32_bytes, AT_tf32_bytes, B_tf32_bytes, BT_tf32_bytes;
  size_t A_fp16_bytes, AT_fp16_bytes, B_fp16_bytes, BT_fp16_bytes;
  elt_t *A_dev, *AT_dev, *B_dev, *BT_dev, *CT_dev;
  float *A_tf32_dev, *AT_tf32_dev, *B_tf32_dev, *BT_tf32_dev;
  half *A_fp16_dev, *AT_fp16_dev, *B_fp16_dev, *BT_fp16_dev;
};


__device__ uint32_t
smid_get()
{
  uint smid = 0;
  asm( "mov.u32 %0, %%smid;" : "=r" (smid) );
  return smid;
}

__device__ void
timing_start(const int wp)
{
  constexpr int wp_sz = 32;
  if ( threadIdx.x % wp_sz == 0 )
    {
      Timing_Item& ti = timing_dev.timing_items[wp];
      ti.time_start = clock();
      ti.smid_start = smid_get();
    }
  __syncwarp();
}
__device__ void
timing_end(const int wp)
{
  constexpr int wp_sz = 32;
  __syncwarp();
  if ( threadIdx.x % wp_sz == 0 )
    {
      Timing_Item& ti = timing_dev.timing_items[wp];
      ti.time_end = clock();
      ti.smid_end = smid_get();
    }
}


__constant__ Shape_POD shape_dev;


__global__ void M_transpose
( elt_t* __restrict__ MT_dev, const elt_t* __restrict__ M_dev,
  int M_nrows, int M_ncols )
{
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int num_threads = blockDim.x * gridDim.x;

  const int M_elts = M_ncols * M_nrows;

  for ( int i = tid; i < M_elts;  i += num_threads )
    {
      const int M_col = i % M_ncols;
      const int M_row = i / M_ncols;
      MT_dev[ M_row + M_col * M_nrows ] = M_dev[ i ];
    }
}


template< int t_A_nrows = 0, int t_A_ncols = 0 >
__global__ void mm_simple()
{
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Each column of C is assigned to a particular block.
  // Each block computes all rows of C.
  //
  // Let: S, # of SMs.
  //
  // Disadvantages
  //
  //   Data movement when reading A is at least S m k, which may be too high.
  //
  //   Workload imbalance when n/S < 1 or close to 1 but not equal to 1.
  //
  //   Two load instructions per FMADD instruction.
  //   Ideal Slowdown for CC 9.0:  ( 4 + 4 + 1 ) / 1 = 9 ..
  //   .. nine times slower than peak FP32 rate.
  //
  // Pedagogical Advantage
  //
  //   Easy to understand example of data distribution ..
  //   .. by using blockIdx to compute columns ..
  //   .. and using threadIdx but not blockIdx to compute rows.

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int wp = tid / wp_sz;

  const int A_ncols = t_A_ncols ?: ld.A_ncols;
  const int A_nrows = t_A_nrows ?: ld.A_nrows;
  const int B_ncols = ld.B_ncols;

  timing_start(wp);

  for ( int C_row = blockIdx.x; C_row < A_nrows; C_row += gridDim.x )
    for ( int C_col = threadIdx.x; C_col < B_ncols; C_col += blockDim.x )
      {
        elt_t C_accum = 0;

        for ( int A_col = 0; A_col < A_ncols; A_col++ )
          {
            elt_t A_elt = ld.A_dev[ C_row * A_ncols + A_col ];
            elt_t B_elt = ld.B_dev[ A_col * B_ncols + C_col ];
            C_accum += A_elt * B_elt;
          }

        ld.CT_dev[ C_row + C_col * A_nrows ] = C_accum;
      }

  timing_end(wp);
}


template< int t_A_nrows = 0, int t_A_ncols = 0 >
__global__ void mm_simple2()
{
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Adjacent threads in warp access values adjacent columns in row of C.
  //
  // Let: S, # of SMs.
  //
  // Disadvantages
  //
  //   Data movement when reading A can be >= S m k, which may be too high.
  //   Data movement when reading B can be >= S k n, which may be too high.
  //
  //   Two load instructions per FMADD instruction.
  //   Ideal Slowdown for CC 9.0:  ( 4 + 4 + 1 ) / 1 = 9 ..
  //   .. nine times slower than peak FP32 rate.
  //
  // Advantage Over mm_simple
  //
  //   Better load balance when n/S < 1.
  //

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int num_blocks = gridDim.x;
  const int num_threads = blockDim.x * num_blocks;
  const int wp = tid / wp_sz;

  const int A_ncols = t_A_ncols ?: ld.A_ncols;
  const int A_nrows = t_A_nrows ?: ld.A_nrows;
  const int B_ncols = ld.B_ncols;

  timing_start(wp);

  for ( int i = tid; true; i += num_threads )
    {
      const int C_col = i % B_ncols;
      const int C_row = i / B_ncols;
      if ( C_row >= A_nrows ) break;

      elt_t C_accum = 0;

      for ( int A_col = 0; A_col < A_ncols; A_col++ )
        {
          elt_t A_elt = ld.A_dev[ C_row * A_ncols + A_col ];
          elt_t B_elt = ld.B_dev[ A_col * B_ncols + C_col ];

          C_accum += A_elt * B_elt;
        }

      ld.CT_dev[ C_col * A_nrows + C_row ] = C_accum;
    }

  timing_end(wp);
}


template< int t_A_nrows = 0, int t_A_ncols = 0,
          int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd_simple()
{
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Each thread computes m_wd elements of C along a row ..
  // .. using one load of an element of A ..
  // .. to update m_wd elements of C.
  //
  // Adjacent threads in warp access values adjacent columns in row of C.
  //
  // Let: S, # of SMs.
  //
  // Disadvantages
  //
  //   Overhead of checking for an out-of-range column number (due to m_wd).
  //
  //   Data movement when reading A can be >= S m k, which may be too high.
  //   Data movement when reading B can be >= S k n, which may be too high.
  //
  //   (1+1/m_wd) load instructions per FMADD instruction.
  //   Ideal Slowdown for CC 9.0:  
  //       ( 4 + 4/m_wd + 1 ) / 1   
  //     m_wd -> 8
  //       ( 4 + 4/8 + 1 ) / 1    =  5.5
  //   .. 5.5 times slower than peak FP32 rate.
  //
  // Advantage Over mm_simple2
  //
  //   Fewer loads per FMADD: (1+1/m_wd) ..
  //   .. 5.5 slowdown is not as bad as 9.
  //
  // Advantages Over mm_simple
  //
  //
  //   Better load balance when n/S < 1.
  //

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int num_blocks = gridDim.x;
  const int num_threads = blockDim.x * num_blocks;
  const int wp = tid / wp_sz;

  const int A_ncols = t_A_ncols ?: ld.A_ncols;
  const int A_nrows = t_A_nrows ?: ld.A_nrows;
  const int B_ncols = ld.B_ncols;

  timing_start(wp);

  for ( int i = tid; true; i += num_threads )
    {
      const int C_row = i % A_nrows;
      const int C_col_0 = i / A_nrows * m_wd;
      if ( C_col_0 >= B_ncols ) break;

      elt_t C_accums[m_wd]{};

      for ( int A_col = 0; A_col < A_ncols; A_col++ )
        {
          elt_t B_elts[m_wd];
          for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
            if ( const int C_col = C_col_0 + l_wd;
                 C_col < B_ncols )
              B_elts[l_wd] = ld.BT_dev[ C_col*A_ncols + A_col ];

          elt_t A_elt = ld.AT_dev[ C_row + A_col * A_nrows ];
          for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
            C_accums[l_wd] += A_elt * B_elts[l_wd];
        }

      for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
        if ( const int C_col = C_col_0 + l_wd;
             C_col < B_ncols )
          ld.CT_dev[ C_col * A_nrows + C_row ] = C_accums[l_wd];
    };

  timing_end(wp);
}

template< int t_A_nrows = 0, int t_A_ncols = 0,
          int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd()
{
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Each thread computes m_wd elements of C along a row ..
  // .. using one load of an element of A ..
  // .. to update m_wd elements of C.
  //
  // Adjacent threads in warp compute adjacent sets of m_wd columns in row of C.
  //
  // Optimization.
  //
  // Let: S, # of SMs.
  //
  // Disadvantages
  //
  //   Data movement when reading A can be >= S m k, which may be too high.
  //   Data movement when reading B can be >= S k n, which may be too high.
  //
  //   (1+1/m_wd) load instructions per FMADD instruction.
  //   Ideal Slowdown for CC 9.0:  
  //       ( 4 + 4/m_wd + 1 ) / 1   
  //     m_wd -> 8
  //       ( 4 + 4/8 + 1 ) / 1    =  5.5
  //   .. 5.5 times slower than peak FP32 rate.
  //
  // Advantage Over mm_tile_wd
  //
  //   Fewer run-time checks for an out-of-range column.
  //
  // Advantages Over mm_simple
  //
  //   Fewer loads per FMADD: (1+1/m_wd) ..
  //   .. 5.5 slowdown is not as bad as 9.
  //
  //   Better load balance when n/S < 1.
  //

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int num_blocks = gridDim.x;
  const int num_threads = blockDim.x * num_blocks;
  const int wp = tid / wp_sz;

  const int A_ncols = t_A_ncols ?: ld.A_ncols;
  const int A_nrows = t_A_nrows ?: ld.A_nrows;
  const int B_ncols = ld.B_ncols;

  timing_start(wp);

  auto ii_iter = [&](int C_col_0, int C_row, bool C_col_check)
  {
    elt_t C_accums[m_wd]{};

    for ( int A_col = 0; A_col < A_ncols; A_col++ )
      {
        elt_t B_elts[m_wd];
        for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
          {
            const int C_col = C_col_0 + l_wd;
            const bool C_col_invalid = C_col_check && C_col >= B_ncols;
            B_elts[l_wd] =
              C_col_invalid ? 0 : ld.BT_dev[ C_col*A_ncols + A_col ];
          }

        elt_t A_elt = ld.AT_dev[ C_row + A_col * A_nrows ];
        for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
          C_accums[l_wd] += A_elt * B_elts[l_wd];
      }

    for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
      if ( const int C_col = C_col_0 + l_wd;  !C_col_check || C_col < B_ncols )
        ld.CT_dev[ C_col * A_nrows + C_row ] = C_accums[l_wd];
  };

  for ( int i = tid; true; i += num_threads )
    {
      const int C_row = i % A_nrows;
      const int C_col_0 = i / A_nrows * m_wd;

      if ( C_col_0 >= B_ncols ) break;

      if ( C_col_0 + m_wd <= B_ncols )
        ii_iter(C_col_0,C_row,false);  // Don't check if C_col out of range.
      else
        ii_iter(C_col_0,C_row,true);   // Do check if C_col out of range.
    }

  timing_end(wp);
}

template< int t_A_nrows = 0, int t_A_ncols = 0,
          int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd_ht()
{
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Each thread computes m_wd elements of C along a row ..
  // .. and m_ht elements of C along a column ..
  // .. using m_ht loads of elements of A ..
  // .. and   m_wd loads of elements of B ..
  // .. to update m_wd * m_ht elements of C.
  //
  // Adjacent threads in warp access adjacent sets of m_wd columns in C ..
  // .. in each of m_ht rows ..
  // .. with each thread's m_ht rows at a stride of sector length (8).
  //
  // Optimization.
  //
  // Let: S, # of SMs.
  //
  // Disadvantages
  //
  //   Overhead of checking for an out-of-range column number (due to m_wd).
  //
  //   Data movement when reading A can be >= S m k, which may be too high.
  //   Data movement when reading B can be >= S k n, which may be too high.
  //
  //   Workload imbalance due to each thread computing m_wd m_ht elements of C.
  //
  // Advantage Over mm_tile_wd2
  //
  //   (1/m_ht+1/m_wd) load instructions per FMADD instruction.
  //   Ideal Slowdown for CC 9.0:  
  //       ( 4/m_ht + 4/m_wd + 1 ) / 1   
  //     m_wd -> 8, m_ht -> 8,
  //       ( 4/8 + 4/8 + 1 ) / 1    =  2
  //   .. only 2 times slower than peak FP32 rate.
  //

  // Each thread computes m_wd (rows) by m_ht (columns) tiles of C.
  //
  // A tile is computed in ii_iter.
  //
  //   One execution of ii_iter computes m_wd * m_ht values of C.
  //
  //   The number of iterations in the outer loop of ii_iter (A_col)
  //   is A_ncols.
  //
  // This requires:
  //   Storage for m_wd * m_ht intermediate values of C (in C_accums).
  //

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int wp = tid / wp_sz;
  const int num_blocks = gridDim.x;

  const int A_ncols = t_A_ncols ?: ld.A_ncols;
  const int A_nrows = t_A_nrows ?: ld.A_nrows;
  const int B_ncols = ld.B_ncols;

  const int n_C_col_groups = div_ceil( B_ncols, m_wd );
  const int n_C_cg_per_block = div_ceil( n_C_col_groups, num_blocks );
  const int block_C_cg_start = blockIdx.x * n_C_cg_per_block;
  const int block_C_cg_stop = block_C_cg_start + n_C_cg_per_block;
  const int block_C_col_stop = min( B_ncols, block_C_cg_stop * m_wd );

  constexpr int sector_len = m_ht > 1 ? 32 : 1;
  // Chunk of A
  constexpr int ch_A_nrows = m_ht * sector_len;
  constexpr int n_ch_A = div_ceil( A_nrows, ch_A_nrows );

  timing_start(wp);

  auto ii_iter = [&]( int C_col_0, int C_row_0,
                      bool C_col_check, bool C_row_check)
  {
    elt_t C_accums[m_wd][m_ht]{};

    for ( int A_col = 0; A_col < A_ncols; A_col++ )
      {
        elt_t B_elts[m_wd];
        for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
          {
            const int C_col = C_col_0 + l_wd;
            const bool C_col_invalid = C_col_check && C_col >= block_C_col_stop;
            B_elts[l_wd] =
              C_col_invalid ? 0 : ld.BT_dev[ C_col*A_ncols + A_col ];
          }

        for ( int l_ht = 0; l_ht < m_ht;  l_ht++ )
          {
            const int C_row = C_row_0 + l_ht * sector_len;
            const bool C_row_invalid = C_row_check && C_row >= A_nrows;
            elt_t A_elt =
              C_row_invalid ? 0 : ld.AT_dev[ C_row + A_col * A_nrows ];

          for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
            C_accums[l_wd][l_ht] += A_elt * B_elts[l_wd];
          }
      }

    for ( int l_ht = 0; l_ht < m_ht;  l_ht++ )
      for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
        {
          const int C_col = C_col_0 + l_wd;
          const int C_row = C_row_0 + l_ht * sector_len;
          if ( ( !C_col_check || C_col < block_C_col_stop )
               && ( !C_row_check || C_row < A_nrows ) )
            ld.CT_dev[ C_col * A_nrows + C_row ] = C_accums[l_wd][l_ht];
        }
  };

  for ( int i = threadIdx.x; true; i += blockDim.x )
    {
      int ie = i;
      const int C_row_lo = ie % sector_len;
      ie /= sector_len;
      const int C_row_hi = ie % n_ch_A;
      ie /= n_ch_A;
      const int C_row_0 = C_row_hi * ch_A_nrows + C_row_lo;

      const int i_cg_0 = block_C_cg_start + ie;
      const int C_col_0 = i_cg_0 * m_wd;

      const bool no_work = C_col_0 >= block_C_col_stop || C_row_0 >= A_nrows;
      const bool need_check = !no_work
        && ( C_col_0 + m_wd > block_C_col_stop
             || C_row_0 + ch_A_nrows > A_nrows );
      const uint32_t wp_mask = 0xffffffff;

      const bool d_check = m_wd > 1 || m_ht > 1;
      if ( !d_check && C_col_0 >= block_C_col_stop ) break;

      const uint32_t we_check =
        d_check ? __ballot_sync(wp_mask, need_check) : need_check;

      if ( no_work )
        { /* This part intentionally left blank. */ }
      else if ( we_check == 0 )
        ii_iter(C_col_0,C_row_0,false,false);
      else
        ii_iter(C_col_0,C_row_0,true,A_nrows!=ch_A_nrows*n_ch_A);

      if ( !d_check ) continue;

      const bool im_done = C_col_0 >= block_C_col_stop;
      if ( __all_sync(wp_mask,im_done) ) break;
    }

  timing_end(wp);
}




template< int t_A_nrows = 0, int t_A_ncols = 0,
          int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd_ht_dp()
{
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Each group of m_dp threads computes m_wd elements of C along a row ..
  // .. and m_ht elements of C along a column ..
  // .. using m_ht loads of elements of A ..
  // .. and   m_wd loads of elements of B ..
  // .. to update m_wd * m_ht elements of C ..
  // .. when the k loop is completed m_dp partial sums of C are added
  // .. and then written to C by on of the m_dp threads.
  //
  // Adjacent threads in warp access adjacent sets of m_wd columns in C ..
  // .. in each of m_ht rows ..
  // .. with each thread's m_ht rows at a stride of sector length (8).]
  //
  // Assignment of rows and columns to a thread chosen to reduce data
  // movement of A and B to SMs.
  //
  // Let: S, # of SMs.
  //
  // Disadvantage
  //
  //   Looks overdone.
  //
  // Advantages Over mm_tile_wd_ht
  //
  //   Reduced workload imbalance because m_dp times as many threads can
  //   be kept busy than in mm_tile_wd_ht.
  //
  //   Reduced data movement.

  // Each group of m_dp threads computes m_wd (rows) by m_ht (columns)
  // tiles of C.
  //
  // A tile is computed in ii_iter.
  //
  //   In one execution of ii_iter ..
  //   .. each thread computes m_wd * m_ht *PARTIAL* values of C ..
  //   .. and then the partial values of the m_dp threads are added together ..
  //   .. to get the complete values of C.
  //
  //   The number of iterations in the outer loop of ii_iter (A_col)
  //   is A_ncols / m_dp.
  //
  // This requires:
  //   Storage for m_wd * m_ht intermediate values of C (in C_accums).
  //   (The m_pa dimension is only used if m_wd = m_ht = 1.)
  //
  // Benefits:
  //   Each load of a value of A is reused m_wd times.
  //   Each load of a value of B is reused m_ht times.
  //   Better load balance.

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int lane = threadIdx.x % wp_sz;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int wp = tid / wp_sz;
  const int num_blocks = gridDim.x;

  const int A_ncols = t_A_ncols ?: ld.A_ncols;
  const int A_nrows = t_A_nrows ?: ld.A_nrows;
  const int B_ncols = ld.B_ncols;

  //  In this routine m_dp threads compute m_wd columns and m_ht rows
  //  of the output matrix per iteration.

  const int sub_lane = lane % m_dp;
  const bool use_AT = m_dp < 8;

  const int m_ht_stride_lg =
    m_ht == 1 ? 0 :
    use_AT ? 5 :
    m_dp < 32 ? 3 : 0;
  const int m_ht_stride = 1 << m_ht_stride_lg;

  const uint ch_A_nrows_min = m_ht > 1 ? m_ht_stride * m_ht : 32;
  const int ch_A_nrows_min_lg = bit_width(ch_A_nrows_min-1);

  const uint ch_C_nelts = blockDim.x * m_wd * m_ht / m_dp;
  const int ch_C_nelts_lg = bit_width(ch_C_nelts-1);
  assert( 1 << ch_C_nelts_lg == ch_C_nelts );

  constexpr int m_wd_lg = bit_width( unsigned( m_wd ) - 1 );
  const int ch_A_nrows_lg =
    min( ch_C_nelts_lg - m_wd_lg,
         max( ch_A_nrows_min_lg, ch_C_nelts_lg / 2) );
  const int ch_A_nrows = 1 << ch_A_nrows_lg;
  const int ch_B_ncols = 1 << ch_C_nelts_lg - ch_A_nrows_lg;
  assert( ch_A_nrows >= m_ht );
  assert( ch_B_ncols >= m_wd );
  assert( ch_A_nrows * ch_B_ncols == ch_C_nelts );
  const int ch_A_nthds = ch_A_nrows / m_ht;
  assert( ch_A_nthds * m_ht == ch_A_nrows );
  const int ch_B_nthds = ch_B_ncols / m_wd;
  assert( ch_A_nthds * ch_B_nthds * m_dp == blockDim.x );
  assert( ch_B_nthds * m_wd == ch_B_ncols );
  const int n_ch_A_floor = max( 1, A_nrows / ch_A_nrows );
  const int n_ch_B_floor = max( 1, B_ncols / ch_B_ncols );

  const int num_blocks_sqrt = sqrt(num_blocks) + 0.5f;
  int util_max = 0;
  int util_max_fac1 = 0, util_max_fac2 = 0;
  const bool n_ch_A_smaller = n_ch_A_floor <= n_ch_B_floor;
  const int n_ch_X_min = min( n_ch_A_floor, n_ch_B_floor );
  const int n_ch_X_max = max( n_ch_A_floor, n_ch_B_floor );
  for ( int i = 1;  i <= num_blocks_sqrt; i++ )
    {
      const int fac2 = num_blocks / i;
      int util = min(i,n_ch_X_min) * min(fac2,n_ch_X_max);
      if ( util < util_max ) continue;
      util_max = util;
      util_max_fac1 = i;
      util_max_fac2 = fac2;
    }
  const int C_row_blk_mod = n_ch_A_smaller ? util_max_fac1 : util_max_fac2;
  const int C_col_blk_mod = n_ch_A_smaller ? util_max_fac2 : util_max_fac1;

  timing_start(wp);
  constexpr int m_pa = m_dp > 1 ? 1 : m_ht * m_wd == 1 ? 4 : 1;

  auto ii_iter = [&]( int C_col_0, int C_row_0,
                      bool C_col_check, bool C_row_check)
  {
    elt_t C_accums[m_pa][m_wd][m_ht]{};
    const uint32_t am = m_dp > 1 ? __activemask() : 0;

    for ( int A_col_o = 0; A_col_o < A_ncols; A_col_o += m_dp )
      {
        const int plane = C_col_check ? 0 : A_col_o % m_pa;
        const int A_col = A_col_o + sub_lane;
        if ( A_ncols/m_dp*m_dp != A_ncols && A_col >= A_ncols ) break;
        elt_t B_elts[m_wd];
        for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
          {
            const int C_col = C_col_0 + l_wd;
            const bool C_col_invalid = C_col_check && C_col >= B_ncols;
            B_elts[l_wd] =
              C_col_invalid ? 0 : ld.BT_dev[ C_col*A_ncols + A_col ];
          }

        for ( int l_ht = 0; l_ht < m_ht;  l_ht++ )
          {
            const int C_row = C_row_0 + l_ht * m_ht_stride;
            const bool C_row_invalid = C_row_check && C_row >= A_nrows;
            elt_t A_elt =
              C_row_invalid ? 0 :
              use_AT ? ld.AT_dev[ C_row + A_col * A_nrows ]
              :        ld.A_dev[  A_col + C_row * A_ncols ];

            for ( int l_wd = 0;  l_wd < m_wd;  l_wd++ )
              C_accums[plane][l_wd][l_ht] += A_elt * B_elts[l_wd];
          }
      }

    for ( int l_ht = 0; l_ht < m_ht;  l_ht++ )
      for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
        {
          elt_t C_accum = C_accums[0][l_wd][l_ht];
          for ( int l_pa = 1; l_pa < m_pa; l_pa++ )
            C_accum += C_accums[l_pa][l_wd][l_ht];
          for ( int dist = 1; dist < m_dp; dist *= 2 )
            C_accum += __shfl_down_sync( am, C_accum, dist, m_dp);

          const int C_col = C_col_0 + l_wd;
          const int C_row = C_row_0 + l_ht * m_ht_stride;
          if ( sub_lane == 0
               && ( !C_col_check || C_col < B_ncols )
               && ( !C_row_check || C_row < A_nrows ) )
            ld.CT_dev[ C_col * A_nrows + C_row ] = C_accum;
        }
  };

  const int tid_um = threadIdx.x / m_dp;
  const int tid_um_lo = tid_um % ch_A_nthds;
  const int C_row_lo = tid_um_lo % m_ht_stride;
  const int C_row_lmid = ( tid_um_lo - C_row_lo ) * m_ht;
  const int C_row_mid = blockIdx.x % C_row_blk_mod;
  const int C_row_thd = C_row_lo + C_row_lmid + C_row_mid * ch_A_nrows;

  const int C_col_lo = tid_um / ch_A_nthds;
  const int C_col_mid = blockIdx.x / C_row_blk_mod;
  const int C_col_thd = C_col_lo * m_wd + C_col_mid * ch_B_ncols;
  if ( C_col_mid >= C_col_blk_mod ) { timing_end(wp); return; }
  assert( C_row_thd < C_row_blk_mod * ch_A_nrows );
  assert( C_col_thd < C_col_blk_mod * ch_B_ncols );

  const uint32_t wp_mask = 0xffffffff;

  const int A_nrows_r =
    m_ht > 1
    ? div_ceil( A_nrows, m_ht_stride * m_ht ) * m_ht_stride * m_ht : A_nrows;

  for ( int i = 0; true; i++ )
    {
      const int C_row_0_raw = C_row_thd + i * C_row_blk_mod * ch_A_nrows;
      const int C_row_0 = C_row_0_raw % A_nrows_r;
      const int iC_col = C_row_0_raw / A_nrows_r;

      const int C_col_0 = C_col_thd + iC_col * C_col_blk_mod * ch_B_ncols;
      const bool thd_no_C_col = C_col_0 >= B_ncols;

      const bool no_work = C_col_0 >= B_ncols || C_row_0 >= A_nrows;
      const bool need_check = !no_work
        && ( C_col_0 + m_wd > B_ncols
             || C_row_0 + ( m_ht - 1 ) * m_ht_stride >= A_nrows );

      const bool d_check = true && ( m_wd > 1 || m_ht > 1 );
      if ( !d_check && C_col_0 >= B_ncols ) break;

      const uint32_t we_check =
        d_check ? __ballot_sync(wp_mask, need_check) : need_check;

      if ( d_check && !we_check && __all_sync(wp_mask,thd_no_C_col) ) break;

      if ( no_work )
        { /* This part intentionally left blank. */ }
      else if ( we_check == 0 )
        ii_iter(C_col_0,C_row_0,false,false);
      else
        ii_iter(C_col_0,C_row_0,true,true);

    }

  timing_end(wp);
}


__global__ void
convert_float_tf32()
{
  // Convert A and B arrays from float to tf32.
  //
  const Shape_POD& ld = shape_dev;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int num_threads = blockDim.x * gridDim.x;

  const int A_ncols = ld.A_ncols;
  const int A_nrows = ld.A_nrows;
  const int B_ncols = ld.B_ncols;
  const int A_nelts = A_nrows * A_ncols;
  const int B_nelts = A_ncols * B_ncols;

  using namespace nvcuda::wmma;

  for ( int i = tid; i < A_nelts; i += num_threads )
    {
      ld.A_tf32_dev[i] = __float_to_tf32( ld.A_dev[i] );
      ld.AT_tf32_dev[i] = __float_to_tf32( ld.AT_dev[i] );
      ld.A_fp16_dev[i] = __float2half( ld.A_dev[i] );
      ld.AT_fp16_dev[i] = __float2half( ld.AT_dev[i] );
    }

  __syncwarp();

  for ( int i = tid; i < B_nelts; i += num_threads )
    {
      ld.B_tf32_dev[i] = __float_to_tf32( ld.B_dev[i] );
      ld.BT_tf32_dev[i] = __float_to_tf32( ld.BT_dev[i] );
      ld.B_fp16_dev[i] = __float2half( ld.B_dev[i] );
      ld.BT_fp16_dev[i] = __float2half( ld.BT_dev[i] );
    }
}

using wmma::precision::tf32;

template< int t_A_nrows = 0, int t_A_ncols = 0, int t_B_ncols = 0,
          int m_wd_n = 1, int m_ht_n = 1, typename AB_precision = half >
__global__ void
mm_tc()
{
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Each warp computes a m_ht rows by m_wd column block of C ..
  // .. using tensor cores ..
  // .. elements of A and B are of type (precision) AB_precision (see template).
  //
  // Adjacent warps (consecutive by warp number) access adjacent blocks
  // along a column.
  //
  // Advantages Over mm_tile_wd_ht_dp
  //
  //   Much faster computation due to higher FP throughput of tensor
  //   core instructions.
  //
  // Disadvantages
  //
  //   CC 9.0 does not support float type for elements.
  //
  //   Bank conflicts aggravate load times.
  //
  //   Can improve load efficiency.
  //
  //   Can reduce amount of data read.

  // Multiply Matrices Using Tensor Cores
  //
  // This kernel written for a variety of matrix shapes and element
  // precisions.  This might make the code more difficult to read. To
  // see a simplified version of this kernel look at the
  // mm_tc< 2048, 2048, 4096, 16, 16, 16, half > specialization that
  // follows this kernel.

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int wp = tid / wp_sz;
  const int n_warps = blockDim.x * gridDim.x / wp_sz;

  constexpr int A_ncols = t_A_ncols;
  constexpr int A_nrows = t_A_nrows;
  constexpr int B_nrows = A_ncols;
  constexpr int B_ncols = t_B_ncols;
  constexpr int C_nrows = A_nrows;

  assert( A_nrows == ld.A_nrows );
  assert( B_ncols == ld.B_ncols );

  using namespace nvcuda::wmma;

  // Extract data type used for matrix storage, and the data type used
  // to declare matrix elements in this routine.
  //
  using ab_sto_t = helper_traits<AB_precision>::storage_element_type;
  using ab_elt_t = helper_traits<AB_precision>::element_type;

  constexpr bool h = is_same_v<ab_sto_t,half>;

  constexpr int tm = 16;  // Number of A and C rows in tile.
  constexpr int tn = 16;  // Number of B and C columns in tile.
  constexpr int tk = h ? 16 : 8; // Number of A columns and B rows in tile.

  assert( A_nrows % tm == 0 );
  assert( B_ncols % tn == 0 );
  assert( A_ncols % tk == 0 );

  timing_start(wp);

  constexpr int C_nrow_tiles = C_nrows / tm;
  constexpr int m_wd = tm * m_wd_n;

  using aorg = row_major;
  using borg = col_major;
  constexpr bool a_row_major = is_same_v<aorg,row_major>;
  constexpr bool b_row_major = is_same_v<borg,row_major>;

  typedef size_t idx_t;

  const idx_t a_row_stride = a_row_major ? A_ncols : 1;
  const idx_t a_col_stride = a_row_major ? 1       : A_nrows;
  const idx_t b_row_stride = b_row_major ? B_ncols : 1;
  const idx_t b_col_stride = b_row_major ? 1       : B_nrows;
  const idx_t a_stride = max(a_row_stride,a_col_stride);
  const idx_t b_stride = max(b_row_stride,b_col_stride);

  ab_sto_t* const a_ptr =
    a_row_major
    ? ( h ? (ab_sto_t*)ld.A_fp16_dev  : (ab_sto_t*)ld.A_tf32_dev  )
    : ( h ? (ab_sto_t*)ld.AT_fp16_dev : (ab_sto_t*)ld.AT_tf32_dev );

  ab_sto_t* const b_ptr =
    b_row_major
    ? ( h ? (ab_sto_t*)ld.B_fp16_dev  : (ab_sto_t*)ld.B_tf32_dev  )
    : ( h ? (ab_sto_t*)ld.BT_fp16_dev : (ab_sto_t*)ld.BT_tf32_dev );

  for ( int i = wp; true; i += n_warps )
    {
      const ssize_t C_row_0 = i % C_nrow_tiles * tm;
      const int C_col_0_raw = i / C_nrow_tiles * m_wd;
      if ( C_col_0_raw >= B_ncols ) break;
      const bool partial_cols = C_col_0_raw + m_wd > B_ncols;
      const ssize_t C_col_0 = partial_cols ? B_ncols - m_wd : C_col_0_raw;

      // Declare variables used to hold tiles.
      //
      fragment<matrix_a,    tm, tn, tk, ab_elt_t, aorg> tile_a;
      fragment<matrix_b,    tm, tn, tk, ab_elt_t, borg> tile_b;
      fragment<accumulator, tm, tn, tk, float> tile_C_acc[m_wd_n];

      for ( auto& tca: tile_C_acc ) fill_fragment(tca, 0);

      for ( ssize_t i_k = 0; i_k < A_ncols; i_k += tk )
        {
          load_matrix_sync
            ( tile_a,  a_ptr + C_row_0 * a_row_stride + i_k * a_col_stride,
              a_stride );

          for ( ssize_t i_wd = 0; i_wd < m_wd_n; i_wd++ )
            {
              load_matrix_sync
                ( tile_b,  b_ptr + i_k * b_row_stride
                  + ( C_col_0 + i_wd * tn ) * b_col_stride,
                  b_stride );

              mma_sync( tile_C_acc[i_wd], tile_a, tile_b, tile_C_acc[i_wd] );
            }
        }
      for ( int i_wd = 0; i_wd < m_wd_n; i_wd++ )
        store_matrix_sync
          ( &ld.CT_dev[ C_row_0 + ( C_col_0 + i_wd * tn ) * A_nrows ],
            tile_C_acc[i_wd], A_nrows, mem_col_major );
    }

  timing_end(wp);
}

template<>
__global__ void
mm_tc< 2048, 2064, 4096, 1, 1, half >()
{
  // This specialization is easier to read than general template above.

  constexpr int A_nrows = 2048, A_ncols = 2048, B_ncols = 4096;
  constexpr int C_nrows = A_nrows, C_ncols = B_ncols, B_nrows = A_ncols;

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int wp = tid / wp_sz;
  const int n_warps = blockDim.x * gridDim.x / wp_sz;

  assert( A_nrows == ld.A_nrows );

  // Shape (dimensions) of tile computed by this warp by tensor core.
  //
  constexpr int tm = 16, tk = 16, tn = 16;
  //
  // The tensor cores can only compute specific shapes. The shapes
  // that are available depend on the data type (which is half [fp16])
  // in this example.

  assert( A_ncols % tk == 0 );
  assert( A_nrows % tm == 0 );

  timing_start(wp);

  using namespace nvcuda::wmma;

  constexpr int C_nrow_tiles = C_nrows / tm;

  // Convenient pointers to the input arrays.
  //
  half* const a_fp16 = &ld.A_fp16_dev[ 0 ];
  half* const bt_fp16 = &ld.BT_fp16_dev[ 0 ];
  //
  // Note: "bt" is a column major (transposed) version of the b
  // matrix.

  constexpr int n_tiles = C_nrows / tm * C_ncols / tn;

  for ( int i = wp;  i < n_tiles;  i += n_warps )
    {
      const ssize_t C_row_0 = i % C_nrow_tiles * tm;
      const ssize_t C_col_0 = i / C_nrow_tiles * tn;

      fragment<matrix_a, tm, tn, tk, half, row_major> tile_a;
      fragment<matrix_b, tm, tn, tk, half, col_major> tile_b;
      fragment<accumulator, tm, tn, tk, float> tile_C_acc;

      fill_fragment(tile_C_acc, 0);

      for ( ssize_t i_k = 0; i_k < A_ncols; i_k += tk )
        {
          load_matrix_sync
            ( tile_a,  a_fp16 + C_row_0 * A_ncols + i_k,  A_ncols );

          load_matrix_sync
            ( tile_b,  bt_fp16  + C_col_0 * B_nrows + i_k,  B_nrows );

          mma_sync( tile_C_acc, tile_a, tile_b, tile_C_acc );
        }

      store_matrix_sync
        ( &ld.CT_dev[ C_row_0 + C_col_0 * A_nrows ],
          tile_C_acc, A_nrows, mem_col_major );
    }

  timing_end(wp);
}


template< int t_A_nrows = 0, int t_A_ncols = 0, int t_B_ncols = 0,
          int m_wd_n = 1, int m_ht_n = 1, typename AB_precision = half >
__global__
// Launch bounds improve register use in larger kernels.
__launch_bounds__( m_wd_n*m_ht_n >= 16 ? 256 : m_wd_n*m_ht_n >= 4 ? 512 : 0 )
void
mm_hw05()
{
  /// SOLUTION in this routine.
  //
  // Compute C = A * B,
  //
  //   A is m rows by k columns,
  //   B is k columns by n rows.
  //   C must then be m rows by k columns.
  //
  // Each warp computes a m_ht rows by m_wd column block of C ..
  // .. using tensor cores ..
  // .. elements of A and B are of type (precision) AB_precision (see template).
  //
  // HW05 Solution -- Tiles assigned to warps within a block form
  // an approximate square over output matrix c.
  //

  const Shape_POD& ld = shape_dev;
  constexpr int wp_sz = 32;
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  const int wp = tid / wp_sz;

  constexpr int A_ncols = t_A_ncols;
  constexpr int A_nrows = t_A_nrows;
  constexpr int B_nrows = A_ncols;
  constexpr int B_ncols = t_B_ncols;
  constexpr int C_nrows = A_nrows;
  constexpr int C_ncols = B_ncols;

  assert( A_nrows == ld.A_nrows );
  assert( B_ncols == ld.B_ncols );

  using namespace nvcuda::wmma;

  // Extract data type used for matrix storage, and the data type used
  // to declare matrix elements in this routine.
  //
  using ab_sto_t = helper_traits<AB_precision>::storage_element_type;
  using ab_elt_t = helper_traits<AB_precision>::element_type;

  constexpr bool h = is_same_v<ab_sto_t,half>;

  constexpr int tm = 16;  // Number of A and C rows in tile.
  constexpr int tn = 16;  // Number of B and C columns in tile.
  constexpr int tk = h ? 16 : 8; // Number of A columns and B rows in tile.

  assert( A_nrows % tm == 0 );
  assert( B_ncols % tn == 0 );
  assert( A_ncols % tk == 0 );

  timing_start(wp);

  using aorg = row_major;
  using borg = col_major;
  constexpr bool a_row_major = is_same_v<aorg,row_major>;
  constexpr bool b_row_major = is_same_v<borg,row_major>;

  ab_sto_t* const a_ptr =
    a_row_major
    ? ( h ? (ab_sto_t*)ld.A_fp16_dev  : (ab_sto_t*)ld.A_tf32_dev  )
    : ( h ? (ab_sto_t*)ld.AT_fp16_dev : (ab_sto_t*)ld.AT_tf32_dev );

  ab_sto_t* const b_ptr =
    b_row_major
    ? ( h ? (ab_sto_t*)ld.B_fp16_dev  : (ab_sto_t*)ld.B_tf32_dev  )
    : ( h ? (ab_sto_t*)ld.BT_fp16_dev : (ab_sto_t*)ld.BT_tf32_dev );

  typedef size_t idx_t;

  const idx_t a_row_stride = a_row_major ? A_ncols : 1;
  const idx_t a_col_stride = a_row_major ? 1       : A_nrows;
  const idx_t b_row_stride = b_row_major ? B_ncols : 1;
  const idx_t b_col_stride = b_row_major ? 1       : B_nrows;
  const idx_t a_stride = max(a_row_stride,a_col_stride);
  const idx_t b_stride = max(b_row_stride,b_col_stride);

  /// Homework 5 -- Should not need to modify code above.
  //
  //  Most of the solution goes below.
  //  Declarations above do not have to be changed, but feel free
  //  to experiment.

  /// SOLUTION -- Problems 1 and 2 in this kernel.

  constexpr int m_wd = tn * m_wd_n;
  constexpr int m_ht = tm * m_ht_n;

  static_assert( m_wd % tn == 0 );
  static_assert( m_ht % tm == 0 );

  constexpr uint warp_tile_nrows = tm * m_ht_n;
  constexpr uint warp_tile_ncols = tn * m_wd_n;
  constexpr int warp_tile_nrows_bits = bit_width( warp_tile_nrows - 1 );
  constexpr int warp_tile_ncols_bits = bit_width( warp_tile_ncols - 1 );

  /// SOLUTION -- Problem 2
  //
  // Assign warps in the block to form a square (or close to it) on
  // matrix C.
  //
  // Terminology:
  //
  //   Warp Tile: the part of c that a warp operates on in one
  //   iteration of the i loop below. A warp tile has warp_tile_nrows
  //   (tm * m_ht_n) rows and warp_tile_ncols (tn * m_wd_n) columns.
  //   These values were computed above.
  //   
  //   Block Tile: the part of c that an entire block operates on in
  //   one iteration of the i loop below. A block tile has
  //   n_wps_row * warp_tile_nrows rows and n_wps_col * warp_tile_ncols
  //   columns. Variables n_wps_row and n_wps_col are computed below.
  //
  // Solution takes advantage of the fact that number of warps per
  // block is a power of 2.
  //
  // The computation of n_wps_row and n_wps_col shown below takes into
  // account the fact that the portion of C covered by a warp may be
  // wide ( m_wd > m_ht ), narrow ( m_wd < m_ht ), or square ( m_wd ==
  // m_ht ).
  //
  const uint n_wps_p_bl = blockDim.x / wp_sz;
  const int n_wps_p_bl_bits = bit_width( n_wps_p_bl - 1 );

  const int block_tile_n_wp_rows_bits =
    max( ( n_wps_p_bl_bits >> 1 )
         - warp_tile_nrows_bits + warp_tile_ncols_bits, 0 );

  // Number of columns measured in units of warps.
  const int block_tile_n_wp_cols =
    max( 1, n_wps_p_bl >> block_tile_n_wp_rows_bits );
  const int block_tile_n_wp_rows = n_wps_p_bl / block_tile_n_wp_cols;
  assert( n_wps_p_bl == block_tile_n_wp_cols * block_tile_n_wp_rows );

  const int block_tile_nrows = warp_tile_nrows * block_tile_n_wp_rows;
  const int block_tile_ncols = warp_tile_ncols * block_tile_n_wp_cols;

  const int blk_wp_idx = threadIdx.x / wp_sz;
  const int C_row_0_offset =
    ( blk_wp_idx / block_tile_n_wp_cols ) * warp_tile_nrows;
  const int C_col_0_offset =
    ( blk_wp_idx % block_tile_n_wp_cols ) * warp_tile_ncols;
  //
  // C_row_0_offset, C_col_0_offset is the position of the warp
  // tile relative to the block tile.

  const int n_block_tiles_col = div_ceil( C_ncols, block_tile_ncols );

  constexpr bool partial_cols_possible = C_ncols % warp_tile_ncols;
  static_assert( !partial_cols_possible );

  for ( int i = blockIdx.x; true; i += gridDim.x )
    {
      const int C_col_0_block = i % n_block_tiles_col * block_tile_ncols;
      const int C_row_0_block = i / n_block_tiles_col * block_tile_nrows;
      //
      // C_row_0_block, C_col_0_block is the upper left of the block
      // tile.

      const int C_col_0_raw = C_col_0_block + C_col_0_offset;
      const int C_row_0_raw = C_row_0_block + C_row_0_offset;

      if ( C_col_0_raw >= C_ncols ) continue;
      if ( C_row_0_raw >= C_nrows ) break;

      // This modulo operations are purely for the compiler's sake.
      const int C_col_0 = C_col_0_raw % C_ncols;
      const int C_row_0 = C_row_0_raw % C_nrows;

      // Declare 2D array holding output tiles ..
      //
      fragment<accumulator, tm, tn, tk, float> tile_C_acc[m_ht_n][m_wd_n];
      //
      // .. and initialize to 0.
      //
      for ( auto& tca: tile_C_acc )
        for ( auto& tcb: tca ) fill_fragment(tcb, 0);

      for ( ssize_t i_k = 0; i_k < A_ncols; i_k += tk )
        {
          /// SOLUTION -- Problem 1
          //
          // First, load m_ht_n tiles of a into tile_a.
          //
          fragment<matrix_a, tm, tn, tk, ab_elt_t, aorg> tile_a[m_ht_n];
          //
          for ( size_t i_ht = 0; i_ht < m_ht_n; i_ht++ )
            load_matrix_sync
              ( tile_a[i_ht],
                a_ptr + ( C_row_0 + i_ht * tm ) * a_row_stride
                + i_k * a_col_stride,
                a_stride );

          // Next, multiply each tile_b tile by the m_ht_n tiles of a.
          //
          for ( size_t i_wd = 0; i_wd < m_wd_n; i_wd++ )
            {
              // Load one tile of b ..
              //
              fragment<matrix_b, tm, tn, tk, ab_elt_t, borg> tile_b;
              //
              load_matrix_sync
                ( tile_b,  b_ptr + i_k * b_row_stride
                  + ( C_col_0 + i_wd * tn ) * b_col_stride,
                  b_stride );
              //
              // .. and use it to compute m_ht_n tiles of c.
              //
              for ( int i_ht = 0; i_ht < m_ht_n; i_ht++ )
                mma_sync( tile_C_acc[i_ht][i_wd], tile_a[i_ht], tile_b,
                          tile_C_acc[i_ht][i_wd] );
            }
        }

      // Write completed tiles of c to memory.
      //
      for ( int i_wd = 0; i_wd < m_wd_n; i_wd++ )
        for ( ssize_t i_ht = 0; i_ht < m_ht_n; i_ht++ )
          store_matrix_sync
            ( &ld.CT_dev[ C_row_0 + i_ht * tm
                          + ( C_col_0 + i_wd * tn ) * A_nrows ],
              tile_C_acc[i_ht][i_wd], A_nrows, mem_col_major );
    }

  timing_end(wp);
}


GPU_Info
print_gpu_and_kernel_info()
{
  GPU_Info info;

  gpu_info_print();

  // Choose GPU 0 because it's usually the better choice.
  //
  int dev = gpu_choose_index();
  CE(cudaSetDevice(dev));
  printf("Using GPU %d\n",dev);
  info.get_gpu_info(dev);

  return info;
}

enum AB_Format { AB_elt_t, AB_tf32, AB_bf16, AB_fp16, AB_ENUM_SIZE };
const char* ab_format_txt[] = { "FP32", "TF32", "BF16", "FP16" };

class Shape : public Shape_POD {
public:

  vector<clock_t> clock_start, clcok_end;
  vector<elt_t> A, AT, B, BT, CT;
  // Computed by the CPU (to check results).
  vector<elt_t> CT_cpu;
  bool verbose;
  int64_t est_mm_fp;       // Number of FP ops (madd=1) needed.
  int64_t est_mm_ls_ops;   // Number of load and store instructions.
  int64_t est_mm_l_bytes;  // Minimum amount of data loaded, bytes.
  int64_t est_mm_s_bytes;  // Minimum amount of data stored, bytes.

  Shape(const Shape_Spec& ss)
  {
    A_nrows = ss.A_nrows;
    A_ncols = ss.A_ncols;
    B_ncols = ss.B_ncols;
    A_dev = nullptr; verbose = false;
    setup();
  };

  void setup()
  {
    assert( A_nrows && A_ncols && B_ncols );
    assert( !A_dev );

    /// Allocate
    //
    size_t A_size = A_nrows * A_ncols;

    A.resize(A_size);
    AT.resize(A_size);

    size_t B_size = A_ncols * B_ncols;
    size_t CT_size = A_nrows * B_ncols;

    B.resize( B_size );
    BT.resize( B_size );
    CT_cpu.resize( CT_size );

    /// Initialize Randomly
    //
    const int seed = 2735;
    default_random_engine re(seed);
    uniform_real_distribution<elt_t> uni_pm1(-1,1);
    auto rand_pm1 = [&]() { return uni_pm1(re); };

    if ( verbose )
      printf("Initializing input and weight vectors randomly.\n");

    ranges::generate(B,rand_pm1);
    ranges::generate(A, [&](){ return uni_pm1(re)/A_ncols; } );


    /// Compute
    //

    // Compute matrix product
    //
#pragma omp parallel for
    for ( size_t C_col = 0; C_col < B_ncols; C_col++ )
      for ( size_t C_row = 0; C_row < A_nrows; C_row++ )
        {
          elt_t C_accum = 0;
          for ( size_t A_col = 0; A_col < A_ncols; A_col++ )
            C_accum += A[ C_row * A_ncols + A_col ]
              * B[ A_col * B_ncols + C_col ];
          CT_cpu[ C_row + C_col * A_nrows ] = C_accum;
        }

    // Allocate CUDA storage.

    // These are for values copied back from the GPU.
    CT.resize( CT_size );

    if ( verbose )
      printf("Preparing CUDA storage.\n");

#   define CMALc(var, n_elts)                               \
    var##_bytes = n_elts * sizeof( var##_dev[0] );         \
    CE( cudaMalloc( &var##_dev, var##_bytes ) );

#   define CMALC(var) CMALc(var,var.size())

    CMALC( A ); CMALC( AT );
    CMALC( B ); CMALC( BT );
    CMALc( A_tf32, A_size ); CMALc( AT_tf32, A_size );
    CMALc( B_tf32, B_size ); CMALc( BT_tf32, B_size );
    CMALc( A_fp16, A_size ); CMALc( AT_fp16, A_size );
    CMALc( B_fp16, B_size ); CMALc( BT_fp16, B_size );
    CMALC( CT );

#   undef CMALC
#   undef CMALc

    // Should be done each time input and matrices change, which for
    // here is once. Also, what's the point of having a CPU copy
    // of the transposed A and B matrices?

    CE( cudaMemcpyAsync
        ( B_dev, B.data(), B_bytes, cudaMemcpyHostToDevice ) );
    CE( cudaMemcpyAsync
        ( A_dev, A.data(), A_bytes, cudaMemcpyHostToDevice ) );

    /// Estimate Resources Needed
    //
    // Note: A_bytes, ... computed by CMALC.
    est_mm_fp = A_nrows * A_ncols * B_ncols;
    est_mm_s_bytes = CT_bytes;
    est_mm_l_bytes = int64_t(A_bytes) + B_bytes;
    est_mm_ls_ops = B_ncols * A_nrows * ( A_ncols * 2 + 1 );
  }

  void launch_prep()
  {
    // Prepare for a kernel launch, resetting intermediate data.
    CE( cudaMemsetAsync( CT_dev, 0, CT_bytes ) );

    Shape_POD for_dev(*this);
    CE( cudaMemcpyToSymbol
        (shape_dev,&for_dev,sizeof(for_dev),0,cudaMemcpyHostToDevice) );
  }

  void from_dev_and_check(AB_Format ab_fmt)
  {
    CE( cudaMemcpy( CT.data(), CT_dev, CT_bytes, cudaMemcpyDeviceToHost ) );
    int n_err = 0, n_err_z = 0;
    double max_err = 0;

    using wmma::precision::tf32;

    // An estimate. Actual value is implementation dependent.
    constexpr int tf32_prec_bits = 10;
    constexpr int fp16_prec_bits = 11;

    float tol = ( ab_fmt == AB_tf32
                  ? 1.0 / float( 1 << tf32_prec_bits )
                  : ab_fmt == AB_fp16
                  ? 1.0 / float( 1 << fp16_prec_bits )
                  : numeric_limits<elt_t>::epsilon() ) * A_ncols * 5;
    for ( size_t i=0; i<CT_cpu.size(); i++ )
      {
        elt_t hd = CT[i];
        elt_t hc = CT_cpu[i];
        float d = fabs( hd - hc );
        if ( d > max_err ) max_err = d;
        if ( d > tol )
          {
            n_err++;
            if ( hd == 0 ) n_err_z++;
          }
      }
    if ( !n_err && !verbose ) return;
    printf("Number of errors %d out of %zd. %d zeros\n",
           n_err, CT_cpu.size(), n_err_z);
    printf("Max err = %.9f,  tolerance = %.9f\n", max_err, tol );
  }

};




int
main(int argc, char **argv)
{
  NPerf_init();

  // Otherwise use time measured by kernels.
  const bool opt_use_nperf_time = false;

  // Get info about GPU and each kernel.
  //
  GPU_Info info = print_gpu_and_kernel_info();
  // For metric naming convention see: 
  //   https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-decoder
  //
  NPerf_metric_collect("sm__inst_executed.sum");
  NPerf_metric_collect("sm__inst_executed.max");
  NPerf_metric_collect("sm__inst_executed.avg");
  NPerf_metric_collect
    ("sm__instruction_throughput.avg.pct_of_peak_sustained_elapsed");

  NPerf_metric_collect
    ("sm__average_thread_inst_executed_pred_on_per_inst_executed_realtime.pct");

  NPerf_metric_collect
    ("smsp__thread_inst_executed_pred_off.sum");
  NPerf_metric_collect
    ("smsp__thread_inst_executed_pred_on.sum");
  NPerf_metric_collect
    ("smsp__thread_inst_executed.sum");

  NPerf_metric_collect
    ("smsp__thread_inst_executed_per_inst_executed.ratio");

  NPerf_metric_collect("l1tex__t_sectors.sum");
  NPerf_metric_collect("l1tex__t_requests.sum");
  NPerf_metric_collect("l1tex__data_bank_conflicts_pipe_lsu.sum");
  NPerf_metric_collect("smsp__sass_l1tex_tags_mem_global.sum");
  //
  // From L2 cache to L1 cache.
  NPerf_metric_collect("l1tex__m_xbar2l1tex_read_bytes.sum");
  // From L1 cache to L2 cache.
  NPerf_metric_collect("l1tex__m_l1tex2xbar_write_bytes.sum");

  NPerf_metric_collect("dram__bytes_read.sum");
  NPerf_metric_collect("dram__bytes_write.sum");

  NPerf_metric_collect
    ("sm__sass_average_branch_targets_threads_uniform.pct");

  constexpr int wp_sz = 32;

  vector<Shape> shapes;

  for ( auto& ls: shape_specs ) shapes.emplace_back(ls);

  const int n_shapes = shapes.size();

  struct App_Kernel_Info {
    App_Kernel_Info
    (Kernel_Info& k,const char *name_base, const char *name, int i,
     int p_m_wd, int p_m_ht, int thd_p_elt, AB_Format fmt):
      k_ptr(k.func_ptr),name_base(name_base), name(name),
      shape_idx{i},m_wd(p_m_wd),m_ht(p_m_ht),thd_p_elt(thd_p_elt),
      ab_format(fmt) {}
    GPU_Info_Func k_ptr;
    const char *name_base;
    const char *name;
    const int shape_idx;
    const int m_wd, m_ht, thd_p_elt;
    const AB_Format ab_format;
  };

  vector<App_Kernel_Info> kernels;

  #define EXAMINE_KERNEL(k,kb,sidx,m_wd,m_ht,thd_p_elt,fmt)                   \
    { const int idx = kernels.size();                                         \
      kernels.emplace_back                                                    \
       ( info.GET_INFO((k)), kb,#k,sidx,m_wd,m_ht,thd_p_elt,fmt); }

  #define SPECIFY_KERNEL(k,sidx) \
    EXAMINE_KERNEL((k<shape_specs[sidx].A_ncols>),sidx,1,1,1,AB_elt_t);

  #define xSPECIFY_KERNEL2(k,sidx,m_wd) \
    EXAMINE_KERNEL((k<shape_specs[sidx].A_ncols,m_wd>),sidx,m_wd,1,1);

  #define SPECIFY_KERNEL2(k,sidx) \
    EXAMINE_KERNEL( (k<shape_specs[sidx].A_nrows,shape_specs[sidx].A_ncols>), \
                    #k,sidx,1,1,1,AB_elt_t);

  #define SPECIFY_KERNEL_tc(k,sidx,fmt,m_wd_n,m_ht_n,ty)                                  \
    EXAMINE_KERNEL( (k<shape_specs[sidx].A_nrows,shape_specs[sidx].A_ncols, \
                       shape_specs[sidx].B_ncols, \
                       m_wd_n, m_ht_n, ty >),                                    \
                    #k, sidx, m_wd_n, m_ht_n, 0, fmt);

  #define SPECIFY_KERNEL4(k,sidx,m_wd,m_ht,thd_p_elt)                         \
    EXAMINE_KERNEL( ( k< shape_specs[sidx].A_nrows,                           \
                         shape_specs[sidx].A_ncols, m_wd, m_ht, thd_p_elt > ), \
                    #k, sidx, m_wd, m_ht, thd_p_elt, AB_elt_t);

  #define SPECIALIZE_KERNEL(sidx) \
    SPECIFY_KERNEL4(mm_tile_wd_ht_dp,sidx,8,8,4); \
    SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,1,1,half); \
    SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,2,1,half); \
    SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,4,1,half); \
    SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,8,1,half); \
    SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,1,1,half); \
    SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,2,2,half); \
    SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,4,1,half); \
    SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,2,8,half); \
    SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,8,2,half); \
    SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,4,4,half);

  /// Homework 5 -- Can modify the m_wd_n and m_ht_n above.
  //
  //  For example, to run a m_wd_n=2, m_ht_n=4 kernel add a line:
  //
  //    SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,2,8,half);
  //
  //  to the macro above, end the line with a backslash.


  SPECIALIZE_KERNEL(0);
  SPECIALIZE_KERNEL(2);

  #undef SPECIALIZE_KERNEL

  const bool want_kernel_info = false;
  if ( want_kernel_info )
    {
      // Print information about kernel.
      //
      printf("\nCUDA Kernel Resource Usage:\n");

      for ( int i=0; i<info.num_kernels; i++ )
        {
          printf("For %s:\n", info.ki[i].name);
          printf("  %6zd shared, %zd const, %zd loc, %d regs; "
                 "%d max threads per block.\n",
                 info.ki[i].cfa.sharedSizeBytes,
                 info.ki[i].cfa.constSizeBytes,
                 info.ki[i].cfa.localSizeBytes,
                 info.ki[i].cfa.numRegs,
                 info.ki[i].cfa.maxThreadsPerBlock);
        }
    }



  // Get number of multiprocessors. (A.k.a. streaming multiprocessors or SMs)
  //
  const int num_mp = info.cuda_prop.multiProcessorCount;

  // Examine argument 1, block count, default is number of MPs.
  //
  const int arg1_int = argc < 2 ? num_mp : atoi(argv[1]);
  const int num_blocks =
     arg1_int == 0 ? num_mp :
     arg1_int < 0  ? -arg1_int * num_mp : arg1_int;

  // Examine argument 2, number of warps per block.
  //
  const int wp_per_block_arg = argc < 3 ? 0 : atoi(argv[2]);
  const int wp_per_block_goal =
   wp_per_block_arg == 0 ? 32 : wp_per_block_arg;
  const int n_threads = num_blocks * wp_per_block_goal * wp_sz;

  printf("Kernel timing based on %s\n",
         opt_use_nperf_time ? "sm__cycles_elapsed.max"
         : "clock() data from kernels.");

  for ( int i=0; i<n_shapes; i++ )
    {
      Shape& l = shapes[i];
      printf
        ("Shape %d: A_nrows=%zd.  A_ncols=%zd.  B_ncols=%zd\n",
         i, l.A_nrows, l.A_ncols, l.B_ncols );
      printf
        ("Shape %d: A: %zu kiB   B: %zu kiB\n",
         i, l.A_bytes >> 10, l.B_bytes >> 10);
    }

  if ( n_threads <= 0 )
    {
      printf("Usage: %s [ NUM_CUDA_BLOCKS ] [WARPS_PER_BLOCK] "
             "[COL PER MP]\n",
             argv[0]);
      exit(1);
    }

  cudaFuncCache cache_config;
  map<cudaFuncCache,const char*> cc_to_string;
  #define CCMAP(e) cc_to_string[e]=#e;
  CCMAP(cudaFuncCachePreferNone);
  CCMAP(cudaFuncCachePreferShared);
  CCMAP(cudaFuncCachePreferL1);
  CCMAP(cudaFuncCachePreferEqual);
  CE( cudaDeviceSetCacheConfig( cudaFuncCachePreferL1 ) );
  CE( cudaDeviceGetCacheConfig( &cache_config ) );

  printf("Cache preference set to %s.\n",cc_to_string[cache_config]);

  const double clock_period_us = 1e3 / info.cuda_prop.clockRate;

  const int max_wps = num_blocks * 32;
  vector<Timing_Item> timing_items(max_wps);
  Timing timing;
  const size_t timing_items_bytes =
    timing_items.size() * sizeof(timing_items[0]);
  CE( cudaMalloc( &timing.timing_items, timing_items_bytes ) );
  CE( cudaMemcpyToSymbol
      ( timing_dev, &timing, sizeof(timing), 0, cudaMemcpyHostToDevice) );

  // The width, in characters, of the output to which we are printing.
  //
  const int output_width = stdout_width_get();

  for ( auto& l: shapes )
    {
      l.launch_prep();
      M_transpose<<< num_mp, 1024 >>>(l.AT_dev,l.A_dev,l.A_nrows,l.A_ncols);
      M_transpose<<< num_mp, 1024 >>>(l.BT_dev,l.B_dev,l.A_ncols,l.B_ncols);
      convert_float_tf32<<< num_mp, 1024 >>>();
    }
  {
    // Prepare events used for timing.
    //
    cudaEvent_t gpu_start_ce, gpu_stop_ce;
    CE(cudaEventCreate(&gpu_start_ce));
    CE(cudaEventCreate(&gpu_stop_ce,cudaEventBlockingSync));

    // Launch kernel multiple times and keep track of the best time.
    printf("Launching with %d blocks of up to %d warps. \n",
           num_blocks, wp_per_block_goal);

    for ( auto& aki: kernels )
      {
        const char* kname = aki.name_base;
        const int sidx = aki.shape_idx;
        Shape& l = shapes[sidx];
        const bool vary_warps = true;

        const bool uses_tc_fp16 = aki.ab_format == AB_fp16;
        const bool uses_tc_tf32 = aki.ab_format == AB_tf32;
        const bool uses_tc = uses_tc_fp16 || uses_tc_tf32;

        Kernel_Info* const ki = &info.get_info(aki.k_ptr);

        pStringF shape_etc("Shape %zd x %zd x %zd.",
                           l.A_nrows,l.A_ncols,l.B_ncols);
        pStringF tiling_scalar
          ("wd=%d, ht=%d, dp=%d",aki.m_wd,aki.m_ht,aki.thd_p_elt);
        pStringF tiling_tc
          ("wd=%d  ht=%d  tiles.",aki.m_wd,aki.m_ht);

        pStringF local_txt(", %zd LOCAL", ki->cfa.localSizeBytes );

        printf("\nKernel %s, %d regs%s. %s %s  %s\n",
               kname, ki->cfa.numRegs,
               ki->cfa.localSizeBytes ? local_txt.s : "",
               shape_etc.s,
               uses_tc ? tiling_tc.s : tiling_scalar.s,
               ab_format_txt[aki.ab_format]);
        pTable table(stdout);

        const int wp_limit_kernel = ki->cfa.maxThreadsPerBlock >> 5;
        const int wp_limit_prefer = 32;
        const int wp_limit = min(wp_limit_kernel,wp_limit_prefer);

        const int thd_limit = wp_limit << 5;
        const int thd_per_block_no_vary =
          min(wp_per_block_goal*wp_sz,thd_limit);

        const int wp_start = 4;
        const int wp_stop = vary_warps ? wp_limit : wp_start;
        //  const int wp_inc = 4;


        //  for ( int wp_cnt = wp_start; wp_cnt <= wp_stop; wp_cnt += wp_inc )
        for ( int wp_cnt = wp_start; wp_cnt <= wp_stop; wp_cnt *= 2 )
          {
            const int thd_per_block = wp_cnt << 5;

            if ( vary_warps && wp_cnt == 12 && wp_cnt < wp_stop ) continue;
            if ( vary_warps && wp_cnt > 16 && wp_cnt < wp_stop ) continue;
            if ( vary_warps && wp_cnt > 4 && wp_cnt & 0x3 ) continue;

            const int grid_wps = wp_cnt * num_blocks;

            /// Compute Expected Computation and Communication
            //
            // Number of multiply/add operations. Ignore everything else.
            //
            const int64_t num_ops_fp = l.est_mm_fp;
            //
            // Load and store instructions.
            const int64_t est_mm_ls_ops =
              uses_tc_fp16 ? l.est_mm_ls_ops / 2 : l.est_mm_ls_ops;

            const int tc_tile_m = 16;
            const int tc_tile_n = 16;
            const int ld_reuse_a = uses_tc ? aki.m_wd * tc_tile_n : aki.m_wd;
            const int ld_reuse_b = uses_tc ? aki.m_ht * tc_tile_m : aki.m_ht;

            const int64_t num_ops_ls =
              est_mm_ls_ops / ( 2 * ld_reuse_a ) +
              est_mm_ls_ops / ( 2 * ld_reuse_b );

            const int64_t est_mm_l_bytes =
              uses_tc_fp16 ? l.est_mm_l_bytes / 2 : l.est_mm_l_bytes;

            //
            // Amount of data in and out of GPU chip --- if perfect.
            //
            const int64_t amt_data_bytes =
              est_mm_l_bytes + l.est_mm_s_bytes;

            {
              l.launch_prep();
              CE( cudaMemset( timing.timing_items, 0, timing_items_bytes ) );

              NPerf_metrics_off();

              // Measure execution time starting "now", which is after data
              // set to GPU.
              //
              //  CE(cudaEventRecord(gpu_start_ce));

              typedef void (*KPtr)();

              /// Launch Kernel -- Without Performance Counter Sampling
              //
              KPtr(ki->func_ptr) <<< num_blocks, thd_per_block >>>();

              // Stop measuring execution time now, which is before is data
              // returned from GPU.
              //
              //  CE(cudaEventRecord(gpu_stop_ce));
              //  CE(cudaEventSynchronize(gpu_stop_ce));
              float cuda_time_ms = -1.1;

              CE( cudaMemcpy
                  ( timing_items.data(), timing.timing_items,
                    timing_items_bytes, cudaMemcpyDeviceToHost ) );

              l.launch_prep();

              NPerf_metrics_on();

              /// Launch Kernel -- With Performance Counter Sampling
              //
              for ( NPerf_data_reset(); NPerf_need_run_get(); )
                KPtr(ki->func_ptr) <<< num_blocks, thd_per_block >>>();

              NPerf_metrics_off();

              map<int32_t,Timing_Item> sm_start_end;
              int n_migs = 0; // Number of migrations.
              for ( auto& ti: views::take(timing_items,grid_wps) )
                if ( ti.smid_start != ti.smid_end )
                  {
                    n_migs++;
                  }
                else
                  {
                    auto& tis = sm_start_end[ti.smid_start];
                    if ( tis.time_start == tis.time_end ) tis = ti;
                    set_min( tis.time_start, ti.time_start );
                    set_max( tis.time_end, ti.time_end );
                  }
              if ( n_migs ) printf("-- Number of migrations: %d\n",n_migs);
              assert( sm_start_end.size() == num_mp );

              int64_t et_sum = 0;
              vector<clock_t> et;
              for ( auto& [smid,tis]: sm_start_end )
                {
                  const int64_t elapsed = tis.time_end - tis.time_start;
                  et_sum += elapsed;
                  et.push_back( elapsed );
                }

              ranges::sort( et, ranges::greater() );

              const double et_clock_max_us = et[0] * clock_period_us;
              const double et_clock_avg_us =
                et_sum * clock_period_us / sm_start_end.size();

              // Note: AoTW NPerf uses sm__cycles_elapsed.max.
              const double nperf_elapsed_time_s =
                NPerf_metrics_collection_get()
                ? NPerf_kernel_et_get() : cuda_time_ms * 0.001;

              const double this_elapsed_time_s =
                opt_use_nperf_time ? nperf_elapsed_time_s
                : et_clock_max_us * 1e-6;

              const double imbalance_penalty =
                et_clock_avg_us ? et_clock_max_us / et_clock_avg_us : 0.0;

              const double imbalance_penaltyi =
                NPerf_metric_value_get("sm__inst_executed.max") /
                NPerf_metric_value_get("sm__inst_executed.avg");

              const double thpt_compute_gflops =
                num_ops_fp / this_elapsed_time_s * 1e-9;
              const double thpt_data_gbps =
                amt_data_bytes / this_elapsed_time_s * 1e-9;

              // Number of load/store operations per second.
              const double chip_ls_ops = info.chip_sp_flops / 4;

              const int num_sched_p_sm = 4;
              const int tc_flop_p_hmma_tf32 = 16; // CC 8.9
              const int tc_ii_p_hmma_tf32 = 2; // CC 8.9 Inferred.
              const int tc_flop_p_hmma_fp16 = 64; // CC 8.9
              const int tc_ii_p_hmma_fp16 = 4; // CC 8.9. Inferred.
              const int chip_tc_insn_p_cyc = num_mp * num_sched_p_sm * wp_sz;
              const double chip_tc_fp16_flop_p_cyc =
                chip_tc_insn_p_cyc * tc_flop_p_hmma_fp16 / tc_ii_p_hmma_fp16;
              const double chip_tc_fp16_flops =
                chip_tc_fp16_flop_p_cyc * info.clock_freq_hz;
              const double chip_tc_tf32_flop_p_cyc =
                chip_tc_insn_p_cyc * tc_flop_p_hmma_tf32 / tc_ii_p_hmma_tf32;
              const double chip_tc_tf32_flops =
                chip_tc_tf32_flop_p_cyc * info.clock_freq_hz;

              const float flop_p_insn =
                uses_tc_fp16 ? tc_flop_p_hmma_fp16 :
                uses_tc_tf32 ? tc_flop_p_hmma_tf32 : 1;

              const double t_bound_fp_tc_fp16 = num_ops_fp / chip_tc_fp16_flops;
              const double t_bound_fp_tc_tf32 = num_ops_fp / chip_tc_tf32_flops;
              const double t_bound_fp_sp = num_ops_fp / info.chip_sp_flops;
              const double t_bound_fp =
                uses_tc_fp16 ? t_bound_fp_tc_fp16 :
                uses_tc_tf32 ? t_bound_fp_tc_tf32 : t_bound_fp_sp;

              const double t_bound_ls = num_ops_ls / chip_ls_ops;
              const double t_bound_insn = t_bound_fp + t_bound_ls;

              {
                const double comp_frac = t_bound_insn / this_elapsed_time_s;
                //  1e9 * thpt_compute_gflops / info.chip_sp_flops;
                const double bw_frac =
                  1e9 * thpt_data_gbps / info.chip_bw_Bps;
                const double fp_frac = t_bound_fp / this_elapsed_time_s;

                // Number of warps, rounded up.
                //
                const int num_wps = ( thd_per_block + 31 ) >> 5;

                // The maximum number of active blocks per MP for this
                // kernel when launched with a block size of thd_per_block.
                //
                const int max_bl_per_mp =
                  ki->get_max_active_blocks_per_mp(thd_per_block);

                // Compute number of blocks available per MP based only on
                // the number of blocks.  This may be larger than the
                // number of blocks that can run.
                //
                const int bl_per_mp_available =
                  0.999 + double(num_blocks) / num_mp;

                // The number of active blocks is the minimum of what
                // can fit and how many are available.
                //
                const int bl_per_mp =
                  min( bl_per_mp_available, max_bl_per_mp );

                // Based on the number of blocks, compute number of warps.
                //
                const int act_wps = num_wps * bl_per_mp;

                pTable_Row row(table);
                table.entry("wp",num_wps);
                if ( num_blocks > num_mp )
                  table.entry("ac",act_wps);
                if ( NPerf_metrics_collection_get() )
                  {
                    double dram_rd_bytes =
                      NPerf_metric_value_get("dram__bytes_read.sum");
                    double dram_wr_bytes =
                      NPerf_metric_value_get("dram__bytes_write.sum");

                    double l2_rd_bytes =
                      NPerf_metric_value_get
                      ("l1tex__m_xbar2l1tex_read_bytes.sum");
                    double l2_wr_bytes =
                      NPerf_metric_value_get
                      ("l1tex__m_l1tex2xbar_write_bytes.sum");

                    table.header_span_start("Insn");

                    const double normalized_ipi_factor =
                      32.0 / num_ops_fp * flop_p_insn;

                    table.entry
                      ("/itr","%5.2f",
                       NPerf_metric_value_get("sm__inst_executed.sum")
                       * normalized_ipi_factor );

                    table.entry
                      ("% ","%2.0f",
                       NPerf_metric_value_get
                       ("sm__instruction_throughput"
                        ".avg.pct_of_peak_sustained_elapsed") );

                    const double pred_on_frac =
                      NPerf_metric_value_get
                      ("smsp__thread_inst_executed_pred_on.sum") /
                      NPerf_metric_value_get("smsp__thread_inst_executed.sum");
                    const double diverg_frac =
                      NPerf_metric_value_get
                      ("smsp__thread_inst_executed_per_inst_executed.ratio")
                      / 32.0;
                    const double lane_eff = pred_on_frac * diverg_frac;

                    pStringF eff_str("%2d%s", int(lane_eff * 100 + 0.5),
                                     diverg_frac < pred_on_frac ? "d" : "p" );
                    const bool show_TAc = false;
                    if ( show_TAc )
                      table.entry("TAc","%3s",
                                  lane_eff > .99 ? "100" : eff_str.s);

                    table.header_span_end();

                    table.header_span_start("L1");

                    table.entry
                      ("SW","%2.0f",
                       NPerf_metric_value_get("l1tex__t_sectors.sum") /
                       NPerf_metric_value_get ("l1tex__t_requests.sum"));

                    if ( 0 )
                      table.entry
                        ("TW","%2.0f",
                         NPerf_metric_value_get
                         ("smsp__sass_l1tex_tags_mem_global.sum") /
                         NPerf_metric_value_get ("l1tex__t_requests.sum"));

                    table.entry
                      ("BXW","%4.1f",
                       NPerf_metric_value_get
                       ("l1tex__data_bank_conflicts_pipe_lsu.sum") /
                       NPerf_metric_value_get ("l1tex__t_requests.sum"));
                    table.header_span_end();

                    if ( true ) {

                      table.header_span_start(" L1<->L2 ");
                      table.entry
                        ("N-Rd","%4.0f", l2_rd_bytes / est_mm_l_bytes );
                      table.entry
                        ("N-Wr","%4.1f", l2_wr_bytes / l.est_mm_s_bytes );
                      table.entry
                        ("GB/s","%4.0f",
                         1e-9 * ( l2_rd_bytes + l2_wr_bytes )
                         / ( this_elapsed_time_s ) );
                      table.header_span_end(); }

                  }

                const bool imbalance_more_time =
                  imbalance_penalty > imbalance_penaltyi;
                const double imbalance =
                  max( imbalance_penalty, imbalance_penaltyi );
                pStringF imb_textp
                  ("%2.0f%s",
                   ( imbalance - 1 ) * 10, imbalance_more_time ? "t" : "i");

                table.entry("Imb","%3s", imb_textp.s);

                table.entry( "t/µs", "%6.0f", this_elapsed_time_s * 1e6 );

                if ( 0 )
                table.entry("FP θ","%4.0f", thpt_compute_gflops);

                if ( false )
                  table.entry("GB/s","%4.0f", thpt_data_gbps);

                const size_t max_st_len =
                  max(5, output_width - 1 - table.row_len_get() );
                pStringF fmt("%%-%zds",max_st_len);
                string fp_unit = uses_tc ? "TC" : "FP";
                string util_hdr = "=== Util: " + fp_unit + "++  Insn-- ";
                if ( max_st_len > util_hdr.length() )
                  util_hdr += string(max_st_len - util_hdr.length(),'=');

                typedef struct { double f; char c; } Elt;
                vector<Elt> segments =
                  { { fp_frac, '+' }, { comp_frac, '-' }, { 0*bw_frac, '*' } };

                ranges::sort( segments, {}, &Elt::f );

                string bar;
                for ( Elt& e: segments )
                  if ( size_t p = e.f * max_st_len + 0.5; p > bar.length() )
                    bar += string( p - bar.length(), e.c );

                if ( bar.length() > max_st_len )
                  {
                    bar.resize(max_st_len);
                    bar[max_st_len-1] = '>';
                  }

                table.entry(util_hdr,fmt, bar, pTable::pT_Left);
              }

              // Copy output array from GPU to CPU and check.
              //
              l.from_dev_and_check(aki.ab_format);
            }
          }
      }
  }

  return 0;
}