/// LSU EE 7722 (Spring 2025), GPU Programming // /// Simple CUDA Example, without LSU ECE helper classes. /// References // // :ccpg: CUDA C Programming Guide Version 12.8 // https://docs.nvidia.com/cuda/cuda-c-programming-guide #if 0 /// Background // The following background describes the view of NVIDIA GPUs // provided by the CUDA system. "View" means how the hardware // appears to the application programmer. For CUDA, this should // be very close to the actual hardware. /// Compute Capability (CC) // // An NVIDIA numbering system that identifies the approximate // capabilities of the hardware. // // Compute Capabilities // // -- CC 1.0, 1.1, 1.2, 1.3 // Tesla. (Not to be confused with the Tesla product line.) // Now considered obsolete (the CC, not the product line). // // -- CC 2.0, 2.1 // Fermi // Now considered obsolete. // // -- CC 3.0, 3.1, 3.5, 3.7 // Kepler // Outdated. // // -- CC 5.2 // Maxwell // Outdated. // // -- CC 6.0, 6.1, 6.2 // Pascal // Product cycle ending. // If cost is no object, good double-precision and half-precision perf. // Some support for machine learning: 16-bit floats. // // -- CC 7.0 // Volta // Not used much for graphics. // If cost is no object, good double-precision and half-precision perf. // Machine learning support. // // -- CC 7.5 // Turing // Ray tracing support. // Machine learning support. // // -- CC 8.0 // Ampere // Sparse matrix support (intended for machine learning). // // -- CC 8.9 // Ada // Machine learning support: FP8 data types. (E4M3, E5M2) // // -- CC 9.0 // Hopper // Machine learning support: add/min/max (e.g., for ReLU) // Block clustering. // Tensor data movement. // // -- CC 10.0 // Blackwell // Dynamic range management, fine-grain scaling. // FP4. /// Kernels // // :Def: Kernel // A procedure that executes on the GPU. // Entry point is a __global__ procedure. // "I launched a kernel to multiply two 1000 by 1000 matrices." // // :Example: __global__ void my_kernel(float *c){ c[threadIdx.x,blockIdx.x]++; } // // // :Def: Kernel Launch // The initiation of execution of CUDA code. // Done by a CUDA API call. // Specify: // The name of the CUDA C procedure to start. (E.g., my_kernel();) // The grid size. (The number of blocks.) // The block size. (Number of threads per block.) // // :Example: int num_blocks, thd_per_block; ivec3 num_blocks = ivec3(2, 4, 7); my_kernel<<< num_blocks, thd_per_block >>>(c); // // :Def: Launch Configuration // The block size and grid sized used for a kernel launch. // Choosing the correct launch configuration is very important. // // /// Launch Configuration Criteria // // - Number of blocks is a multiple of number of SMs. // - Number of threads per block is a multiple of warp size (32). /// CUDA Thread Organization // // :Def: Thread // Similar to the definition of a thread on a CPU. // A path of execution through the kernel. // Each Thread: // Has its own id local to the block. // The id consists of a thread index, in variable threadIdx, .. // .. and a block index, in variable blockIdx. // "My kernel consists of 16384 threads." // // :Def: Block // A grouping of threads. // The number of threads in a block is called the block size .. // .. its value is in variable blockDim. // "My kernel has a block size of 1024 threads." // // :Def: Block Cluster // Note: CUDA 12, CC 9.0 and later. // A grouping of blocks. // // :Def: Grid // A collection of blocks. // The grid size is specified in the kernel launch. // "My kernel consists of 16 blocks of 1024 threads each." // // :Def: Warp // A group of threads that (usually) execute together .. // .. meaning one instruction is fetched for all threads in the warp .. // .. threads on that path are active, others inactive. // For all NV GPUs so far warp size is 32 threads (2025). // One day the size of warp may change but it's been 32 through CC 9.0. // "I chose my block size to be a multiple of the warp size." // /// Hardware Organization // // :Def: Streaming Multiprocessor (SM, SMX, MP) // The hardware to execute a block of threads. // In class called a multiprocessor (the word streaming omitted) for short. // Roughly akin to a core in a CPU. // High-performance GPUs might have about 128 SMs. // /// Each block is assigned to a particular SM. // All threads in a block execute on the same multiprocessor. // Threads within a block share shared memory. // // "Uh-oh, my new GPU has 50 SMs. I hope my code can launch enough // blocks to keep them all busy." // // // :Def: Functional Unit // A piece of hardware that can perform a particular set of operations. // Typical, GPU and Non-GPU Examples: // Integer ALU: Can perform operations such as add, sub, AND, OR. // Integer multiply. // FP add, mul, madd. // FP div, sqrt, trig. // NVIDIA GPU Units: // FP32, (CUDA Core) : Can perform most single-precision non-divide FP. // FP64: Double-precision. // INT32: Common integer operations. // Special Func Unit: reciprocal, reciprocal square root, approx trig. // Load / Store: Read and write from memory. // Tensor Core: Part of Matrix / Matrix Multiply. /// Global Memory Access // // :Sample: mval = a[tid]; // // Important rule: // // Consecutive threads should access consecutive data items. // As in: mval = a[ tid ]; // Good. ☺ // NOT: mval = a[ tid * 1000 ]; // BAD. ☹ // // Size of contiguous chunks (accessed by consecutive threads) // should be a multiple of 32 bytes. // /// Possible Locations of Global Data // // - Off-Chip Global Memory // Requires about 400 cycles to obtain data. // Subject to off-chip BW limit. // BW limit in RTX 4090: 930 GB/s (measured, see microbenchmarks) // // - Level 2 Cache (Not caché, please). // Size varies, about 73 MiB in an RTX 4090 // Requires about 250 cycles to obtain data. // Much higher BW limit: 4875 GB/s RTX 4090, measured // // - Read-Only Cache (Texture Cache) // Size varies, 16 - 32 kiB per SM. // Available in some CC's, and only for certain code. // // - Level 1 Cache // Available in CC 2.x and CC 7.0-. // Latency about 36 cycles in an RTX A4000 // Varies, 32 - 64 kiB per SM. // BW Limit: 42500 GB/s RTX 4090, measured. // // /// Memory Requests // // How Things Work (CC 3.x to 6.x) // // - Threads in a warp execute a load instruction. E.g., mval = a[tid]; // // - Hardware coalesces these loads based on address into // contiguous *requests* of size 32, 64, or 128 B. // // - Requests are sent to L2 cache, and if necessary, off-chip storage. // // - Dependent instructions can execute when requests return. // // Implications // // Bandwidth consumed determined by request size .. // .. not by how much data actually needed. // // Possible slow down with a larger number of requests .. // .. so 10 128-B requests better than 40 32-B request .. // .. even though they are the same size. /// #endif #include <pthread.h> #include <string.h> #include <stdio.h> #include <stdlib.h> #include <unistd.h> #include <errno.h> #include <ctype.h> #include <time.h> #include <new> #include <cuda_runtime.h> #include <gp/cuda-gpuinfo.h> struct App { int num_threads; int array_size; float4 *h_in; // Host address space, data in. float *h_out; // Host address space, data out. float *h_out_check; // Compute correct answer on CPU, to check GPU. float4 *d_in; // Device address space, data in. float *d_out; // Device address space, data out. }; // In host address space. App app; // In device constant address space. __constant__ App d_app; /// /// GPU Code (Kernels) /// __global__ void kmain_efficient() { const int tid = threadIdx.x + blockIdx.x * blockDim.x; if ( tid >= d_app.num_threads ) return; for ( int h=tid; h<d_app.array_size; h += d_app.num_threads ) { float4 p = d_app.d_in[h]; // Good: Consecutive access. float sos = p.x * p.x + p.y * p.y + p.z * p.z + p.w * p.w; d_app.d_out[h] = sos; } } __global__ void kmain_simple() { const int tid = threadIdx.x + blockIdx.x * blockDim.x; if ( tid >= d_app.num_threads ) return; const int elt_per_thread = ( d_app.array_size + d_app.num_threads - 1 ) / d_app.num_threads; /// WARNING: Don't assign work to threads this way. const int start = elt_per_thread * tid; // Bad: Non-consecutive access. const int stop = start + elt_per_thread; for ( int h=start; h<stop; h++ ) { float4 p = d_app.d_in[h]; // Bad: Non-consecutive access. float sos = p.x * p.x + p.y * p.y + p.z * p.z + p.w * p.w; d_app.d_out[h] = sos; // Bad: Non-consecutive access. } } __global__ void kmain_tuned() { const int tid = threadIdx.x + blockIdx.x * blockDim.x; if ( tid >= d_app.num_threads ) return; const int strip_len = 16; // Data "strip" is 32 threads wide and strip_len threads long. const int wp_sz = 32; // Warp size. const int wp = tid / wp_sz; // This thd's warp number within kernel. (0-) const int ln = tid % wp_sz; // This thd's lane number within warp. (0-31) const int start = wp * wp_sz * strip_len + ln; for ( int h=start; h<d_app.array_size; h += strip_len * d_app.num_threads ) { float soses[strip_len]; for ( int i=0; i<strip_len; i++ ) { float4 p = d_app.d_in[ h + i * wp_sz ]; soses[i] = p.x * p.x + p.y * p.y + p.z * p.z + p.w * p.w; // // Note: We expect compiler to emit strip_len load // instructions, and after those do strip_len sum-of-squares // calculations. By performing those strip_len loads before // any calculation, the strip_len load instructions can be // overlapped. Overlapping is key. Suppose you needed to // change sixteen light bulbs. Would you order one light // bulb from Amazon, wait for it to arrive, install it, then // order the next light bulb? } for ( size_t i=0; i<strip_len; i++ ) d_app.d_out[ h + i * wp_sz ] = soses[i]; } } GPU_Info print_gpu_and_kernel_info() { GPU_Info info; gpu_info_print(); // Choose GPU 0 because it's usually the better choice. // int dev = gpu_choose_index(); CE(cudaSetDevice(dev)); printf("Using GPU %d\n",dev); info.get_gpu_info(dev); info.GET_INFO(kmain_simple); info.GET_INFO(kmain_efficient); info.GET_INFO(kmain_tuned); // Print information about kernel. // printf("\nCUDA Kernel Resource Usage:\n"); for ( int i=0; i<info.num_kernels; i++ ) { printf("For %s:\n", info.ki[i].name); printf(" %6zd shared, %zd const, %zd loc, %d regs; " "%d max threads per block.\n", info.ki[i].cfa.sharedSizeBytes, info.ki[i].cfa.constSizeBytes, info.ki[i].cfa.localSizeBytes, info.ki[i].cfa.numRegs, info.ki[i].cfa.maxThreadsPerBlock); } return info; } /// /// Main Routine /// int main(int argc, char **argv) { // Get info about GPU and each kernel. // GPU_Info info = print_gpu_and_kernel_info(); // Get number of multiprocessors. (A.k.a. streaming multiprocessors or SMs) // const int num_sm = info.cuda_prop.multiProcessorCount; // Examine argument 1, block count, default is number of SM. // const int arg1_int = argc < 2 ? num_sm : atoi(argv[1]); const int num_blocks = arg1_int == 0 ? num_sm : arg1_int < 0 ? -arg1_int * num_sm : arg1_int; // Examine argument 2, number of threads per block. // const int thd_per_block_arg = argc < 3 ? 1024 : atoi(argv[2]); const int thd_per_block_goal = thd_per_block_arg == 0 ? 1024 : thd_per_block_arg; const int num_threads = num_blocks * thd_per_block_goal; // If true, run kernels at multiple block sizes. // const bool vary_warps = thd_per_block_arg == 0; // Examine argument 3, size of array in MiB. Default is 2^20 per SM. // app.array_size = argc < 4 ? num_sm << 20 : int( atof(argv[3]) * (1<<20) ); if ( num_threads <= 0 || app.array_size <= 0 ) { printf("Usage: %s [ NUM_CUDA_BLOCKS ] [THD_PER_BLOCK] " "[DATA_SIZE_MiB]\n", argv[0]); exit(1); } const size_t in_size_bytes = app.array_size * sizeof( app.h_in[0] ); const size_t out_size_bytes = app.array_size * sizeof( app.h_out[0] ); const size_t overrun_size_bytes = num_blocks * 1024 * sizeof( app.h_in[0] ); // Allocate storage for CPU copy of data. // app.h_in = new float4[app.array_size]; app.h_out = new float[app.array_size]; app.h_out_check = new float[app.array_size]; // Allocate storage for GPU copy of data. // CE( cudaMalloc( &app.d_in, in_size_bytes + overrun_size_bytes ) ); CE( cudaMalloc( &app.d_out, out_size_bytes + overrun_size_bytes ) ); printf("Array size: %d 4-component vectors.\n", app.array_size); // Initialize input array. // for ( int i=0; i<app.array_size; i++ ) for ( int j=0; j<4; j++ ) ((float*)&app.h_in[i])[j] = drand48(); // Compute correct answer. // #pragma omp parallel for for ( int i=0; i<app.array_size; i++ ) { float4 p = app.h_in[i]; app.h_out_check[i] = p.x * p.x + p.y * p.y + p.z * p.z + p.w * p.w; } /// Compute Expected Computation and Communication // // Number of multiply/add operations. Ignore everything else. // const int64_t num_ops = 4 * app.array_size; // Multiply-adds. // // Amount of data in and out of GPU chip. // const size_t amt_data_bytes = in_size_bytes + out_size_bytes; { // Prepare events used for timing. // cudaEvent_t gpu_start_ce, gpu_stop_ce; CE(cudaEventCreate(&gpu_start_ce)); CE(cudaEventCreate(&gpu_stop_ce)); // Copy input array from CPU to GPU. // CE( cudaMemcpy ( app.d_in, app.h_in, in_size_bytes, cudaMemcpyHostToDevice ) ); // Launch kernel multiple times and keep track of the best time. printf("Launching with %d blocks of up to %d threads. \n", num_blocks, thd_per_block_goal); for ( int kernel = 0; kernel < info.num_kernels; kernel++ ) { cudaFuncAttributes& cfa = info.ki[kernel].cfa; const int wp_limit = cfa.maxThreadsPerBlock >> 5; const int thd_limit = wp_limit << 5; const int thd_per_block_no_vary = min(thd_per_block_goal,thd_limit); const int wp_start = 1; const int wp_stop = vary_warps ? wp_limit : wp_start; const int wp_inc = 1; for ( int wp_cnt = wp_start; wp_cnt <= wp_stop; wp_cnt += wp_inc ) { const int thd_per_block = vary_warps ? wp_cnt << 5 : thd_per_block_no_vary; // When there are more than four warps .. // .. limit the number of warps to a multiple of four. // if ( vary_warps && wp_cnt > 4 && wp_cnt & 0x3 ) continue; app.num_threads = thd_per_block * num_blocks; // Copy App structure to GPU. // CE( cudaMemcpyToSymbol ( d_app, &app, sizeof(app), 0, cudaMemcpyHostToDevice ) ); // Zero the output array. // CE( cudaMemset(app.d_out,0,out_size_bytes) ); // Measure execution time starting "now", which is after data // set to GPU. // CE( cudaEventRecord(gpu_start_ce,0) ); typedef void (*KPtr)(); /// Launch Kernel // KPtr(info.ki[kernel].func_ptr) <<< num_blocks, thd_per_block >>>(); // Stop measuring execution time now, which is before is data // returned from GPU. // CE( cudaEventRecord(gpu_stop_ce,0) ); CE( cudaEventSynchronize(gpu_stop_ce) ); float cuda_time_ms = -1.1; CE( cudaEventElapsedTime(&cuda_time_ms,gpu_start_ce,gpu_stop_ce) ); const double this_elapsed_time_s = cuda_time_ms * 0.001; const double thpt_compute_gflops = num_ops / this_elapsed_time_s * 1e-9; const double thpt_data_gbps = amt_data_bytes / this_elapsed_time_s * 1e-9; if ( vary_warps ) { const char* const stars = "********************************************************************************"; const int stars_len = 80; const double comp_frac = 4e9 * thpt_compute_gflops / info.chip_sp_flops; const double bw_frac = 1e9 * thpt_data_gbps / info.chip_bw_Bps; const bool graph_bw = true; const double frac = graph_bw ? bw_frac : comp_frac; const int max_st_len = 43; // Number of warps, rounded up. // const int num_wps = ( thd_per_block + 31 ) >> 5; // The maximum number of active blocks per SM for this // kernel when launched with a block size of thd_per_block. // const int max_bl_per_sm = info.get_max_active_blocks_per_mp(kernel,thd_per_block); // Compute number of blocks available per SM based only on // the number of blocks. This may be larger than the // number of blocks that can run. // const int bl_per_sm_available = 0.999 + double(num_blocks) / num_sm; // The number of active blocks is the minimum of what // can fit and how many are available. // const int bl_per_sm = min( bl_per_sm_available, max_bl_per_sm ); // Based on the number of blocks, compute the num ber of warps. // const int act_wps = num_wps * bl_per_sm; if ( wp_cnt == wp_start ) printf("Kernel %s:\n", info.ki[kernel].name); printf("%2d %2d wp %6.0f µs %3.0f GF %3.0f GB/s %s\n", num_wps, act_wps, this_elapsed_time_s * 1e6, thpt_compute_gflops, thpt_data_gbps, &stars[max(0,stars_len-int(frac*max_st_len))]); } else { printf("K %-15s %2d wp %11.3f µs %8.3f GFLOPS %8.3f GB/s\n", info.ki[kernel].name, (thd_per_block + 31 ) >> 5, this_elapsed_time_s * 1e6, thpt_compute_gflops, thpt_data_gbps); } // Copy output array from GPU to CPU. // CE( cudaMemcpy ( app.h_out, app.d_out, out_size_bytes, cudaMemcpyDefault) ); int err_count = 0; for ( int i=0; i<app.array_size; i++ ) { if ( fabs( app.h_out_check[i] - app.h_out[i] ) > 1e-5 ) { err_count++; if ( err_count < 5 ) printf("Error at vec %d: %.7f != %.7f (correct)\n", i, app.h_out[i], app.h_out_check[i] ); } } if ( err_count ) printf("Total errors %d\n", err_count); } } } }