Advanced NVIDIA CUDA Kernel Optimization Techniques: Handwritten PTX

As accelerated computing continues to drive application performance in all areas of AI and scientific computing, there’s a renewed interest in GPU optimization techniques to ensure applications obtain the best possible performance. As an application developer, there are many ways to program GPUs, up and down the software stack. In this post, we introduce some of the different levels of the stack and dive into the very lowest: handwriting Parallel Thread Execution (PTX) code. Accelerated computing software stack Today, you can do a lot with a GPU without writing GPU-specific code. The low-level work has already been done for you by library developers and software engineers. For example, you can work at a high level in the stack, building complete AI workflows using blueprints. Or, you can develop applications in a framework such as PyTorch, where you specify your model, and the appropriate GPU code, and libraries are automatically used to execute your program. You can also develop your application by using the full suite of NVIDIA CUDA-X libraries, which include domain-specific libraries from quantum computing, data processing, physics AI, gene sequencing, edge computing, drug discovery, and everything in between. If these domain-specific libraries don’t contain all the functionality you need, you can program GPUs with compiler directives like OpenACC, and libraries like C++ stdpar algorithms, and C++ standard libraries using libcu++. In all the situations mentioned, you aren’t writing GPU-specific code; you’re relying on libraries or compiler directives that have been carefully designed, implemented, and optimized by expert engineers. However, there may be cases where you have to implement your own GPU code because a library doesn’t currently exist for the functionality you need. Then, you can move further down the stack and write CUDA GPU code directly in high-level languages such as C++, Fortran, and Python, for example. Finally, in rare cases, developers may choose to go deeper and write the extremely performance-sensitive portions of their code using PTX directly. As with most performance optimization techniques, the more control you desire, the lower you need to go in the stack to extract performance. This tradeoff should be considered carefully: in addition to added development and debugging complexity, performance gains from hand-written low-level code may not port to other GPU architectures. As we showed in an earlier post, PTX is the assembly language of GPUs. Writing PTX directly is a highly advanced optimization technique that is not necessary for most developers and should be considered a tool of last resort. Nevertheless, there are situations where the fine-grained control enabled by writing PTX directly enables performance improvements in specific applications. These situations are typically in very performance-sensitive portions of an application where every fraction of a percent of performance improvement has significant benefits. All of the available PTX instructions are in the PTX ISA document. In this blog post, we’ll dig deeper into an example where hand-written PTX is used to improve the performance of an important algorithm that appears in certain AI model implementations. Writing PTX Before we get into the example, we’ll list some ways you can include handwritten PTX code in your application. In other words, how, in principle, to do it. The example that follows demonstrates a real-world scenario and shows performance changes. Inline PTX One standard way to include PTX in your code is to use inline PTX. This is the method we’ll show below, and the detailed information about syntax and semantics appears in the documentation. This is very similar to writing assembly code on a CPU. cuda::ptx namespace Another option for including PTX in your code is to use libcu++, which includes the namespace cuda::ptx that provides functions that map directly to PTX instructions. This makes it easy to use specific PTX instructions within a C++ application. For more information on the cuda::ptx namespace, see the cuda::ptx namespace documentation. CUTLASS example To illustrate how one might hand-write PTX code, we’ll use a specific example from linear algebra. In general, if your operation can be expressed as a linear algebra operation, such as a GEMM, then NVIDIA CUBLAS is the recommended way to run on GPUs. CUBLAS is already highly optimized for many sizes and shapes of matrices, and has multiple numerical precisions to select from. Sometimes what you want to do isn’t fully expressed by the functionality in CUBLAS, or you want to do computations directly before or after a GEMM. Rather than calling some functions and then CUBLAS, and then more functions, sometimes you can improve performance by fusing other operations with a GEMM operation. This has many benefits, as fusing kernels potentially enables more optimizations to occur, like using data more efficiently, for example. This is where the NVIDIA CUTLASS library comes in. CUTLASS includes a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. Because it enables more control and customization around GEMM and GEMM-like operations, CUTLASS involves a bit more code from the developer than CUBLAS. CUTLASS includes a good deal of hand-written PTX, because it’s designed with the best performance possible on each GPU architecture. This makes CUTLASS a ready example for illustrating handwritten PTX in action. GEMM plus top_k and softmax The particular operation we’ll demonstrate is a fusion of a GEMM with the top_k and softmax algorithms. This is a common operation when running a mixture of experts neural network. We’ll focus on the NVIDIA Hopper architecture. Because this is a commonly used operation, CUTLASS already has a special kernel for this with some inline PTX, so it’ll be straightforward to demonstrate how CUTLASS incorporates hand-written PTX into its high-performance GPU code. For this post, we use: Version 3.9.2 of CUTLASS NVIDIA GH200 GPU Driver version 570.140 CUDA Toolkit version 12.8 Following the build instructions on the CUTLASS website, we use the build option -DCUTLASS_NVCC_ARCHS=90a to cmake to ensure that the full feature set of the Hopper architecture is enabled. The CUTLASS repository has many examples that demonstrate various capabilities on the latest architectures. Once cmake is finished, we navigate to the build directory (for example, build/examples/61_hopper_gemm_with_topk_and_softmax) where the example code is built and run. We execute make to build the code, and the executable is built and ready to run. The application accepts a few different options as inputs, including the matrix sizes m, n, and k, the error tolerance epsilon, and the number of iterations to run to generate the benchmark numbers in GFlop/s. The following output we obtain by choosing m=1024, n=8 (the default), k=4096, iterations=1000000, and epsilon of 1e-4. In this benchmark, analogous to LLM execution, m is the number of tokens, n is the number of experts, k is the embedding dimension of the experts, and the value for top_k is 2 (hard-coded in the test code). $ ./61_hopper_gemm_with_topk_and_softmax --m=1024 --k=4096 --iterations=1000000 --eps=1e-4 Disposition: Passed Relative error: 1.52478e-05 Problem Size: 1024x8x4096x1 Avg runtime: 0.011765 ms GFLOPS: 5704.11 In this benchmark example, the performance is 5,704 GFlop/s. We vary the number of tokens (the m parameter) up to 16,384 and generate the following table of performance. m GFlop/s1,0245,7042,0489,5514,09614,5698,19219,79416,38421,476Table 1. Performance of the benchmark code, which includes the use of inline PTX for the top_k and softmax functions Removing the inline PTX This benchmark example fuses GEMM with top_k and softmax, and calls functions that use inline PTX for the top_k function, provided the value of k is either 2 or 4. (Note: This is a different ‘k’ than the matrix dimension k above.) It also uses inline PTX for the softmax function under certain conditions. Additionally, both the top_k and softmax have fallback routines, written in CUDA C++, when the specific conditions aren’t met. It’s straightforward to change the internal functions for top_k and softmax to comment out the calling of PTX functions and run the fallback CUDA C++ code. This will enable us to quantify the value of hand-written PTX in this case. To remove the inline PTX from this example, we edited the file cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp to comment out the use of the inline PTX functions. Near the top of this file, you’ll see some functions written with inline PTX that start with “top_2” and “top_4”. As an example, here’s the first PTX function you’ll encounter. CUTLASS_DEVICE Array<float, 2> top_2_reduce_scalar(Array<float, 2> a, float scalar) { Array<float, 2> out; asm volatile( "{\n" " .reg .f32 mx;\n" " .reg .pred p;\n" " max.f32 mx, %3, %4;\n" " setp.gtu.f32 p, %2, %4;\n" " selp.f32 %1, mx, %2, p;\n" " selp.f32 %0, %2, %4, p;\n" "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); return out; } There’s no need to understand every detail of this code. The main point is to show an example of how a short inline PTX function looks. Further down in the source code, you’ll also see a softmax function. These are the PTX functions we’ll omit to see how performance changes. In that same source file, you can find where those functions are called as part of if statements. We simply comment out the if statements that call the inline functions, and leave the else portion of the statement. This omits the calls to the inline PTX functions and executes the code written in CUDA C++ instead. For example, there’s a function called add_element_to_desc_sorted_array which in turn calls either a top_2 or top_4 PTX function if k=2 or k=4, respectively, or calls a C++ implementation of that algorithm. The code of this function is: void add_element_to_desc_sorted_array(cutlass::Array<Element, N>& a, Element b) { if constexpr (N == 2 && is_same_v<Element, float>) { a = top_2_reduce_scalar(a, b); } else if constexpr (N == 4 && is_same_v<Element, float>) { a = top_4_reduce_scalar(a, b); } else { // slower generic path with branching, slower, and can cause register spill CUTLASS_PRAGMA_UNROLL for (int k = 0; k < N; ++k) { if (a[k] < b) { // Shift down CUTLASS_PRAGMA_UNROLL for (int l = N - 1; l > k; --l) { a[l] = a[l-1]; } a[k] = b; break; } } } } To determine the effect of the hand-written PTX functions, we comment out the calls to those PTX functions and only allow the code to execute the C++ version of the algorithm, as follows: void add_element_to_desc_sorted_array(cutlass::Array<Element, N>& a, Element b) { /* BEGIN COMMENT if constexpr (N == 2 && is_same_v<Element, float>) { a = top_2_reduce_scalar(a, b); } else if constexpr (N == 4 && is_same_v<Element, float>) { a = top_4_reduce_scalar(a, b); } else { END COMMENT */ // slower generic path with branching, slower, and can cause register spill CUTLASS_PRAGMA_UNROLL for (int k = 0; k < N; ++k) { if (a[k] < b) { // Shift down CUTLASS_PRAGMA_UNROLL for (int l = N - 1; l > k; --l) { a[l] = a[l-1]; } a[k] = b; break; } } //} COMMENT THE END OF THE ELSE } We make analogous changes to the functions merge_desc_sorted_arrays and masked_softmax to remove the if/else statements, which eliminate the handwritten PTX functions top_2_reduce_scalar, top_4_reduce_scalar, top_2_reduce, top_4_reduce, and fast_masked_softmax from this example. The following are the performance results. m GFlop/s1,0244,9982,0488,3764,09613,2678,19217,88516,38420,066Table 2. Performance of the benchmark code without inline PTX for the top_k and softmax functions When you compare these results to the results in Table 1, you’ll see that performance improves from between 7% to 14% when the handwritten PTX is used instead of the CUDA C++ code. The take-home message is not the absolute performance increases shown here specifically, but rather, in certain carefully chosen situations, handwriting PTX can result in performance improvements. A careful analysis of performance and portability tradeoffs should be undertaken to determine the feasibility of including handwritten PTX in your application. This is an example code that’s been highly optimized, and we chose it because it has handwritten PTX, written by NVIDIA CUTLASS engineers, that shows a demonstrable performance improvement. This example reinforces the guidance that in the vast majority of cases, developers should leave the handwriting of PTX to the library developers of CUTLASS, CUBLAS, and other GPU libraries, and instead, build on the foundations of these libraries. Summary In this post, we showed an example of how CUTLASS uses handwritten PTX to improve the performance of a very specific fused GEMM operation used in some AI models. We don’t want to give the impression that every developer should write PTX. A vast majority of developers shouldn’t ever need to. Writing PTX by hand should be a tool of last resort. But all that being said, handwriting PTX is a technique available to all developers. It is an advanced and specialized technique that, when used appropriately, can be another tool in the advanced GPU programmer’s toolbox. It is one of the great strengths of the CUDA platform that developers can consume the NVIDIA stack at whatever level is appropriate for them, from the application level all the way down to writing assembly code (PTX), and everywhere in between. Acknowledgments Thanks to the following NVIDIA contributor: Ali Hassani