Fixing LLM Inference Bottlenecks with Custom CUDA Kernels (Profiling)
Table of Contents (Blog Series)
- Part 1: Profiling (this post)
- Part 2: Simple Fused Kernel (Coming soon)
- Part 3: Advanced Fused Kernel (Coming soon)
1. Introduction
The primary focus of many popular open-source frameworks’ model implementations (like 🤗 transformers) is on readability, flexibility, and ease of use, and less on faster inference. As such, though the model implementation (like: openai-community/gpt2) may be easy to read and understand, the inference is not exactly optimized, at least when you run it out of the box without any packages like Huggingface accelerate or vLLM.
Luckily for us, this provides a learning opportunity for profiling a model’s inference performance and then fixing bottlenecks by writing custom fused kernels.
In this series of blog posts, I aim to take the time to explain, from first principles, why one would ever need to write fused kernels and how to write one and integrate it to observe noticeable inference speedups.
All that is assumed of you is some understanding of transformer intrinsics as well as a system with CUDA installed (in order to write and compile your own CUDA kernels). You should also have GPU-enabled PyTorch and 🤗 transformers installed.
If you are stuck in endless error hell when trying to install CUDA on Linux, you can check out my blog on Installing CUDA and CUDA toolkit on Ubuntu the right way.
2. What is a fused kernel?
A fused kernel is a high-performance GPU kernel that merges multiple operations in a single kernel to minimize memory movement. But why would we want to fuse operations into one monolith kernel?
Well, for that you need to understand memory-bound vs compute-bound operations:
- Memory-Bound operations: Performance of such operations is limited by how fast the data can be moved between the HBM (High Bandwidth Memory, the main memory of the GPU) and registers/cache.
- Compute-Bound operations: Performance of such operations is limited by how fast the processor (the GPU) can do Math.
In deep learning inference, most operations are memory bound as more time is spent transferring intermediate tensors in and out of memory than is spent on actual computation.
Hence, fused kernels are important because they ideally load data from memory only once (or fewer times than a non-fused implementation) and do the computation so that the operation is less (or not at all) memory bound.
LLM inference frameworks like vLLM use fused kernels extensively for better inference.
If you want to learn more about the memory vs compute bound operations along with transformer inference arithmetics, check out this excellent blog by kipply.
3. Setup + Identifying Bottlenecks
Before you can write a fused kernel, you must first identify the existing bottlenecks in your model.
For this and the next blogs, we will be using the openai-community/gpt2 implementation from 🤗 transformers. Of course, there exist more optimized implementations of this model but from a pedagogical perspective, this will be perfect for our inference autopsy.
There are many ways of identifying the bottlenecks of a transformers model in Python (my personal preference being the nsys package); we will, however be using the PyTorch’s built-in profiler for this so you can profile models even without owning a GPU (on platforms like Google Colab).
3.1 Profiling GPT2 to identify bottlenecks
Import all necessary packages and initialize the model and dummy input text.
import torch
from transformers import AutoModel, AutoTokenizer
device = "cuda:0"
model_name = "gpt2"
model = AutoModel.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Thousand token input sequence for better profiling
text = "A quick brown fox jumped upon a lazy dog." * 100
inputs = tokenizer(text, return_tensors="pt").to(device)
Once the model and tokenizer have been downloaded and the input has been tokenized, we “warm-up” the GPU by running inference on the model a couple of times.
# Warm up
with torch.no_grad():
for _ in range(20):
model(**inputs)
Why this warm-up? There are several reasons to this, some of them being:
- GPUs dynamically adjust clock speeds and power usage. As such when the GPU has been “idle” for long, on the first few runs, the GPU may be in a low power idle state and not the peak frequency which would hurt our profiling. Warm-up makes sure the GPU is at its full performance state.
- First GPU call triggers CUDA context creation (slow) and memory pool setup for allocators. This one time cost adds hundreds of milliseconds which would skew our profiling had we run it without warm-up.
Once the warm-up is done, we will perform the actual profiling using the torch
package’s profiler
module:
with torch.profiler.profile(record_shapes=True) as prof:
with torch.no_grad():
_ = model(**inputs)
After profiling, we can print the average results as a table sorted by the cuda_time_total
(which is the total time in milliseconds it took to execute a set of CUDA ops on GPU) since we are only interested in identifying the GPU bottlenecks here.
# Print the profiled results in a table
print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="cuda_time_total", row_limit=5
)
)
Upon execution (and depending on your GPU), you will see the profiled results as follows:
--------------------------------------------------------------- --- -------- ----------- ------------ ----------- -------------
Name ... # of Calls Self CUDA Self CUDA % CUDA total CUDA time avg
--------------------------------------------------------------- --- -------- ----------- ------------ ----------- -------------
ampere_sgemm_128x64_nn ... 36 10.898ms 33.27% 10.898ms 302.728us
aten::linear ... 1 0.000us 0.00% 6.866ms 6.866ms
aten::matmul ... 1 0.000us 0.00% 6.866ms 6.866ms
aten::mm ... 1 6.866ms 20.96% 6.866ms 6.866ms
ampere_sgemm_64x64_tn ... 1 6.866ms 20.96% 6.866ms 6.866ms
--------------------------------------------------------------- --- -------- ----------- ------------ ----------- -------------
Self CPU time total: 32.360ms
Self CUDA time total: 32.754ms
3.2 Interpreting this table
Let’s take a moment of detour to interpret this table properly. Each row in this table is either a CUDA kernel or a PyTorch op that was invoked during the GPT2 forward pass. This table is sorted in the descending order of the total CUDA time each op takes (including any child ops).
3.2.1 What the columns mean
The columns tell you for each op:
- Name: Name of the kernel / op
- # of Calls: Number of times this op was called.
- Self CUDA: Total time spent in executing CUDA kernels directly in that op (not including any child ops it called).
- For example:
ampere_sgemm_128x64_nn
is a kernel for matrix multiplication and it ran for 10.898 ms. - This is the reason why we see no Self CUDA time for
aten::linear
andaten::matmul
ops as they both are purely high-level dispatch wrappers and don’t execute any CUDA kernels themselves. Instead they immediately call into lower-level ops such asaten::mm
, etc.
- For example:
- Self CUDA %: Percentage of GPU time this op spent executing (also not including any child ops it called).
- CUDA total: The sum of CUDA time in this op plus in any child ops it dispatched.
- For example: For high-level ops like
aten::linear
, this includes time spent in it’s underlyingaten::mm
/ SGEMM calls.
- For example: For high-level ops like
- CUDA time avg: The average CUDA time per single call of this op. Basically it’s the CUDA total for each op averaged by the # of Calls for that op.
Note: There are other, CPU related columns in this output table too but they are omitted for brevity since our focus is on fixing GPU bottlenecks.
3.2.2 [Optional] High-level description of these operations
Let’s also take a moment and understand these top-5 most CUDA time consuming ops at a higher-level:
ampere_sgemm_128x64_nn
: This is a low level CUDA kernel for matrix multiplication (SGEMM) optimized for NVIDIA Ampere architecture (which is my GPU’s architecture).- Here, the
128x64
refers to the tile size (128 rows x 64 cols) used by the kernel for blocks and parallelism nn
refers to the fact that this kernel is launched only when both matrices A and B are not transposed (No transpose on matrix A, No transpose on matrix B)
- Here, the
aten::linear
: This is a PyTorch level op for a linear layer \(y = xW^T+b\)- This is used in the GPT2 blocks for all linear layers (projection in attention, ffn, etc)
- As we covered earlier, this op doesn’t call any CUDA itself but it’s child ops do.
aten::matmul
: Basically a “catch-all” PyTorch op (torch.matmul
).- This handles the batching and broadcasting logic on CPU side before calling the
aten::mm
(for 2D matrices).
- This handles the batching and broadcasting logic on CPU side before calling the
aten::mm
: This is the PyTorch 2D matrix‐multiply operator (torch.mm
), used when both inputs are exactly 2D.- This is the final CPU/GPU dispatch point for GEMM. In our case (on CUDA), it directly invokes a
ampere_sgemm_*
kernels (depending on the matrix).
- This is the final CPU/GPU dispatch point for GEMM. In our case (on CUDA), it directly invokes a
ampere_sgemm_64x64_tn
: Another one of those low level CUDA kernel for matrix multiplication but this one uses a64x64
tile.- In this kernel, you see the suffix:
tn
meaning that Transpose on A and No transpose on matrix B. - This kernel is optimized to give the best performance given the specific matrix shapes.
- In this kernel, you see the suffix:
Now that you understand how to profile your model and how to make sense of the results table, in the next two posts we will be understanding how you can identify bottlenecks from this table and then write custom CUDA kernel that fuses multiple ops to relieve said bottlenecks.