#include <cuda_runtime.h>
#include <gp/cuda-gpuinfo.h>
#include <ptable.h>
#include <nperf.h>
#include <misc.h>
#include <mma.h>
#include <stdio.h>
#include <ranges>
#include <random>
#include <vector>
#include <algorithm>
using namespace nvcuda;
template<typename T1, typename T2>
requires is_integral<T1>::value && is_integral<T2>::value
__device__ __host__ constexpr T1
div_ceil( T1 a, T2 b ) { return (a+b-1)/b; }
typedef float elt_t;
struct Shape_Spec {
const int A_nrows, A_ncols, B_ncols;
};
constexpr Shape_Spec shape_specs[] =
{ { 128, 32, 32768 }, { 64, 32, 64 }, { 5120, 2048+16, 4096 } };
typedef uint32_t pClock_t;
constexpr size_t timing_item_size = 16;
struct __align__(timing_item_size) Timing_Item {
pClock_t time_start;
uint32_t smid_start;
pClock_t time_end;
uint32_t smid_end; };
static_assert( timing_item_size == sizeof(Timing_Item) );
struct Timing { Timing_Item *timing_items; };
__constant__ Timing timing_dev;
class Shape_POD {
public:
size_t A_nrows, A_ncols, B_ncols;
size_t A_bytes, AT_bytes, B_bytes, BT_bytes, CT_bytes;
size_t A_tf32_bytes, AT_tf32_bytes, B_tf32_bytes, BT_tf32_bytes;
size_t A_fp16_bytes, AT_fp16_bytes, B_fp16_bytes, BT_fp16_bytes;
elt_t *A_dev, *AT_dev, *B_dev, *BT_dev, *CT_dev;
float *A_tf32_dev, *AT_tf32_dev, *B_tf32_dev, *BT_tf32_dev;
half *A_fp16_dev, *AT_fp16_dev, *B_fp16_dev, *BT_fp16_dev;
};
__device__ uint32_t
smid_get()
{
uint smid = 0;
asm( "mov.u32 %0, %%smid;" : "=r" (smid) );
return smid;
}
__device__ void
timing_start(const int wp)
{
constexpr int wp_sz = 32;
if ( threadIdx.x % wp_sz == 0 )
{
Timing_Item& ti = timing_dev.timing_items[wp];
ti.time_start = clock();
ti.smid_start = smid_get();
}
__syncwarp();
}
__device__ void
timing_end(const int wp)
{
constexpr int wp_sz = 32;
__syncwarp();
if ( threadIdx.x % wp_sz == 0 )
{
Timing_Item& ti = timing_dev.timing_items[wp];
ti.time_end = clock();
ti.smid_end = smid_get();
}
}
__constant__ Shape_POD shape_dev;
__global__ void M_transpose
( elt_t* __restrict__ MT_dev, const elt_t* __restrict__ M_dev,
int M_nrows, int M_ncols )
{
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
const int M_elts = M_ncols * M_nrows;
for ( int i = tid; i < M_elts; i += num_threads )
{
const int M_col = i % M_ncols;
const int M_row = i / M_ncols;
MT_dev[ M_row + M_col * M_nrows ] = M_dev[ i ];
}
}
template< int t_A_nrows = 0, int t_A_ncols = 0 >
__global__ void mm_simple()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int wp = tid / wp_sz;
const int A_ncols = t_A_ncols ?: ld.A_ncols;
const int A_nrows = t_A_nrows ?: ld.A_nrows;
const int B_ncols = ld.B_ncols;
timing_start(wp);
for ( int C_row = blockIdx.x; C_row < A_nrows; C_row += gridDim.x )
for ( int C_col = threadIdx.x; C_col < B_ncols; C_col += blockDim.x )
{
elt_t C_accum = 0;
for ( int A_col = 0; A_col < A_ncols; A_col++ )
{
elt_t A_elt = ld.A_dev[ C_row * A_ncols + A_col ];
elt_t B_elt = ld.B_dev[ A_col * B_ncols + C_col ];
C_accum += A_elt * B_elt;
}
ld.CT_dev[ C_row + C_col * A_nrows ] = C_accum;
}
timing_end(wp);
}
template< int t_A_nrows = 0, int t_A_ncols = 0 >
__global__ void mm_simple2()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_blocks = gridDim.x;
const int num_threads = blockDim.x * num_blocks;
const int wp = tid / wp_sz;
const int A_ncols = t_A_ncols ?: ld.A_ncols;
const int A_nrows = t_A_nrows ?: ld.A_nrows;
const int B_ncols = ld.B_ncols;
timing_start(wp);
for ( int i = tid; true; i += num_threads )
{
const int C_col = i % B_ncols;
const int C_row = i / B_ncols;
if ( C_row >= A_nrows ) break;
elt_t C_accum = 0;
for ( int A_col = 0; A_col < A_ncols; A_col++ )
{
elt_t A_elt = ld.A_dev[ C_row * A_ncols + A_col ];
elt_t B_elt = ld.B_dev[ A_col * B_ncols + C_col ];
C_accum += A_elt * B_elt;
}
ld.CT_dev[ C_col * A_nrows + C_row ] = C_accum;
}
timing_end(wp);
}
template< int t_A_nrows = 0, int t_A_ncols = 0,
int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd_simple()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_blocks = gridDim.x;
const int num_threads = blockDim.x * num_blocks;
const int wp = tid / wp_sz;
const int A_ncols = t_A_ncols ?: ld.A_ncols;
const int A_nrows = t_A_nrows ?: ld.A_nrows;
const int B_ncols = ld.B_ncols;
timing_start(wp);
for ( int i = tid; true; i += num_threads )
{
const int C_row = i % A_nrows;
const int C_col_0 = i / A_nrows * m_wd;
if ( C_col_0 >= B_ncols ) break;
elt_t C_accums[m_wd]{};
for ( int A_col = 0; A_col < A_ncols; A_col++ )
{
elt_t B_elts[m_wd];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
if ( const int C_col = C_col_0 + l_wd;
C_col < B_ncols )
B_elts[l_wd] = ld.BT_dev[ C_col*A_ncols + A_col ];
elt_t A_elt = ld.AT_dev[ C_row + A_col * A_nrows ];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
C_accums[l_wd] += A_elt * B_elts[l_wd];
}
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
if ( const int C_col = C_col_0 + l_wd;
C_col < B_ncols )
ld.CT_dev[ C_col * A_nrows + C_row ] = C_accums[l_wd];
};
timing_end(wp);
}
template< int t_A_nrows = 0, int t_A_ncols = 0,
int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_blocks = gridDim.x;
const int num_threads = blockDim.x * num_blocks;
const int wp = tid / wp_sz;
const int A_ncols = t_A_ncols ?: ld.A_ncols;
const int A_nrows = t_A_nrows ?: ld.A_nrows;
const int B_ncols = ld.B_ncols;
timing_start(wp);
auto ii_iter = [&](int C_col_0, int C_row, bool C_col_check)
{
elt_t C_accums[m_wd]{};
for ( int A_col = 0; A_col < A_ncols; A_col++ )
{
elt_t B_elts[m_wd];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
{
const int C_col = C_col_0 + l_wd;
const bool C_col_invalid = C_col_check && C_col >= B_ncols;
B_elts[l_wd] =
C_col_invalid ? 0 : ld.BT_dev[ C_col*A_ncols + A_col ];
}
elt_t A_elt = ld.AT_dev[ C_row + A_col * A_nrows ];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
C_accums[l_wd] += A_elt * B_elts[l_wd];
}
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
if ( const int C_col = C_col_0 + l_wd; !C_col_check || C_col < B_ncols )
ld.CT_dev[ C_col * A_nrows + C_row ] = C_accums[l_wd];
};
for ( int i = tid; true; i += num_threads )
{
const int C_row = i % A_nrows;
const int C_col_0 = i / A_nrows * m_wd;
if ( C_col_0 >= B_ncols ) break;
if ( C_col_0 + m_wd <= B_ncols )
ii_iter(C_col_0,C_row,false); else
ii_iter(C_col_0,C_row,true); }
timing_end(wp);
}
template< int t_A_nrows = 0, int t_A_ncols = 0,
int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd_ht()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int wp = tid / wp_sz;
const int num_blocks = gridDim.x;
const int A_ncols = t_A_ncols ?: ld.A_ncols;
const int A_nrows = t_A_nrows ?: ld.A_nrows;
const int B_ncols = ld.B_ncols;
const int n_C_col_groups = div_ceil( B_ncols, m_wd );
const int n_C_cg_per_block = div_ceil( n_C_col_groups, num_blocks );
const int block_C_cg_start = blockIdx.x * n_C_cg_per_block;
const int block_C_cg_stop = block_C_cg_start + n_C_cg_per_block;
const int block_C_col_stop = min( B_ncols, block_C_cg_stop * m_wd );
constexpr int sector_len = m_ht > 1 ? 32 : 1;
constexpr int ch_A_nrows = m_ht * sector_len;
constexpr int n_ch_A = div_ceil( A_nrows, ch_A_nrows );
timing_start(wp);
auto ii_iter = [&]( int C_col_0, int C_row_0,
bool C_col_check, bool C_row_check)
{
elt_t C_accums[m_wd][m_ht]{};
for ( int A_col = 0; A_col < A_ncols; A_col++ )
{
elt_t B_elts[m_wd];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
{
const int C_col = C_col_0 + l_wd;
const bool C_col_invalid = C_col_check && C_col >= block_C_col_stop;
B_elts[l_wd] =
C_col_invalid ? 0 : ld.BT_dev[ C_col*A_ncols + A_col ];
}
for ( int l_ht = 0; l_ht < m_ht; l_ht++ )
{
const int C_row = C_row_0 + l_ht * sector_len;
const bool C_row_invalid = C_row_check && C_row >= A_nrows;
elt_t A_elt =
C_row_invalid ? 0 : ld.AT_dev[ C_row + A_col * A_nrows ];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
C_accums[l_wd][l_ht] += A_elt * B_elts[l_wd];
}
}
for ( int l_ht = 0; l_ht < m_ht; l_ht++ )
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
{
const int C_col = C_col_0 + l_wd;
const int C_row = C_row_0 + l_ht * sector_len;
if ( ( !C_col_check || C_col < block_C_col_stop )
&& ( !C_row_check || C_row < A_nrows ) )
ld.CT_dev[ C_col * A_nrows + C_row ] = C_accums[l_wd][l_ht];
}
};
for ( int i = threadIdx.x; true; i += blockDim.x )
{
int ie = i;
const int C_row_lo = ie % sector_len;
ie /= sector_len;
const int C_row_hi = ie % n_ch_A;
ie /= n_ch_A;
const int C_row_0 = C_row_hi * ch_A_nrows + C_row_lo;
const int i_cg_0 = block_C_cg_start + ie;
const int C_col_0 = i_cg_0 * m_wd;
const bool no_work = C_col_0 >= block_C_col_stop || C_row_0 >= A_nrows;
const bool need_check = !no_work
&& ( C_col_0 + m_wd > block_C_col_stop
|| C_row_0 + ch_A_nrows > A_nrows );
const uint32_t wp_mask = 0xffffffff;
const bool d_check = m_wd > 1 || m_ht > 1;
if ( !d_check && C_col_0 >= block_C_col_stop ) break;
const uint32_t we_check =
d_check ? __ballot_sync(wp_mask, need_check) : need_check;
if ( no_work )
{ }
else if ( we_check == 0 )
ii_iter(C_col_0,C_row_0,false,false);
else
ii_iter(C_col_0,C_row_0,true,A_nrows!=ch_A_nrows*n_ch_A);
if ( !d_check ) continue;
const bool im_done = C_col_0 >= block_C_col_stop;
if ( __all_sync(wp_mask,im_done) ) break;
}
timing_end(wp);
}
template< int t_A_nrows = 0, int t_A_ncols = 0,
int m_wd = 1, int m_ht = 1, int m_dp = 1>
__global__ void mm_tile_wd_ht_dp()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int lane = threadIdx.x % wp_sz;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int wp = tid / wp_sz;
const int num_blocks = gridDim.x;
const int A_ncols = t_A_ncols ?: ld.A_ncols;
const int A_nrows = t_A_nrows ?: ld.A_nrows;
const int B_ncols = ld.B_ncols;
const int sub_lane = lane % m_dp;
const bool use_AT = m_dp < 8;
const int m_ht_stride_lg =
m_ht == 1 ? 0 :
use_AT ? 5 :
m_dp < 32 ? 3 : 0;
const int m_ht_stride = 1 << m_ht_stride_lg;
const uint ch_A_nrows_min = m_ht > 1 ? m_ht_stride * m_ht : 32;
const int ch_A_nrows_min_lg = bit_width(ch_A_nrows_min-1);
const uint ch_C_nelts = blockDim.x * m_wd * m_ht / m_dp;
const int ch_C_nelts_lg = bit_width(ch_C_nelts-1);
assert( 1 << ch_C_nelts_lg == ch_C_nelts );
constexpr int m_wd_lg = bit_width( unsigned( m_wd ) - 1 );
const int ch_A_nrows_lg =
min( ch_C_nelts_lg - m_wd_lg,
max( ch_A_nrows_min_lg, ch_C_nelts_lg / 2) );
const int ch_A_nrows = 1 << ch_A_nrows_lg;
const int ch_B_ncols = 1 << ch_C_nelts_lg - ch_A_nrows_lg;
assert( ch_A_nrows >= m_ht );
assert( ch_B_ncols >= m_wd );
assert( ch_A_nrows * ch_B_ncols == ch_C_nelts );
const int ch_A_nthds = ch_A_nrows / m_ht;
assert( ch_A_nthds * m_ht == ch_A_nrows );
const int ch_B_nthds = ch_B_ncols / m_wd;
assert( ch_A_nthds * ch_B_nthds * m_dp == blockDim.x );
assert( ch_B_nthds * m_wd == ch_B_ncols );
const int n_ch_A_floor = max( 1, A_nrows / ch_A_nrows );
const int n_ch_B_floor = max( 1, B_ncols / ch_B_ncols );
const int num_blocks_sqrt = sqrt(num_blocks) + 0.5f;
int util_max = 0;
int util_max_fac1 = 0, util_max_fac2 = 0;
const bool n_ch_A_smaller = n_ch_A_floor <= n_ch_B_floor;
const int n_ch_X_min = min( n_ch_A_floor, n_ch_B_floor );
const int n_ch_X_max = max( n_ch_A_floor, n_ch_B_floor );
for ( int i = 1; i <= num_blocks_sqrt; i++ )
{
const int fac2 = num_blocks / i;
int util = min(i,n_ch_X_min) * min(fac2,n_ch_X_max);
if ( util < util_max ) continue;
util_max = util;
util_max_fac1 = i;
util_max_fac2 = fac2;
}
const int C_row_blk_mod = n_ch_A_smaller ? util_max_fac1 : util_max_fac2;
const int C_col_blk_mod = n_ch_A_smaller ? util_max_fac2 : util_max_fac1;
timing_start(wp);
constexpr int m_pa = m_dp > 1 ? 1 : m_ht * m_wd == 1 ? 4 : 1;
auto ii_iter = [&]( int C_col_0, int C_row_0,
bool C_col_check, bool C_row_check)
{
elt_t C_accums[m_pa][m_wd][m_ht]{};
const uint32_t am = m_dp > 1 ? __activemask() : 0;
for ( int A_col_o = 0; A_col_o < A_ncols; A_col_o += m_dp )
{
const int plane = C_col_check ? 0 : A_col_o % m_pa;
const int A_col = A_col_o + sub_lane;
if ( A_ncols/m_dp*m_dp != A_ncols && A_col >= A_ncols ) break;
elt_t B_elts[m_wd];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
{
const int C_col = C_col_0 + l_wd;
const bool C_col_invalid = C_col_check && C_col >= B_ncols;
B_elts[l_wd] =
C_col_invalid ? 0 : ld.BT_dev[ C_col*A_ncols + A_col ];
}
for ( int l_ht = 0; l_ht < m_ht; l_ht++ )
{
const int C_row = C_row_0 + l_ht * m_ht_stride;
const bool C_row_invalid = C_row_check && C_row >= A_nrows;
elt_t A_elt =
C_row_invalid ? 0 :
use_AT ? ld.AT_dev[ C_row + A_col * A_nrows ]
: ld.A_dev[ A_col + C_row * A_ncols ];
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
C_accums[plane][l_wd][l_ht] += A_elt * B_elts[l_wd];
}
}
for ( int l_ht = 0; l_ht < m_ht; l_ht++ )
for ( int l_wd = 0; l_wd < m_wd; l_wd++ )
{
elt_t C_accum = C_accums[0][l_wd][l_ht];
for ( int l_pa = 1; l_pa < m_pa; l_pa++ )
C_accum += C_accums[l_pa][l_wd][l_ht];
for ( int dist = 1; dist < m_dp; dist *= 2 )
C_accum += __shfl_down_sync( am, C_accum, dist, m_dp);
const int C_col = C_col_0 + l_wd;
const int C_row = C_row_0 + l_ht * m_ht_stride;
if ( sub_lane == 0
&& ( !C_col_check || C_col < B_ncols )
&& ( !C_row_check || C_row < A_nrows ) )
ld.CT_dev[ C_col * A_nrows + C_row ] = C_accum;
}
};
const int tid_um = threadIdx.x / m_dp;
const int tid_um_lo = tid_um % ch_A_nthds;
const int C_row_lo = tid_um_lo % m_ht_stride;
const int C_row_lmid = ( tid_um_lo - C_row_lo ) * m_ht;
const int C_row_mid = blockIdx.x % C_row_blk_mod;
const int C_row_thd = C_row_lo + C_row_lmid + C_row_mid * ch_A_nrows;
const int C_col_lo = tid_um / ch_A_nthds;
const int C_col_mid = blockIdx.x / C_row_blk_mod;
const int C_col_thd = C_col_lo * m_wd + C_col_mid * ch_B_ncols;
if ( C_col_mid >= C_col_blk_mod ) { timing_end(wp); return; }
assert( C_row_thd < C_row_blk_mod * ch_A_nrows );
assert( C_col_thd < C_col_blk_mod * ch_B_ncols );
const uint32_t wp_mask = 0xffffffff;
const int A_nrows_r =
m_ht > 1
? div_ceil( A_nrows, m_ht_stride * m_ht ) * m_ht_stride * m_ht : A_nrows;
for ( int i = 0; true; i++ )
{
const int C_row_0_raw = C_row_thd + i * C_row_blk_mod * ch_A_nrows;
const int C_row_0 = C_row_0_raw % A_nrows_r;
const int iC_col = C_row_0_raw / A_nrows_r;
const int C_col_0 = C_col_thd + iC_col * C_col_blk_mod * ch_B_ncols;
const bool thd_no_C_col = C_col_0 >= B_ncols;
const bool no_work = C_col_0 >= B_ncols || C_row_0 >= A_nrows;
const bool need_check = !no_work
&& ( C_col_0 + m_wd > B_ncols
|| C_row_0 + ( m_ht - 1 ) * m_ht_stride >= A_nrows );
const bool d_check = true && ( m_wd > 1 || m_ht > 1 );
if ( !d_check && C_col_0 >= B_ncols ) break;
const uint32_t we_check =
d_check ? __ballot_sync(wp_mask, need_check) : need_check;
if ( d_check && !we_check && __all_sync(wp_mask,thd_no_C_col) ) break;
if ( no_work )
{ }
else if ( we_check == 0 )
ii_iter(C_col_0,C_row_0,false,false);
else
ii_iter(C_col_0,C_row_0,true,true);
}
timing_end(wp);
}
__global__ void
convert_float_tf32()
{
const Shape_POD& ld = shape_dev;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
const int A_ncols = ld.A_ncols;
const int A_nrows = ld.A_nrows;
const int B_ncols = ld.B_ncols;
const int A_nelts = A_nrows * A_ncols;
const int B_nelts = A_ncols * B_ncols;
using namespace nvcuda::wmma;
for ( int i = tid; i < A_nelts; i += num_threads )
{
ld.A_tf32_dev[i] = __float_to_tf32( ld.A_dev[i] );
ld.AT_tf32_dev[i] = __float_to_tf32( ld.AT_dev[i] );
ld.A_fp16_dev[i] = __float2half( ld.A_dev[i] );
ld.AT_fp16_dev[i] = __float2half( ld.AT_dev[i] );
}
__syncwarp();
for ( int i = tid; i < B_nelts; i += num_threads )
{
ld.B_tf32_dev[i] = __float_to_tf32( ld.B_dev[i] );
ld.BT_tf32_dev[i] = __float_to_tf32( ld.BT_dev[i] );
ld.B_fp16_dev[i] = __float2half( ld.B_dev[i] );
ld.BT_fp16_dev[i] = __float2half( ld.BT_dev[i] );
}
}
using wmma::precision::tf32;
template< int t_A_nrows = 0, int t_A_ncols = 0, int t_B_ncols = 0,
int m_wd_n = 1, int m_ht_n = 1, typename AB_precision = half >
__global__ void
mm_tc()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int wp = tid / wp_sz;
const int n_warps = blockDim.x * gridDim.x / wp_sz;
constexpr int A_ncols = t_A_ncols;
constexpr int A_nrows = t_A_nrows;
constexpr int B_nrows = A_ncols;
constexpr int B_ncols = t_B_ncols;
constexpr int C_nrows = A_nrows;
assert( A_nrows == ld.A_nrows );
assert( B_ncols == ld.B_ncols );
using namespace nvcuda::wmma;
using ab_sto_t = helper_traits<AB_precision>::storage_element_type;
using ab_elt_t = helper_traits<AB_precision>::element_type;
constexpr bool h = is_same_v<ab_sto_t,half>;
constexpr int tm = 16; constexpr int tn = 16; constexpr int tk = h ? 16 : 8;
assert( A_nrows % tm == 0 );
assert( B_ncols % tn == 0 );
assert( A_ncols % tk == 0 );
timing_start(wp);
constexpr int C_nrow_tiles = C_nrows / tm;
constexpr int m_wd = tm * m_wd_n;
using aorg = row_major;
using borg = col_major;
constexpr bool a_row_major = is_same_v<aorg,row_major>;
constexpr bool b_row_major = is_same_v<borg,row_major>;
typedef size_t idx_t;
const idx_t a_row_stride = a_row_major ? A_ncols : 1;
const idx_t a_col_stride = a_row_major ? 1 : A_nrows;
const idx_t b_row_stride = b_row_major ? B_ncols : 1;
const idx_t b_col_stride = b_row_major ? 1 : B_nrows;
const idx_t a_stride = max(a_row_stride,a_col_stride);
const idx_t b_stride = max(b_row_stride,b_col_stride);
ab_sto_t* const a_ptr =
a_row_major
? ( h ? (ab_sto_t*)ld.A_fp16_dev : (ab_sto_t*)ld.A_tf32_dev )
: ( h ? (ab_sto_t*)ld.AT_fp16_dev : (ab_sto_t*)ld.AT_tf32_dev );
ab_sto_t* const b_ptr =
b_row_major
? ( h ? (ab_sto_t*)ld.B_fp16_dev : (ab_sto_t*)ld.B_tf32_dev )
: ( h ? (ab_sto_t*)ld.BT_fp16_dev : (ab_sto_t*)ld.BT_tf32_dev );
for ( int i = wp; true; i += n_warps )
{
const ssize_t C_row_0 = i % C_nrow_tiles * tm;
const int C_col_0_raw = i / C_nrow_tiles * m_wd;
if ( C_col_0_raw >= B_ncols ) break;
const bool partial_cols = C_col_0_raw + m_wd > B_ncols;
const ssize_t C_col_0 = partial_cols ? B_ncols - m_wd : C_col_0_raw;
fragment<matrix_a, tm, tn, tk, ab_elt_t, aorg> tile_a;
fragment<matrix_b, tm, tn, tk, ab_elt_t, borg> tile_b;
fragment<accumulator, tm, tn, tk, float> tile_C_acc[m_wd_n];
for ( auto& tca: tile_C_acc ) fill_fragment(tca, 0);
for ( ssize_t i_k = 0; i_k < A_ncols; i_k += tk )
{
load_matrix_sync
( tile_a, a_ptr + C_row_0 * a_row_stride + i_k * a_col_stride,
a_stride );
for ( ssize_t i_wd = 0; i_wd < m_wd_n; i_wd++ )
{
load_matrix_sync
( tile_b, b_ptr + i_k * b_row_stride
+ ( C_col_0 + i_wd * tn ) * b_col_stride,
b_stride );
mma_sync( tile_C_acc[i_wd], tile_a, tile_b, tile_C_acc[i_wd] );
}
}
for ( int i_wd = 0; i_wd < m_wd_n; i_wd++ )
store_matrix_sync
( &ld.CT_dev[ C_row_0 + ( C_col_0 + i_wd * tn ) * A_nrows ],
tile_C_acc[i_wd], A_nrows, mem_col_major );
}
timing_end(wp);
}
template<>
__global__ void
mm_tc< 2048, 2064, 4096, 1, 1, half >()
{
constexpr int A_nrows = 2048, A_ncols = 2048, B_ncols = 4096;
constexpr int C_nrows = A_nrows, C_ncols = B_ncols, B_nrows = A_ncols;
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int wp = tid / wp_sz;
const int n_warps = blockDim.x * gridDim.x / wp_sz;
assert( A_nrows == ld.A_nrows );
constexpr int tm = 16, tk = 16, tn = 16;
assert( A_ncols % tk == 0 );
assert( A_nrows % tm == 0 );
timing_start(wp);
using namespace nvcuda::wmma;
constexpr int C_nrow_tiles = C_nrows / tm;
half* const a_fp16 = &ld.A_fp16_dev[ 0 ];
half* const bt_fp16 = &ld.BT_fp16_dev[ 0 ];
constexpr int n_tiles = C_nrows / tm * C_ncols / tn;
for ( int i = wp; i < n_tiles; i += n_warps )
{
const ssize_t C_row_0 = i % C_nrow_tiles * tm;
const ssize_t C_col_0 = i / C_nrow_tiles * tn;
fragment<matrix_a, tm, tn, tk, half, row_major> tile_a;
fragment<matrix_b, tm, tn, tk, half, col_major> tile_b;
fragment<accumulator, tm, tn, tk, float> tile_C_acc;
fill_fragment(tile_C_acc, 0);
for ( ssize_t i_k = 0; i_k < A_ncols; i_k += tk )
{
load_matrix_sync
( tile_a, a_fp16 + C_row_0 * A_ncols + i_k, A_ncols );
load_matrix_sync
( tile_b, bt_fp16 + C_col_0 * B_nrows + i_k, B_nrows );
mma_sync( tile_C_acc, tile_a, tile_b, tile_C_acc );
}
store_matrix_sync
( &ld.CT_dev[ C_row_0 + C_col_0 * A_nrows ],
tile_C_acc, A_nrows, mem_col_major );
}
timing_end(wp);
}
template< int t_A_nrows = 0, int t_A_ncols = 0, int t_B_ncols = 0,
int m_wd_n = 1, int m_ht_n = 1, typename AB_precision = half >
__global__
__launch_bounds__( m_wd_n*m_ht_n >= 16 ? 256 : m_wd_n*m_ht_n >= 4 ? 512 : 0 )
void
mm_hw05()
{
const Shape_POD& ld = shape_dev;
constexpr int wp_sz = 32;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int wp = tid / wp_sz;
constexpr int A_ncols = t_A_ncols;
constexpr int A_nrows = t_A_nrows;
constexpr int B_nrows = A_ncols;
constexpr int B_ncols = t_B_ncols;
constexpr int C_nrows = A_nrows;
constexpr int C_ncols = B_ncols;
assert( A_nrows == ld.A_nrows );
assert( B_ncols == ld.B_ncols );
using namespace nvcuda::wmma;
using ab_sto_t = helper_traits<AB_precision>::storage_element_type;
using ab_elt_t = helper_traits<AB_precision>::element_type;
constexpr bool h = is_same_v<ab_sto_t,half>;
constexpr int tm = 16; constexpr int tn = 16; constexpr int tk = h ? 16 : 8;
assert( A_nrows % tm == 0 );
assert( B_ncols % tn == 0 );
assert( A_ncols % tk == 0 );
timing_start(wp);
using aorg = row_major;
using borg = col_major;
constexpr bool a_row_major = is_same_v<aorg,row_major>;
constexpr bool b_row_major = is_same_v<borg,row_major>;
ab_sto_t* const a_ptr =
a_row_major
? ( h ? (ab_sto_t*)ld.A_fp16_dev : (ab_sto_t*)ld.A_tf32_dev )
: ( h ? (ab_sto_t*)ld.AT_fp16_dev : (ab_sto_t*)ld.AT_tf32_dev );
ab_sto_t* const b_ptr =
b_row_major
? ( h ? (ab_sto_t*)ld.B_fp16_dev : (ab_sto_t*)ld.B_tf32_dev )
: ( h ? (ab_sto_t*)ld.BT_fp16_dev : (ab_sto_t*)ld.BT_tf32_dev );
typedef size_t idx_t;
const idx_t a_row_stride = a_row_major ? A_ncols : 1;
const idx_t a_col_stride = a_row_major ? 1 : A_nrows;
const idx_t b_row_stride = b_row_major ? B_ncols : 1;
const idx_t b_col_stride = b_row_major ? 1 : B_nrows;
const idx_t a_stride = max(a_row_stride,a_col_stride);
const idx_t b_stride = max(b_row_stride,b_col_stride);
constexpr int m_wd = tn * m_wd_n;
constexpr int m_ht = tm * m_ht_n;
static_assert( m_wd % tn == 0 );
static_assert( m_ht % tm == 0 );
constexpr uint warp_tile_nrows = tm * m_ht_n;
constexpr uint warp_tile_ncols = tn * m_wd_n;
constexpr int warp_tile_nrows_bits = bit_width( warp_tile_nrows - 1 );
constexpr int warp_tile_ncols_bits = bit_width( warp_tile_ncols - 1 );
const uint n_wps_p_bl = blockDim.x / wp_sz;
const int n_wps_p_bl_bits = bit_width( n_wps_p_bl - 1 );
const int block_tile_n_wp_rows_bits =
max( ( n_wps_p_bl_bits >> 1 )
- warp_tile_nrows_bits + warp_tile_ncols_bits, 0 );
const int block_tile_n_wp_cols =
max( 1, n_wps_p_bl >> block_tile_n_wp_rows_bits );
const int block_tile_n_wp_rows = n_wps_p_bl / block_tile_n_wp_cols;
assert( n_wps_p_bl == block_tile_n_wp_cols * block_tile_n_wp_rows );
const int block_tile_nrows = warp_tile_nrows * block_tile_n_wp_rows;
const int block_tile_ncols = warp_tile_ncols * block_tile_n_wp_cols;
const int blk_wp_idx = threadIdx.x / wp_sz;
const int C_row_0_offset =
( blk_wp_idx / block_tile_n_wp_cols ) * warp_tile_nrows;
const int C_col_0_offset =
( blk_wp_idx % block_tile_n_wp_cols ) * warp_tile_ncols;
const int n_block_tiles_col = div_ceil( C_ncols, block_tile_ncols );
constexpr bool partial_cols_possible = C_ncols % warp_tile_ncols;
static_assert( !partial_cols_possible );
for ( int i = blockIdx.x; true; i += gridDim.x )
{
const int C_col_0_block = i % n_block_tiles_col * block_tile_ncols;
const int C_row_0_block = i / n_block_tiles_col * block_tile_nrows;
const int C_col_0_raw = C_col_0_block + C_col_0_offset;
const int C_row_0_raw = C_row_0_block + C_row_0_offset;
if ( C_col_0_raw >= C_ncols ) continue;
if ( C_row_0_raw >= C_nrows ) break;
const int C_col_0 = C_col_0_raw % C_ncols;
const int C_row_0 = C_row_0_raw % C_nrows;
fragment<accumulator, tm, tn, tk, float> tile_C_acc[m_ht_n][m_wd_n];
for ( auto& tca: tile_C_acc )
for ( auto& tcb: tca ) fill_fragment(tcb, 0);
for ( ssize_t i_k = 0; i_k < A_ncols; i_k += tk )
{
fragment<matrix_a, tm, tn, tk, ab_elt_t, aorg> tile_a[m_ht_n];
for ( size_t i_ht = 0; i_ht < m_ht_n; i_ht++ )
load_matrix_sync
( tile_a[i_ht],
a_ptr + ( C_row_0 + i_ht * tm ) * a_row_stride
+ i_k * a_col_stride,
a_stride );
for ( size_t i_wd = 0; i_wd < m_wd_n; i_wd++ )
{
fragment<matrix_b, tm, tn, tk, ab_elt_t, borg> tile_b;
load_matrix_sync
( tile_b, b_ptr + i_k * b_row_stride
+ ( C_col_0 + i_wd * tn ) * b_col_stride,
b_stride );
for ( int i_ht = 0; i_ht < m_ht_n; i_ht++ )
mma_sync( tile_C_acc[i_ht][i_wd], tile_a[i_ht], tile_b,
tile_C_acc[i_ht][i_wd] );
}
}
for ( int i_wd = 0; i_wd < m_wd_n; i_wd++ )
for ( ssize_t i_ht = 0; i_ht < m_ht_n; i_ht++ )
store_matrix_sync
( &ld.CT_dev[ C_row_0 + i_ht * tm
+ ( C_col_0 + i_wd * tn ) * A_nrows ],
tile_C_acc[i_ht][i_wd], A_nrows, mem_col_major );
}
timing_end(wp);
}
GPU_Info
print_gpu_and_kernel_info()
{
GPU_Info info;
gpu_info_print();
int dev = gpu_choose_index();
CE(cudaSetDevice(dev));
printf("Using GPU %d\n",dev);
info.get_gpu_info(dev);
return info;
}
enum AB_Format { AB_elt_t, AB_tf32, AB_bf16, AB_fp16, AB_ENUM_SIZE };
const char* ab_format_txt[] = { "FP32", "TF32", "BF16", "FP16" };
class Shape : public Shape_POD {
public:
vector<clock_t> clock_start, clcok_end;
vector<elt_t> A, AT, B, BT, CT;
vector<elt_t> CT_cpu;
bool verbose;
int64_t est_mm_fp; int64_t est_mm_ls_ops; int64_t est_mm_l_bytes; int64_t est_mm_s_bytes;
Shape(const Shape_Spec& ss)
{
A_nrows = ss.A_nrows;
A_ncols = ss.A_ncols;
B_ncols = ss.B_ncols;
A_dev = nullptr; verbose = false;
setup();
};
void setup()
{
assert( A_nrows && A_ncols && B_ncols );
assert( !A_dev );
size_t A_size = A_nrows * A_ncols;
A.resize(A_size);
AT.resize(A_size);
size_t B_size = A_ncols * B_ncols;
size_t CT_size = A_nrows * B_ncols;
B.resize( B_size );
BT.resize( B_size );
CT_cpu.resize( CT_size );
const int seed = 2735;
default_random_engine re(seed);
uniform_real_distribution<elt_t> uni_pm1(-1,1);
auto rand_pm1 = [&]() { return uni_pm1(re); };
if ( verbose )
printf("Initializing input and weight vectors randomly.\n");
ranges::generate(B,rand_pm1);
ranges::generate(A, [&](){ return uni_pm1(re)/A_ncols; } );
#pragma omp parallel for
for ( size_t C_col = 0; C_col < B_ncols; C_col++ )
for ( size_t C_row = 0; C_row < A_nrows; C_row++ )
{
elt_t C_accum = 0;
for ( size_t A_col = 0; A_col < A_ncols; A_col++ )
C_accum += A[ C_row * A_ncols + A_col ]
* B[ A_col * B_ncols + C_col ];
CT_cpu[ C_row + C_col * A_nrows ] = C_accum;
}
CT.resize( CT_size );
if ( verbose )
printf("Preparing CUDA storage.\n");
# define CMALc(var, n_elts) \
var##_bytes = n_elts * sizeof( var##_dev[0] ); \
CE( cudaMalloc( &var##_dev, var##_bytes ) );
# define CMALC(var) CMALc(var,var.size())
CMALC( A ); CMALC( AT );
CMALC( B ); CMALC( BT );
CMALc( A_tf32, A_size ); CMALc( AT_tf32, A_size );
CMALc( B_tf32, B_size ); CMALc( BT_tf32, B_size );
CMALc( A_fp16, A_size ); CMALc( AT_fp16, A_size );
CMALc( B_fp16, B_size ); CMALc( BT_fp16, B_size );
CMALC( CT );
# undef CMALC
# undef CMALc
CE( cudaMemcpyAsync
( B_dev, B.data(), B_bytes, cudaMemcpyHostToDevice ) );
CE( cudaMemcpyAsync
( A_dev, A.data(), A_bytes, cudaMemcpyHostToDevice ) );
est_mm_fp = A_nrows * A_ncols * B_ncols;
est_mm_s_bytes = CT_bytes;
est_mm_l_bytes = int64_t(A_bytes) + B_bytes;
est_mm_ls_ops = B_ncols * A_nrows * ( A_ncols * 2 + 1 );
}
void launch_prep()
{
CE( cudaMemsetAsync( CT_dev, 0, CT_bytes ) );
Shape_POD for_dev(*this);
CE( cudaMemcpyToSymbol
(shape_dev,&for_dev,sizeof(for_dev),0,cudaMemcpyHostToDevice) );
}
void from_dev_and_check(AB_Format ab_fmt)
{
CE( cudaMemcpy( CT.data(), CT_dev, CT_bytes, cudaMemcpyDeviceToHost ) );
int n_err = 0, n_err_z = 0;
double max_err = 0;
using wmma::precision::tf32;
constexpr int tf32_prec_bits = 10;
constexpr int fp16_prec_bits = 11;
float tol = ( ab_fmt == AB_tf32
? 1.0 / float( 1 << tf32_prec_bits )
: ab_fmt == AB_fp16
? 1.0 / float( 1 << fp16_prec_bits )
: numeric_limits<elt_t>::epsilon() ) * A_ncols * 5;
for ( size_t i=0; i<CT_cpu.size(); i++ )
{
elt_t hd = CT[i];
elt_t hc = CT_cpu[i];
float d = fabs( hd - hc );
if ( d > max_err ) max_err = d;
if ( d > tol )
{
n_err++;
if ( hd == 0 ) n_err_z++;
}
}
if ( !n_err && !verbose ) return;
printf("Number of errors %d out of %zd. %d zeros\n",
n_err, CT_cpu.size(), n_err_z);
printf("Max err = %.9f, tolerance = %.9f\n", max_err, tol );
}
};
int
main(int argc, char **argv)
{
NPerf_init();
const bool opt_use_nperf_time = false;
GPU_Info info = print_gpu_and_kernel_info();
NPerf_metric_collect("sm__inst_executed.sum");
NPerf_metric_collect("sm__inst_executed.max");
NPerf_metric_collect("sm__inst_executed.avg");
NPerf_metric_collect
("sm__instruction_throughput.avg.pct_of_peak_sustained_elapsed");
NPerf_metric_collect
("sm__average_thread_inst_executed_pred_on_per_inst_executed_realtime.pct");
NPerf_metric_collect
("smsp__thread_inst_executed_pred_off.sum");
NPerf_metric_collect
("smsp__thread_inst_executed_pred_on.sum");
NPerf_metric_collect
("smsp__thread_inst_executed.sum");
NPerf_metric_collect
("smsp__thread_inst_executed_per_inst_executed.ratio");
NPerf_metric_collect("l1tex__t_sectors.sum");
NPerf_metric_collect("l1tex__t_requests.sum");
NPerf_metric_collect("l1tex__data_bank_conflicts_pipe_lsu.sum");
NPerf_metric_collect("smsp__sass_l1tex_tags_mem_global.sum");
NPerf_metric_collect("l1tex__m_xbar2l1tex_read_bytes.sum");
NPerf_metric_collect("l1tex__m_l1tex2xbar_write_bytes.sum");
NPerf_metric_collect("dram__bytes_read.sum");
NPerf_metric_collect("dram__bytes_write.sum");
NPerf_metric_collect
("sm__sass_average_branch_targets_threads_uniform.pct");
constexpr int wp_sz = 32;
vector<Shape> shapes;
for ( auto& ls: shape_specs ) shapes.emplace_back(ls);
const int n_shapes = shapes.size();
struct App_Kernel_Info {
App_Kernel_Info
(Kernel_Info& k,const char *name_base, const char *name, int i,
int p_m_wd, int p_m_ht, int thd_p_elt, AB_Format fmt):
k_ptr(k.func_ptr),name_base(name_base), name(name),
shape_idx{i},m_wd(p_m_wd),m_ht(p_m_ht),thd_p_elt(thd_p_elt),
ab_format(fmt) {}
GPU_Info_Func k_ptr;
const char *name_base;
const char *name;
const int shape_idx;
const int m_wd, m_ht, thd_p_elt;
const AB_Format ab_format;
};
vector<App_Kernel_Info> kernels;
#define EXAMINE_KERNEL(k,kb,sidx,m_wd,m_ht,thd_p_elt,fmt) \
{ const int idx = kernels.size(); \
kernels.emplace_back \
( info.GET_INFO((k)), kb,#k,sidx,m_wd,m_ht,thd_p_elt,fmt); }
#define SPECIFY_KERNEL(k,sidx) \
EXAMINE_KERNEL((k<shape_specs[sidx].A_ncols>),sidx,1,1,1,AB_elt_t);
#define xSPECIFY_KERNEL2(k,sidx,m_wd) \
EXAMINE_KERNEL((k<shape_specs[sidx].A_ncols,m_wd>),sidx,m_wd,1,1);
#define SPECIFY_KERNEL2(k,sidx) \
EXAMINE_KERNEL( (k<shape_specs[sidx].A_nrows,shape_specs[sidx].A_ncols>), \
#k,sidx,1,1,1,AB_elt_t);
#define SPECIFY_KERNEL_tc(k,sidx,fmt,m_wd_n,m_ht_n,ty) \
EXAMINE_KERNEL( (k<shape_specs[sidx].A_nrows,shape_specs[sidx].A_ncols, \
shape_specs[sidx].B_ncols, \
m_wd_n, m_ht_n, ty >), \
#k, sidx, m_wd_n, m_ht_n, 0, fmt);
#define SPECIFY_KERNEL4(k,sidx,m_wd,m_ht,thd_p_elt) \
EXAMINE_KERNEL( ( k< shape_specs[sidx].A_nrows, \
shape_specs[sidx].A_ncols, m_wd, m_ht, thd_p_elt > ), \
#k, sidx, m_wd, m_ht, thd_p_elt, AB_elt_t);
#define SPECIALIZE_KERNEL(sidx) \
SPECIFY_KERNEL4(mm_tile_wd_ht_dp,sidx,8,8,4); \
SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,1,1,half); \
SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,2,1,half); \
SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,4,1,half); \
SPECIFY_KERNEL_tc(mm_tc,sidx,AB_fp16,8,1,half); \
SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,1,1,half); \
SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,2,2,half); \
SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,4,1,half); \
SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,2,8,half); \
SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,8,2,half); \
SPECIFY_KERNEL_tc(mm_hw05,sidx,AB_fp16,4,4,half);
SPECIALIZE_KERNEL(0);
SPECIALIZE_KERNEL(2);
#undef SPECIALIZE_KERNEL
const bool want_kernel_info = false;
if ( want_kernel_info )
{
printf("\nCUDA Kernel Resource Usage:\n");
for ( int i=0; i<info.num_kernels; i++ )
{
printf("For %s:\n", info.ki[i].name);
printf(" %6zd shared, %zd const, %zd loc, %d regs; "
"%d max threads per block.\n",
info.ki[i].cfa.sharedSizeBytes,
info.ki[i].cfa.constSizeBytes,
info.ki[i].cfa.localSizeBytes,
info.ki[i].cfa.numRegs,
info.ki[i].cfa.maxThreadsPerBlock);
}
}
const int num_mp = info.cuda_prop.multiProcessorCount;
const int arg1_int = argc < 2 ? num_mp : atoi(argv[1]);
const int num_blocks =
arg1_int == 0 ? num_mp :
arg1_int < 0 ? -arg1_int * num_mp : arg1_int;
const int wp_per_block_arg = argc < 3 ? 0 : atoi(argv[2]);
const int wp_per_block_goal =
wp_per_block_arg == 0 ? 32 : wp_per_block_arg;
const int n_threads = num_blocks * wp_per_block_goal * wp_sz;
printf("Kernel timing based on %s\n",
opt_use_nperf_time ? "sm__cycles_elapsed.max"
: "clock() data from kernels.");
for ( int i=0; i<n_shapes; i++ )
{
Shape& l = shapes[i];
printf
("Shape %d: A_nrows=%zd. A_ncols=%zd. B_ncols=%zd\n",
i, l.A_nrows, l.A_ncols, l.B_ncols );
printf
("Shape %d: A: %zu kiB B: %zu kiB\n",
i, l.A_bytes >> 10, l.B_bytes >> 10);
}
if ( n_threads <= 0 )
{
printf("Usage: %s [ NUM_CUDA_BLOCKS ] [WARPS_PER_BLOCK] "
"[COL PER MP]\n",
argv[0]);
exit(1);
}
cudaFuncCache cache_config;
map<cudaFuncCache,const char*> cc_to_string;
#define CCMAP(e) cc_to_string[e]=#e;
CCMAP(cudaFuncCachePreferNone);
CCMAP(cudaFuncCachePreferShared);
CCMAP(cudaFuncCachePreferL1);
CCMAP(cudaFuncCachePreferEqual);
CE( cudaDeviceSetCacheConfig( cudaFuncCachePreferL1 ) );
CE( cudaDeviceGetCacheConfig( &cache_config ) );
printf("Cache preference set to %s.\n",cc_to_string[cache_config]);
const double clock_period_us = 1e3 / info.cuda_prop.clockRate;
const int max_wps = num_blocks * 32;
vector<Timing_Item> timing_items(max_wps);
Timing timing;
const size_t timing_items_bytes =
timing_items.size() * sizeof(timing_items[0]);
CE( cudaMalloc( &timing.timing_items, timing_items_bytes ) );
CE( cudaMemcpyToSymbol
( timing_dev, &timing, sizeof(timing), 0, cudaMemcpyHostToDevice) );
const int output_width = stdout_width_get();
for ( auto& l: shapes )
{
l.launch_prep();
M_transpose<<< num_mp, 1024 >>>(l.AT_dev,l.A_dev,l.A_nrows,l.A_ncols);
M_transpose<<< num_mp, 1024 >>>(l.BT_dev,l.B_dev,l.A_ncols,l.B_ncols);
convert_float_tf32<<< num_mp, 1024 >>>();
}
{
cudaEvent_t gpu_start_ce, gpu_stop_ce;
CE(cudaEventCreate(&gpu_start_ce));
CE(cudaEventCreate(&gpu_stop_ce,cudaEventBlockingSync));
printf("Launching with %d blocks of up to %d warps. \n",
num_blocks, wp_per_block_goal);
for ( auto& aki: kernels )
{
const char* kname = aki.name_base;
const int sidx = aki.shape_idx;
Shape& l = shapes[sidx];
const bool vary_warps = true;
const bool uses_tc_fp16 = aki.ab_format == AB_fp16;
const bool uses_tc_tf32 = aki.ab_format == AB_tf32;
const bool uses_tc = uses_tc_fp16 || uses_tc_tf32;
Kernel_Info* const ki = &info.get_info(aki.k_ptr);
pStringF shape_etc("Shape %zd x %zd x %zd.",
l.A_nrows,l.A_ncols,l.B_ncols);
pStringF tiling_scalar
("wd=%d, ht=%d, dp=%d",aki.m_wd,aki.m_ht,aki.thd_p_elt);
pStringF tiling_tc
("wd=%d ht=%d tiles.",aki.m_wd,aki.m_ht);
pStringF local_txt(", %zd LOCAL", ki->cfa.localSizeBytes );
printf("\nKernel %s, %d regs%s. %s %s %s\n",
kname, ki->cfa.numRegs,
ki->cfa.localSizeBytes ? local_txt.s : "",
shape_etc.s,
uses_tc ? tiling_tc.s : tiling_scalar.s,
ab_format_txt[aki.ab_format]);
pTable table(stdout);
const int wp_limit_kernel = ki->cfa.maxThreadsPerBlock >> 5;
const int wp_limit_prefer = 32;
const int wp_limit = min(wp_limit_kernel,wp_limit_prefer);
const int thd_limit = wp_limit << 5;
const int thd_per_block_no_vary =
min(wp_per_block_goal*wp_sz,thd_limit);
const int wp_start = 4;
const int wp_stop = vary_warps ? wp_limit : wp_start;
for ( int wp_cnt = wp_start; wp_cnt <= wp_stop; wp_cnt *= 2 )
{
const int thd_per_block = wp_cnt << 5;
if ( vary_warps && wp_cnt == 12 && wp_cnt < wp_stop ) continue;
if ( vary_warps && wp_cnt > 16 && wp_cnt < wp_stop ) continue;
if ( vary_warps && wp_cnt > 4 && wp_cnt & 0x3 ) continue;
const int grid_wps = wp_cnt * num_blocks;
const int64_t num_ops_fp = l.est_mm_fp;
const int64_t est_mm_ls_ops =
uses_tc_fp16 ? l.est_mm_ls_ops / 2 : l.est_mm_ls_ops;
const int tc_tile_m = 16;
const int tc_tile_n = 16;
const int ld_reuse_a = uses_tc ? aki.m_wd * tc_tile_n : aki.m_wd;
const int ld_reuse_b = uses_tc ? aki.m_ht * tc_tile_m : aki.m_ht;
const int64_t num_ops_ls =
est_mm_ls_ops / ( 2 * ld_reuse_a ) +
est_mm_ls_ops / ( 2 * ld_reuse_b );
const int64_t est_mm_l_bytes =
uses_tc_fp16 ? l.est_mm_l_bytes / 2 : l.est_mm_l_bytes;
const int64_t amt_data_bytes =
est_mm_l_bytes + l.est_mm_s_bytes;
{
l.launch_prep();
CE( cudaMemset( timing.timing_items, 0, timing_items_bytes ) );
NPerf_metrics_off();
typedef void (*KPtr)();
KPtr(ki->func_ptr) <<< num_blocks, thd_per_block >>>();
float cuda_time_ms = -1.1;
CE( cudaMemcpy
( timing_items.data(), timing.timing_items,
timing_items_bytes, cudaMemcpyDeviceToHost ) );
l.launch_prep();
NPerf_metrics_on();
for ( NPerf_data_reset(); NPerf_need_run_get(); )
KPtr(ki->func_ptr) <<< num_blocks, thd_per_block >>>();
NPerf_metrics_off();
map<int32_t,Timing_Item> sm_start_end;
int n_migs = 0; for ( auto& ti: views::take(timing_items,grid_wps) )
if ( ti.smid_start != ti.smid_end )
{
n_migs++;
}
else
{
auto& tis = sm_start_end[ti.smid_start];
if ( tis.time_start == tis.time_end ) tis = ti;
set_min( tis.time_start, ti.time_start );
set_max( tis.time_end, ti.time_end );
}
if ( n_migs ) printf("-- Number of migrations: %d\n",n_migs);
assert( sm_start_end.size() == num_mp );
int64_t et_sum = 0;
vector<clock_t> et;
for ( auto& [smid,tis]: sm_start_end )
{
const int64_t elapsed = tis.time_end - tis.time_start;
et_sum += elapsed;
et.push_back( elapsed );
}
ranges::sort( et, ranges::greater() );
const double et_clock_max_us = et[0] * clock_period_us;
const double et_clock_avg_us =
et_sum * clock_period_us / sm_start_end.size();
const double nperf_elapsed_time_s =
NPerf_metrics_collection_get()
? NPerf_kernel_et_get() : cuda_time_ms * 0.001;
const double this_elapsed_time_s =
opt_use_nperf_time ? nperf_elapsed_time_s
: et_clock_max_us * 1e-6;
const double imbalance_penalty =
et_clock_avg_us ? et_clock_max_us / et_clock_avg_us : 0.0;
const double imbalance_penaltyi =
NPerf_metric_value_get("sm__inst_executed.max") /
NPerf_metric_value_get("sm__inst_executed.avg");
const double thpt_compute_gflops =
num_ops_fp / this_elapsed_time_s * 1e-9;
const double thpt_data_gbps =
amt_data_bytes / this_elapsed_time_s * 1e-9;
const double chip_ls_ops = info.chip_sp_flops / 4;
const int num_sched_p_sm = 4;
const int tc_flop_p_hmma_tf32 = 16; const int tc_ii_p_hmma_tf32 = 2; const int tc_flop_p_hmma_fp16 = 64; const int tc_ii_p_hmma_fp16 = 4; const int chip_tc_insn_p_cyc = num_mp * num_sched_p_sm * wp_sz;
const double chip_tc_fp16_flop_p_cyc =
chip_tc_insn_p_cyc * tc_flop_p_hmma_fp16 / tc_ii_p_hmma_fp16;
const double chip_tc_fp16_flops =
chip_tc_fp16_flop_p_cyc * info.clock_freq_hz;
const double chip_tc_tf32_flop_p_cyc =
chip_tc_insn_p_cyc * tc_flop_p_hmma_tf32 / tc_ii_p_hmma_tf32;
const double chip_tc_tf32_flops =
chip_tc_tf32_flop_p_cyc * info.clock_freq_hz;
const float flop_p_insn =
uses_tc_fp16 ? tc_flop_p_hmma_fp16 :
uses_tc_tf32 ? tc_flop_p_hmma_tf32 : 1;
const double t_bound_fp_tc_fp16 = num_ops_fp / chip_tc_fp16_flops;
const double t_bound_fp_tc_tf32 = num_ops_fp / chip_tc_tf32_flops;
const double t_bound_fp_sp = num_ops_fp / info.chip_sp_flops;
const double t_bound_fp =
uses_tc_fp16 ? t_bound_fp_tc_fp16 :
uses_tc_tf32 ? t_bound_fp_tc_tf32 : t_bound_fp_sp;
const double t_bound_ls = num_ops_ls / chip_ls_ops;
const double t_bound_insn = t_bound_fp + t_bound_ls;
{
const double comp_frac = t_bound_insn / this_elapsed_time_s;
const double bw_frac =
1e9 * thpt_data_gbps / info.chip_bw_Bps;
const double fp_frac = t_bound_fp / this_elapsed_time_s;
const int num_wps = ( thd_per_block + 31 ) >> 5;
const int max_bl_per_mp =
ki->get_max_active_blocks_per_mp(thd_per_block);
const int bl_per_mp_available =
0.999 + double(num_blocks) / num_mp;
const int bl_per_mp =
min( bl_per_mp_available, max_bl_per_mp );
const int act_wps = num_wps * bl_per_mp;
pTable_Row row(table);
table.entry("wp",num_wps);
if ( num_blocks > num_mp )
table.entry("ac",act_wps);
if ( NPerf_metrics_collection_get() )
{
double dram_rd_bytes =
NPerf_metric_value_get("dram__bytes_read.sum");
double dram_wr_bytes =
NPerf_metric_value_get("dram__bytes_write.sum");
double l2_rd_bytes =
NPerf_metric_value_get
("l1tex__m_xbar2l1tex_read_bytes.sum");
double l2_wr_bytes =
NPerf_metric_value_get
("l1tex__m_l1tex2xbar_write_bytes.sum");
table.header_span_start("Insn");
const double normalized_ipi_factor =
32.0 / num_ops_fp * flop_p_insn;
table.entry
("/itr","%5.2f",
NPerf_metric_value_get("sm__inst_executed.sum")
* normalized_ipi_factor );
table.entry
("% ","%2.0f",
NPerf_metric_value_get
("sm__instruction_throughput"
".avg.pct_of_peak_sustained_elapsed") );
const double pred_on_frac =
NPerf_metric_value_get
("smsp__thread_inst_executed_pred_on.sum") /
NPerf_metric_value_get("smsp__thread_inst_executed.sum");
const double diverg_frac =
NPerf_metric_value_get
("smsp__thread_inst_executed_per_inst_executed.ratio")
/ 32.0;
const double lane_eff = pred_on_frac * diverg_frac;
pStringF eff_str("%2d%s", int(lane_eff * 100 + 0.5),
diverg_frac < pred_on_frac ? "d" : "p" );
const bool show_TAc = false;
if ( show_TAc )
table.entry("TAc","%3s",
lane_eff > .99 ? "100" : eff_str.s);
table.header_span_end();
table.header_span_start("L1");
table.entry
("SW","%2.0f",
NPerf_metric_value_get("l1tex__t_sectors.sum") /
NPerf_metric_value_get ("l1tex__t_requests.sum"));
if ( 0 )
table.entry
("TW","%2.0f",
NPerf_metric_value_get
("smsp__sass_l1tex_tags_mem_global.sum") /
NPerf_metric_value_get ("l1tex__t_requests.sum"));
table.entry
("BXW","%4.1f",
NPerf_metric_value_get
("l1tex__data_bank_conflicts_pipe_lsu.sum") /
NPerf_metric_value_get ("l1tex__t_requests.sum"));
table.header_span_end();
if ( true ) {
table.header_span_start(" L1<->L2 ");
table.entry
("N-Rd","%4.0f", l2_rd_bytes / est_mm_l_bytes );
table.entry
("N-Wr","%4.1f", l2_wr_bytes / l.est_mm_s_bytes );
table.entry
("GB/s","%4.0f",
1e-9 * ( l2_rd_bytes + l2_wr_bytes )
/ ( this_elapsed_time_s ) );
table.header_span_end(); }
}
const bool imbalance_more_time =
imbalance_penalty > imbalance_penaltyi;
const double imbalance =
max( imbalance_penalty, imbalance_penaltyi );
pStringF imb_textp
("%2.0f%s",
( imbalance - 1 ) * 10, imbalance_more_time ? "t" : "i");
table.entry("Imb","%3s", imb_textp.s);
table.entry( "t/µs", "%6.0f", this_elapsed_time_s * 1e6 );
if ( 0 )
table.entry("FP θ","%4.0f", thpt_compute_gflops);
if ( false )
table.entry("GB/s","%4.0f", thpt_data_gbps);
const size_t max_st_len =
max(5, output_width - 1 - table.row_len_get() );
pStringF fmt("%%-%zds",max_st_len);
string fp_unit = uses_tc ? "TC" : "FP";
string util_hdr = "=== Util: " + fp_unit + "++ Insn-- ";
if ( max_st_len > util_hdr.length() )
util_hdr += string(max_st_len - util_hdr.length(),'=');
typedef struct { double f; char c; } Elt;
vector<Elt> segments =
{ { fp_frac, '+' }, { comp_frac, '-' }, { 0*bw_frac, '*' } };
ranges::sort( segments, {}, &Elt::f );
string bar;
for ( Elt& e: segments )
if ( size_t p = e.f * max_st_len + 0.5; p > bar.length() )
bar += string( p - bar.length(), e.c );
if ( bar.length() > max_st_len )
{
bar.resize(max_st_len);
bar[max_st_len-1] = '>';
}
table.entry(util_hdr,fmt, bar, pTable::pT_Left);
}
l.from_dev_and_check(aki.ab_format);
}
}
}
}
return 0;
}