/// LSU EE 7722 GPU Microarchitecture
//
 /// Spring 2024
 /// Homework 1 -- SOLUTION
 //
 //  Assignment: https://www.ece.lsu.edu/koppel/gp/2024/hw01.pdf
 //
 //  Put solution in kernel norm_group.
 //
 //  Modify this file only.

#include <cuda_runtime.h>
#include <gp/cuda-gpuinfo.h>
#include <ptable.h>
#include <nperf.h>
#include <misc.h>

#include <stdio.h>
#include <random>
#include <vector>
#include <algorithm>
#include <ranges>

typedef float elt_t;

struct App
{
  int n_l;
  int d_l;
  elt_t *l_in_d, *l_out_d;
};

__constant__ App c_app;

template< int D_L = 0 >
__global__ void
norm_base(elt_t* __restrict__ l_out, const elt_t* __restrict__ l_in)
{
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  const int n_threads = blockDim.x * gridDim.x;

  const int d_l = D_L ?: c_app.d_l;
  const int n_l = c_app.n_l;

  /// Do not modify this kernel.

  for ( int h = tid;  h < n_l;  h += n_threads )
    {
      elt_t sum = 0;
      for ( int i = 0;  i < d_l;  i++ ) sum += l_in[ h * d_l + i ];
      const elt_t avg = sum / d_l;

      for ( int i = 0;  i < d_l;  i++ )
        l_out[ h * d_l + i ] = l_in[ h * d_l + i ] - avg;
    }
}

__device__ elt_t
group_sum(elt_t thd_val, uint group_size)
{
  assert( group_size <= 32 );
  /// Do not modify this kernel either.
  const uint32_t mask =
    group_size > 1 && group_size < 32 ? __activemask() : ~0;
  elt_t sum = thd_val;
  for ( int dist = 1;  dist < group_size; dist <<= 1 )
    sum += __shfl_xor_sync(mask, sum, dist );
  return sum;
}

double
prob_1b_workload_imbalance
(GPU_Info& info, App& app, int n_blocks, int thd_p_block, int grp_size_raw)
{
  /// SOLUTION -- Problem 1b

  const int grp_size = grp_size_raw ?: 1;
  int n_threads = n_blocks * thd_p_block;
  const int64_t total_thd_h_work = app.n_l * grp_size;
  const int64_t h_n_iters_max =
    ( total_thd_h_work + n_threads - 1 ) / n_threads;
  const double imbalance =
    double(total_thd_h_work) / ( h_n_iters_max * n_threads );
  return imbalance;
}


template<int D_L = 0, int grp_size = 1>
__global__ void
norm_group(elt_t* __restrict__ l_out, const elt_t* __restrict__ l_in)
{
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  const int n_threads = blockDim.x * gridDim.x;

  const int d_l = D_L ?: c_app.d_l;
  const int n_l = c_app.n_l;

  /// SOLUTION -- Problem 1a

  /// Solution Outline
  //
  //  - Each vector is operated on by grp_size threads. Say, grp_size = 4.
  //    The code only works if grp_size is a power of 2 and <= 32.
  //    Note: h (in the loop below) is the vector number.
  //
  //  - The set of grp_size threads operating on a vector is called a group.
  //
  //  - Each thread in a group is assigned a sub_lane, the sub_lanes
  //    are numbered from 0 to grp_size-1.
  //
  //  - TO AVOID BANK CONFLICTS, have consecutive threads operate
  //    on consecutive elements of a vector.


  // Determine this thread's sub_lane.
  //
  const int sub_lane = threadIdx.x % grp_size;

  // Determine first vector to operate on.
  //
  const int h_start = tid / grp_size;
  //
  // Notice that grp_size consecutive threads have the same value of h_start.

  for ( int h = h_start;  h < n_l;  h += n_threads / grp_size )
    {
      // Index of first element of vector number h.
      const size_t idx_vec_start = h * d_l;

      // Index of first element of vector h that this thread will operate on.
      const size_t idx_vec_thd_start = idx_vec_start + sub_lane;
      //
      // This assignment of elements to threads will avoid bank conflicts
      // if grp_size = 32 or if grp_size = d_l.

      elt_t thd_sum = 0;
      for ( int i = 0;  i < d_l;  i += grp_size )
        thd_sum += l_in[ idx_vec_thd_start + i ];

      const elt_t sum = group_sum(thd_sum,grp_size);
      const elt_t avg = sum / d_l;

      for ( int i = 0;  i < d_l;  i += grp_size )
        l_out[ idx_vec_thd_start + i ] = l_in[ idx_vec_thd_start + i ] - avg;
    }
}


template<int D_L = 0, int grp_size = 1>
__global__ void
norm_group_bad(elt_t* __restrict__ l_out, const elt_t* __restrict__ l_in)
{
  const int tid = threadIdx.x + blockIdx.x * blockDim.x;
  const int n_threads = blockDim.x * gridDim.x;

  const int d_l = D_L ?: c_app.d_l;
  const int n_l = c_app.n_l;

  /// BAD SOLUTION BELOW

  /// Solution Outline
  //
  //  - Each vector is operated on by grp_size threads. Say, grp_size = 4.
  //    The code only works if grp_size is a power of 2 and <= 32.
  //    Note: h (in the loop below) is the vector number.
  //
  //  - The set of grp_size threads operating on a vector is called a group.
  //
  //  - Each thread in a group is assigned a sub_lane, the sub_lanes
  //    are numbered from 0 to grp_size-1.
  //
  //  - TO AVOID BANK CONFLICTS, have consecutive threads operate
  //    on consecutive elements of a vector.
  //    LET'S IGNORE THAT ADVICE.

  // Determine this thread's sub_lane.
  //
  const int sub_lane = threadIdx.x % grp_size;

  // Determine first vector to operate on.
  //
  const int h_start = tid / grp_size;
  //
  // Notice that grp_size consecutive threads have the same value of h_start.

  const int elt_p_thd = d_l / grp_size;

  for ( int h = h_start;  h < n_l;  h += n_threads / grp_size )
    {
      // Index of first element of vector number h.
      const size_t idx_vec_start = h * d_l;

      // Index of first element of vector h that this thread will operate on.
      /// WARNING: Inefficient way to do it. (This is the bad part.)
      const size_t idx_vec_thd_start = idx_vec_start + sub_lane * elt_p_thd;
      //
      // This assignment of elements to threads will not avoid bank conflicts
      // when d_l > 32, even if grp_size = 32.

      elt_t thd_sum = 0;
      // BAD SOLUTION: Iterate over elt_p_thd consecutive elements.
      for ( int i = 0;  i < elt_p_thd;  i++ )
        thd_sum += l_in[ idx_vec_thd_start + i];

      const elt_t sum = group_sum(thd_sum,grp_size);
      const elt_t avg = sum / d_l;

      // BAD SOLUTION: Iterate over elt_p_thd consecutive elements.
      for ( int i = 0;  i < elt_p_thd;  i++ )
        l_out[ idx_vec_thd_start + i ] = l_in[ idx_vec_thd_start + i ] - avg;
    }
}



GPU_Info
print_gpu_and_kernel_info()
{
  GPU_Info info;

  gpu_info_print();

  // Choose GPU 0 because it's usually the better choice.
  //
  int dev = gpu_choose_index();
  CE(cudaSetDevice(dev));
  printf("Using GPU %d\n",dev);
  info.get_gpu_info(dev);

  return info;
}



int
main(int argc, char **argv)
{
  //  const bool debug = false;
  App app;

  // Must be called before any CUDA API calls.
  NPerf_init();

  // Get info about GPU and each kernel.
  //
  GPU_Info info = print_gpu_and_kernel_info();

  const uint num_mp = info.cuda_prop.multiProcessorCount;

  // Examine argument 1, block count, default is number of MPs.
  //
  const int arg1_int = argc < 2 ? num_mp : atoi(argv[1]);
  const uint num_blocks =
     arg1_int == 0 ? num_mp :     arg1_int < 0  ? -arg1_int * num_mp : arg1_int;

  // Examine argument 2, number of threads per block.
  //
  //  const bool opt_p = argc >= 3 && string(argv[2]) == "p";
  const bool opt_p = true;
  const int thd_per_block_arg = argc < 3 ? 0 : atoi(argv[2]);
  const int thd_per_block_goal =
   thd_per_block_arg == 0 ? 1024 : thd_per_block_arg;
  const int num_threads = num_blocks * thd_per_block_goal;

  const bool vary_warps = thd_per_block_arg == 0;

  const int l2_size_bytes = info.cuda_prop.l2CacheSize;

  // This is a guess.
  const float l2_bandwidth_Bps = 32 * bit_floor(num_mp-1) * info.clock_freq_hz;

  const float default_n_l2_units = 0.25;

  // Examine argument 3, size of array.
  //
  //   If positive, multiples of 1 MiB.
  //   If negative, multiples of cache size.
  //
  const float arg3_val = argc < 4 ? -default_n_l2_units : atof(argv[3]);

  const int n_bytes_raw =
    arg3_val < 0 ? -arg3_val * l2_size_bytes : arg3_val * ( 1 << 20 );
  const int n_elts_max = n_bytes_raw / sizeof(elt_t);

  if ( num_threads <= 0 || n_elts_max <= 0 )
    {
      printf("Usage: %s [ NUM_CUDA_BLOCKS ] [THD_PER_BLOCK|p] "
             "[-DATA_SIZE_L2_UNITS|DATA_SIZE_MiB]\n",
             argv[0]);
      exit(1);
    }

  // Collect performance data using a wrapper to NVIDIA CUPTI event
  // counter API.
  //
  NPerf_metric_collect("sm__inst_executed.sum");
  NPerf_metric_collect("gld_efficiency");
  if ( opt_p )
    {
      NPerf_metric_collect
        ("sm__instruction_throughput.avg.pct_of_peak_sustained_elapsed");

      // From L2 cache to L1 cache.
      NPerf_metric_collect("l1tex__m_xbar2l1tex_read_bytes.sum");
      // From L1 cache to L2 cache.
      NPerf_metric_collect("l1tex__m_l1tex2xbar_write_bytes.sum");
      // From DRAM to/from L2
      NPerf_metric_collect("dram__bytes_read.sum");
      NPerf_metric_collect("dram__bytes_write.sum");

      NPerf_metric_collect("l1tex__t_requests.sum");
      NPerf_metric_collect("l1tex__data_bank_conflicts_pipe_lsu.sum");
    }
  //
  // Note: The more metrics that are collected, the more times a kernel
  // will need to be run.

  const size_t max_size_elts = n_elts_max;
  const size_t max_size_bytes = max_size_elts * sizeof( app.l_in_d[0] );
  const size_t overrun_size_bytes = 1024 * sizeof( app.l_in_d[0] );

  struct App_Kernel_Info {
    App_Kernel_Info
    (int idx, Kernel_Info& k,const char *name_base, const char *name,
     int grp_sz, int d_l):
      idx(idx),k_ptr(k.func_ptr), name(name), name_base(name_base),
      grp_size(grp_sz), d_l(d_l){}
    int idx;
    GPU_Info_Func k_ptr;
    const char *name;
    const char *name_base;
    int grp_size, d_l;
  };

  vector<App_Kernel_Info> kernels;

  #define PROCESS_KERNEL(k) \
    { const int idx = kernels.size(); \
      kernels.emplace_back(idx,info.GET_INFO((k)),#k,#k,1,0); }

  #define TDH_GRP_KERNEL3(k,kb,grp_sz,d_l) \
    { const int idx = kernels.size();                   \
      kernels.emplace_back(idx,info.GET_INFO(k),kb,#k,grp_sz,d_l); }

  #define TDH_GRP_KERNEL2(k,grp_sz,d_l) \
    TDH_GRP_KERNEL3((k<d_l,grp_sz>),#k,grp_sz,d_l);

  #define TDH_GRP_KERNEL1(k,grp_sz) \
    TDH_GRP_KERNEL2(k,grp_sz,0); TDH_GRP_KERNEL2(k,grp_sz,4); \
    TDH_GRP_KERNEL2(k,grp_sz,8); TDH_GRP_KERNEL2(k,grp_sz,32);

  #define TDH_GRP_KERNEL(k) \
    TDH_GRP_KERNEL1(k,1); TDH_GRP_KERNEL1(k,2); TDH_GRP_KERNEL1(k,4); \
    TDH_GRP_KERNEL1(k,8); TDH_GRP_KERNEL1(k,16); TDH_GRP_KERNEL1(k,32);

  #define TDH_KERNEL3(k,kb,d_l) \
    { const int idx = kernels.size();                   \
      kernels.emplace_back(idx,info.GET_INFO(k),kb,#k,0,d_l); }

  #define TDH_KERNEL2(k,d_l) TDH_KERNEL3((k<d_l>),#k,d_l);

  #define TDH_KERNEL(k) \
    TDH_KERNEL2(k,0); TDH_KERNEL2(k,4); \
    TDH_KERNEL2(k,8); TDH_KERNEL2(k,32);

  TDH_KERNEL(norm_base);
  TDH_GRP_KERNEL(norm_group);

  vector<string> knames = { "norm_base", "norm_group" };

  map<string, map<int, vector<App_Kernel_Info*> > > kmap;

  for ( auto& k: kernels ) kmap[k.name_base][k.d_l].push_back(&k);

  for ( auto& n: knames ) assert( kmap.contains(n) );

  for ( auto& [name,k_name]: kmap )
    for ( auto& [d_l,kd]: k_name )
      ranges::sort( kd, {}, [](auto *a){ return a->grp_size; } );

  // Allocate storage for CPU copy of data.
  //
  vector<elt_t> l_in(n_elts_max);
  vector<elt_t> l_out(n_elts_max);
  vector<elt_t> l_out_check(n_elts_max);

  // Allocate storage for GPU copy of data.
  //
  CE( cudaMalloc( &app.l_in_d, max_size_bytes + overrun_size_bytes ) );
  CE( cudaMalloc( &app.l_out_d, max_size_bytes + overrun_size_bytes ) );

  // Initialize input array.
  //
  for ( auto& e: l_in ) e = drand48();

  double elapsed_time_s = 86400; // Reassigned to minimum run time.
  const int output_width = stdout_width_get();
  {
    // Prepare events used for timing.
    //
    cudaEvent_t gpu_start_ce, gpu_stop_ce;
    CE(cudaEventCreate(&gpu_start_ce));
    CE(cudaEventCreate(&gpu_stop_ce));

    // Copy input array from CPU to GPU.
    //
    CE( cudaMemcpy
        ( app.l_in_d, l_in.data(), max_size_bytes, cudaMemcpyHostToDevice ) );

    // Launch kernel multiple times and keep track of the best time.
    printf("Launching with %d blocks of up to %d threads. \n",
           num_blocks, thd_per_block_goal);

    for ( int d_l: { 4, 8, 32, 128, 1024 } )
      {
        const int n_l = n_elts_max / d_l;
        const size_t in_size_elts = d_l * n_l;
        const size_t out_size_elts = in_size_elts;
        const size_t in_size_bytes = in_size_elts * sizeof( elt_t );
        const size_t out_size_bytes = in_size_bytes;
        const int64_t num_ops_fp = 2 * in_size_elts;
        const int64_t num_ops_ls = 2 * in_size_elts + out_size_elts;
        const int64_t amt_data_bytes = in_size_bytes + out_size_bytes;
        app.n_l = n_l;
        app.d_l = d_l;

        // Copy App structure to GPU.
        //
        CE( cudaMemcpyToSymbol
            ( c_app, &app, sizeof(app), 0, cudaMemcpyHostToDevice ) );

        for ( int i=0; i<n_l; i++ )
          {
            const auto idxs = views::iota( i*d_l, (i+1)*d_l );
            elt_t sum = 0;
            for ( auto j: idxs ) sum += l_in[j];
            const elt_t avg = sum / d_l;
            for ( auto j: idxs ) l_out_check[j] = l_in[j] - avg;
          }

        for ( auto kname: knames )
          {
            auto& k_name = kmap[kname];
            assert( k_name.contains(0) );
            auto& k_name_dh = k_name.contains(d_l) ? k_name[d_l] : k_name[0];
            vector<App_Kernel_Info*> k_run;
            for ( auto& um: k_name_dh | views::reverse )
              {
                if ( um->grp_size > d_l ) continue;
                k_run.push_back(um);
                if ( k_run.size() == 5 ) break;
              }
            assert( k_run.size() );

        for ( auto ki: k_run | views::reverse )
          {
            const int kernel = ki->idx;
            cudaFuncAttributes& cfa = info.ki[kernel].cfa;
            const auto func_ptr = info.ki[kernel].func_ptr;
            const int wp_limit = cfa.maxThreadsPerBlock >> 5;

            const int thd_limit = wp_limit << 5;
            const int thd_per_block_no_vary = min(thd_per_block_goal,thd_limit);

            pTable table;

            vector<int> n_wps = { 1, 2, 4, 8, 12, 16, 24, 32 };
            while ( n_wps.back() > wp_limit ) n_wps.pop_back();
            if ( !vary_warps ) n_wps.resize(1);
            const int wp_start = n_wps.front();

            for ( int wp_cnt: n_wps )
              {
                const int thd_per_block =
                  vary_warps ? wp_cnt << 5 : thd_per_block_no_vary;

                // Zero the output array.
                //
                CE(cudaMemset(app.l_out_d,0,out_size_bytes));

                // Measure execution time starting "now", which is after data
                // set to GPU.
                //
                CE(cudaEventRecord(gpu_start_ce,0));

                // Launch Kernel
                //
                typedef void (*KPtr)(elt_t*, const elt_t*);
                for ( NPerf_data_reset(); NPerf_need_run_get(); )
                  KPtr( info.ki[kernel].func_ptr )
                    <<< num_blocks, thd_per_block >>>(app.l_out_d,app.l_in_d);

                // Stop measuring execution time now, which is before is data
                // returned from GPU.
                //
                CE(cudaEventRecord(gpu_stop_ce,0));
                CE(cudaEventSynchronize(gpu_stop_ce));
                float cuda_time_ms = -1.1;
                CE( cudaEventElapsedTime
                    ( &cuda_time_ms, gpu_start_ce, gpu_stop_ce) );

                const double this_elapsed_time_s =
                  NPerf_metrics_collection_get()
                  ? NPerf_kernel_et_get() : cuda_time_ms * 0.001;

                const double thpt_compute_gflops =
                  num_ops_fp / this_elapsed_time_s * 1e-9;
                const double thpt_data_gbps =
                  amt_data_bytes / this_elapsed_time_s * 1e-9;

                const double fp_bw =
                  sizeof(elt_t) == 4 ? info.chip_sp_flops :
                  sizeof(elt_t) == 8 ? info.chip_dp_flops : 1;

                // Number of load/store operations per second.
                const double chip_ls_ops = info.chip_sp_flops / 4;

                const double t_bound_fp_s = num_ops_fp / fp_bw;
                const double t_bound_ls_s = num_ops_ls / chip_ls_ops;
                const double t_bound_insn_s = t_bound_fp_s + t_bound_ls_s;
                const double t_bound_mem_s = amt_data_bytes / info.chip_bw_Bps;
                const double t_bound_l2_s = amt_data_bytes / l2_bandwidth_Bps;

                const double frac_fp =
                  min( 2.0, t_bound_fp_s / this_elapsed_time_s );
                const double frac_insn =
                  min( 2.0, t_bound_insn_s / this_elapsed_time_s );
                const double frac_mem =
                  min( 2.0, t_bound_mem_s / this_elapsed_time_s );
                const double frac_l2 =
                  min( 2.0, t_bound_l2_s / this_elapsed_time_s );

                const double comm_frac =
                  min(2.0,1e9 * thpt_data_gbps / info.chip_bw_Bps);
                const double comm_l2_frac =
                  min(2.0,1e9 * thpt_data_gbps / l2_bandwidth_Bps );

                // Number of warps, rounded up.
                //
                const int num_wps = ( thd_per_block + 31 ) >> 5;

                // The maximum number of active blocks per MP for this
                // kernel when launched with a block size of thd_per_block.
                //
                const int max_bl_per_mp =
                  info.get_max_active_blocks_per_mp(kernel,thd_per_block);

                // Compute number of blocks available per MP based only on
                // the number of blocks.  This may be larger than the
                // number of blocks that can run.
                //
                const int bl_per_mp_available =
                  0.999 + double(num_blocks) / num_mp;

                // The number of active blocks is the minimum of what
                // can fit and how many are available.
                //
                const int bl_per_mp =
                  min( bl_per_mp_available, max_bl_per_mp );

                // Based on the number of blocks, compute the number of warps.
                //
                const int act_wps = num_wps * bl_per_mp;
                const int act_thds_gpu =
                  min( num_mp * act_wps * 32, num_blocks * thd_per_block );

                if ( wp_cnt == wp_start )
                  printf("\nKernel %s.  Uses %d registers.  n_l %d  d_l %d\n",
                         info.ki[kernel].name,
                         info.ki[kernel].cfa.numRegs, n_l, d_l );

                table.row_start();
                table.entry("wp",num_wps);
                if ( num_blocks > num_mp )
                  table.entry("ac",act_wps);
                table.entry
                  ("Imb","%4.2f",
                   prob_1b_workload_imbalance
                   (info,app,num_blocks,thd_per_block,ki->grp_size));
                table.entry("t/µs","%4.0f", this_elapsed_time_s * 1e6);
                table.entry
                  ("I/el","%4.1f",
                   NPerf_metric_value_get("sm__inst_executed.sum") * 32.0
                   / in_size_elts );

                if ( opt_p )
                  {
                    if ( false )
                    table.entry
                      ("%","%2.0f",
                       NPerf_metric_value_get
                       ("sm__instruction_throughput"
                        ".avg.pct_of_peak_sustained_elapsed") );

                    table.entry
                      ("BXW","%4.1f",
                       NPerf_metric_value_get
                       ("l1tex__data_bank_conflicts_pipe_lsu.sum") /
                       NPerf_metric_value_get ("l1tex__t_requests.sum"));

                    table.header_span_start("L2-Cache");
                    const double data_l2_l1_bytes =
                      NPerf_metric_value_get
                      ("l1tex__m_xbar2l1tex_read_bytes.sum");
                    const double data_l1_l2_bytes =
                      NPerf_metric_value_get
                      ("l1tex__m_l1tex2xbar_write_bytes.sum");

                    table.entry
                      ("N*R", "%4.1f", data_l2_l1_bytes / in_size_bytes );
                    table.entry
                      ("N*W", "%4.1f", data_l1_l2_bytes / out_size_bytes );

                    table.entry
                      ("%pk", "%3.0f",
                       100.0 * ( data_l1_l2_bytes + data_l2_l1_bytes )
                       / ( this_elapsed_time_s * l2_bandwidth_Bps ) );
                    table.entry
                      ("GB/s", "%4.0f",
                       1e-9*
                       ( NPerf_metric_value_get
                         ("l1tex__m_xbar2l1tex_read_bytes.sum")
                         + NPerf_metric_value_get
                         ("l1tex__m_l1tex2xbar_write_bytes.sum") )
                       / this_elapsed_time_s );

                    table.header_span_end();
                    table.header_span_start("DRAM");
                    if ( false )
                    table.entry
                      ("N*RW", "%4.1f",
                       ( NPerf_metric_value_get("dram__bytes_read.sum")
                         + NPerf_metric_value_get("dram__bytes_write.sum") )
                       / ( in_size_bytes + out_size_bytes ) );

                    table.entry
                      ("GB/s","%4.0f",
                       1e-9 *
                       ( NPerf_metric_value_get("dram__bytes_write.sum")
                         + NPerf_metric_value_get("dram__bytes_read.sum") )
                       / this_elapsed_time_s );

                    table.header_span_end();
                  }

                table.entry("FP θ","%4.0f", thpt_compute_gflops);
                // table.entry("GB/s","%4.0f", thpt_data_gbps);

                const size_t max_st_len =
                  max(5, output_width - 1 - table.row_len_get() );
                pStringF fmt("%%-%zds",max_st_len);

                typedef struct { double f; char c; } Elt;
                const bool l2_contained = amt_data_bytes < l2_size_bytes;
                const double frac_data = l2_contained ? frac_l2 : frac_mem;
                vector<Elt> segments =
                  { { frac_fp, '+' }, { frac_insn, '-' }, { frac_data, '*' } };

                string util_hdr = "=== Util: FP++  Insn--  ";
                util_hdr += l2_contained ? "L2** " : "Mem** ";
                if ( max_st_len > util_hdr.length() )
                  util_hdr += string(max_st_len - util_hdr.length(),'=');

                ranges::sort( segments, {}, &Elt::f );
                string bar;
                for ( Elt& e: segments )
                  if ( size_t p = e.f * max_st_len + 0.5; p > bar.length() )
                    bar += string( p - bar.length(), e.c );

                if ( bar.length() > max_st_len )
                  {
                    bar.resize(max_st_len);
                    bar[max_st_len-1] = '>';
                  }

                table.entry(util_hdr,fmt, bar, pTable::pT_Left);

                elapsed_time_s = min(this_elapsed_time_s,elapsed_time_s);

                // Copy output array from GPU to CPU.
                //
                CE( cudaMemcpy
                    ( l_out.data(), app.l_out_d, out_size_bytes,
                      cudaMemcpyDeviceToHost ) );
                int err_count = 0;
                elt_t max_err = 0;
                for ( int i=0; i<n_l; i++ )
                  for ( int j=0; j<d_l; j++ )
                    {
                      const int idx = i * d_l + j;

                      const elt_t err = fabs( l_out_check[idx] - l_out[idx] );
                      set_max( max_err, err );
                      if ( err > 1e-5 )
                        {
                          err_count++;
                          if ( err_count < 5 )
                            printf
                              ( "Error at vec %d elt %d: "
                                "%.7f != %.7f (correct)\n",
                                i, j, l_out[idx], l_out_check[idx] );
                        }
                    }
                if ( err_count )
                  printf("Total errors %d, max error %f\n",
                         err_count, max_err );
              }
            printf("%s",table.body_get());
          }}
      }
  }

}