Fixing LLM Inference Bottlenecks with Custom CUDA Kernels (Simple Kernel)
Table of Contents (Blog Series)
- Part 1: Profiling
- Part 2: Simple Fused Kernel (this post)
- Part 3: Advanced Fused Kernel (Coming soon)
1. Introduction
This blog post is the second one in the “Fixing LLM Inference Bottlenecks with Custom CUDA Kernel” series and assumes you have read the first blog which is about Profiling. I highly recommend reading the previous blog if you haven’t already as you otherwise won’t have the background or the context to fully understand this one.
In this blog post we will be taking a closer look at the profiled results table, as well as the benchmarked time, identifying a “simple” bottleneck, writing a custom CUDA kernel + a torch binding, fixing the identified bottleneck and verifying the increase in inference speed.
When I say “simple” here, I am talking about the CUDA kernel for a rather insignificant op which may not give us a huge speed-up, but it will be easier to follow (I hope) and have a positive speed-up. Essentially, I want you (the reader) to leave this blog with net positive knowledge of something that works meaningfully (however small that “meaning” may be).
2. Searching for a Simple Bottleneck
Continuing where we left off in the profiling blog, let’s re-run the profiling, and display more than 5 ops this time.
# Print the profiled results in a table
print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="cuda_time_total", row_limit=20,
)
)
Upon running this code snippet, you will see the raw results as follows:
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls Input Shapes
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
ampere_sgemm_128x64_nn 0.00% 0.000us 0.00% 0.000us 0.000us 17.517ms 46.50% 17.517ms 364.935us 48 []
aten::linear 0.00% 1.814us 0.32% 121.142us 121.142us 0.000us 0.00% 7.253ms 7.253ms 1 [[1, 1000, 768], [50257, 768], []]
aten::matmul 0.01% 5.315us 0.31% 115.422us 115.422us 0.000us 0.00% 7.253ms 7.253ms 1 [[1, 1000, 768], [768, 50257]]
aten::mm 0.07% 26.606us 0.28% 104.650us 104.650us 7.253ms 19.25% 7.253ms 7.253ms 1 [[1000, 768], [768, 50257]]
ampere_sgemm_128x64_tn 0.00% 0.000us 0.00% 0.000us 0.000us 7.253ms 19.25% 7.253ms 7.253ms 1 []
aten::scaled_dot_product_attention 0.16% 61.564us 0.86% 321.059us 26.755us 0.000us 0.00% 6.391ms 532.558us 12 [[1, 12, 1000, 64], [1, 12, 1000, 64], [1, 12, 1000, 64], [], [], [], [], []]
aten::_scaled_dot_product_efficient_attention 0.12% 46.681us 0.69% 259.495us 21.625us 0.000us 0.00% 6.391ms 532.558us 12 [[1, 12, 1000, 64], [1, 12, 1000, 64], [1, 12, 1000, 64], [], [], [], [], []]
aten::_efficient_attention_forward 0.21% 79.827us 0.45% 167.215us 13.935us 6.391ms 16.97% 6.391ms 532.558us 12 [[1, 1000, 12, 64], [1, 1000, 12, 64], [1, 1000, 12, 64], [], [], [], [], [], []
fmha_cutlassF_f32_aligned_64x64_rf_sm80(PyTorchMemEf... 0.00% 0.000us 0.00% 0.000us 0.000us 6.391ms 16.97% 6.391ms 532.558us 12 []
aten::addmm 0.32% 120.418us 0.59% 220.200us 18.350us 5.780ms 15.34% 5.780ms 481.667us 12 [[768], [1000, 3072], [3072, 768], [], []]
aten::addmm 0.28% 103.600us 0.38% 142.193us 11.849us 5.723ms 15.19% 5.723ms 476.915us 12 [[3072], [1000, 768], [768, 3072], [], []]
aten::addmm 0.34% 128.653us 0.55% 206.349us 17.196us 4.515ms 11.98% 4.515ms 376.219us 12 [[2304], [1000, 768], [768, 2304], [], []]
aten::mul 0.42% 157.371us 2.50% 935.477us 25.985us 1.774ms 4.71% 1.774ms 49.281us 36 [[1, 1000, 3072], []]
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.774ms 4.71% 1.774ms 49.281us 36 []
aten::addmm 0.36% 135.718us 0.55% 206.691us 17.224us 1.530ms 4.06% 1.530ms 127.475us 12 [[768], [1000, 768], [768, 768], [], []]
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.529ms 4.06% 1.529ms 41.323us 37 []
aten::add 0.18% 68.608us 0.24% 90.358us 7.530us 1.113ms 2.96% 1.113ms 92.768us 12 [[1, 1000, 3072], [1, 1000, 3072], []]
aten::mul 0.11% 40.528us 0.17% 63.433us 5.286us 1.098ms 2.92% 1.098ms 91.526us 12 [[1, 1000, 3072], [1, 1000, 3072]]
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.098ms 2.92% 1.098ms 91.526us 12 []
aten::pow 0.20% 73.946us 0.29% 106.669us 8.889us 634.662us 1.68% 634.662us 52.888us 12 [[1, 1000, 3072], []]
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Now there is a lot of signal (or noise, depending on who you may ask) in that table, but let’s identify some lone ops that stand out.
One potential lead are point-wise ops, which we see a lot of here. More specifically these: aten::pow
, aten::mul
, aten::add
(actually, there are two aten::mul
ops in the table, more on that later).
But wait, what is a point-wise op again? Essentially, any op that applies to all it’s input elements separately and independently is point-wise.
Let’s take for example the addition
operation. If you have two arrays A
and B
(both 1-dimensional with same length) and you add them: A + B = C
, the addition will take each element in A
and add it to each corresponding element in B
. This a point-wise op because in order to produce C[i]
, the addition function will only need to care about A[i]
and B[i]
and nothing else.
Why do we care about point-wise ops here? Because in the most vanilla implementation of models (including LLMs), they are the lowest hanging fruits, ripe for fixing.
But before we jump to writing our kernel, where do these point-wise ops belong to? By this I mean which op do they map to in the 🤗 transformers implementation of GPT2? The answer is: GELU activation function. I mean not all, but most of them are from this activation function that’s used in the GPT2 model in the MLP (or feed foward network) layer.
How do I know? Let me explain along with how much CUDA time is spent on them combined.
3. GELU activation bottleneck
Now, I assume you already know what is this “GELU guy” and where is it coming from, if not, please read the original GPT paper. However to do a quick refresher of what GELU does, here’s the equation:
\[\text{GELU}(x) = 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} \left( x + 0.044715 x^3 \right) \right)\right)\]Okay, when I said the ops are from the GELU activation function part, I didn’t mean the torch.nn.functional.gelu
. The reason being that 🤗 transformers implementation of GPT2 defines it separately which is very similar (numerically) to torch’s GELU implementation, but nevertheless implemented separately.
Here’s the implementation from transformers package for our reference (source):
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
Now, how do we map those point-wise ops to parts of this activation function? Let me explain in a tabular manner:
Expression | Corresponding Op | Why? |
---|---|---|
torch.pow(input, 3.0) |
aten::pow |
Each element is raised to the power of 3 |
0.044715 * (…) |
aten::mul |
Multiplying a scalar by a tensor |
input + (…) |
aten::add |
Point-wise addition of original input with ‘transformed’ input |
math.sqrt(2/π) * (…) |
aten::mul |
Multiplying a scalar by a tensor |
1.0 + (…) |
aten::add |
Point-wise addition of a scalar with the tensor (after broadcasting the scalar) |
input * (…) |
aten::mul |
Point-wise multiplication of input with the ‘transformed’ input |
0.5 * (…) |
aten::mul |
Multiplying a scalar by a tensor |
I have written the table in an inside-out manner so the most “inside” operation from the python code is at the top of the table and so on. Take some time to read each op in the table, find where it is in the Python code above the table and then read the corresponding CUDA op and the reason why that specific op is triggered.
Now, if you read this table and mapped it to the Python code carefully, you must’ve noticed I left out one op: torch.tanh(...)
. Actually, it’s also a point-wise op and it also incurs overhead but since it doesn’t show up in the top-20 ops in the profiled table, I decided to not mention it (for consistency). However, you will be right in thinking this also counts here.
Crucial Learnings: Assuming you are running the GPT2 vanilla from transformers without any optimizations or kernels fusions (i.e: without using things like torch.compile
, Triton-Fusion or other packages), PyTorch dispatcher will launch each one of the aforementioned point-wise operations as separate CUDA kernels. This means that every time the GELU activation function in GPT2 is calculated, PyTorch launches 8 separate CUDA kernels (including the 7 from the table + aten::tanh
).
You can actually see how many times each of those ops were called (in the # of Calls
column in the profiled table) and verify the above conclusion yourself.
Let’s also quickly calculate how much CUDA time is spent by those operations (from the profiled table):
GELU Op | # of launches (12 layers) | CUDA time (ms) |
---|---|---|
aten::pow |
12 | 0.635 |
aten::mul (scalar × tensor -> 36 calls) |
36 | 1.774 |
aten::add |
12 | 1.113 |
aten::mul (point-wise op input * tanh_part ) |
12 | 1.098 |
Subtotal / GELU call alone | 72 | ≈ 4.62 ms |
Note: All of these kernel calls are heavily memory-bound and not compute bound, which is what they should aim to be.
So on each forward pass (inference mode), we are spending 4.62 ms just performing this un-optimized GELU operation. This may seem like an insignificant number, but you must remember that we only ran a single, 1000 token sample that too only once.
Assuming we deploy our GPT2 model for inference on an RTX 4070 GPU (which is where I ran these numbers) for a month (720 hours) and just run the same, 1000 token input sequence again and again (much much smaller than real life inference workloads):
- We waste:
4.62 ms / 38 ms (total CUDA time) ≈ 0.122
or 12.2% of our total CUDA time (or total GPU time). - Which is
0.122 * 720 hours ≈ 88
=> 88 hours of GPU compute per month. Let’s take a lower estimate of 85 hours. - With an average rate of $0.15/hr for an RTX 4070, we waste
0.15 * 85 ≈ $12.75
per month.
And this $12.75 is for one deployment on a rather conservative inference workload. Imagine wasting many times this amount on such a small operation!
4. Fixing the bottleneck
Now that we have identified the bottleneck and understood where it is coming from, let’s start fixing it. In total, we will be writing 3 files:
- A CUDA Kernel
- A C++ Binding (glue between our kernel and the Python code)
- Python code for patching and running this op in our GPT2 model for inference
4.1 Writing the CUDA Kernel (fused_gelu.cu
)
I assume you know the basics of CUDA, at least how to write a hello-world CUDA kernel. If not, check out this blog by NVIDIA.
First we start by importing packages necessary packages
#include <cuda_runtime.h>
#include <math.h>
extern "C"{
// Rest of our upcoming code will go here
}
We import math.h
in order to use math functions like tanhf
and importing cuda_runtime.h
gives us access to CUDA runtime API.
What does this extern "C"
thing do?
This tells the CUDA compiler to not change the function names so that our C++ binding (wrapper) can find them later. It’s basically asking the compiler to preserve C naming rules. This is crucial as without it, the compiler would mangle up the name into a C++ style name and then our binding won’t be able to find it.
Now, let’s define our custom GELU activation function that will be called in the kernel:
__device__ __forceinline__
float gelu(float x){
const float two_by_pi = 0.7978845608028654f; // literally: √(2/π)
return 0.5f * x * (1.0f + tanhf(two_by_pi * (x + 0.044715f * x * x * x)));
}
A few things to note here:
two_by_pi
is √(2/π) which we have hardcoded to avoid the repeated constant computations.- We use
tanhf
for calculating the hyperbolic tangent of an input array consisting of floats. __device__
is for any function that will be executed on GPU (but is not a kernel) and returns something (remember, a kernel doesn’t return anything).__forceinline__
tells the compiler to forcibly “inline” this function. Simply put: the function’s body is copied into the calling kernel during compilation to remove the function call overhead (since our function is rather simple).
Rest of this GELU implementation is literally the transformer’s PyTorch GELU implementation from Section 3, but just transliterated in CUDA C with best practices.
Moving on, we will now the write the actual kernel that reads one input element at a time from memory, performs GELU on it (remember, inline) and writes it back into the memory.
__global__
void fused_gelu_kernel(const float *__restrict__ input, float *__restrict__ output, int total){
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < total){
float val = input[idx]; // access the input element from memory only once
output[idx] = gelu(val); // apply gelu and then write this back
}
}
The actual kernel is very straightforward: We read the element this specific thread is responsible for, check if it’s in bounds by comparing it with the total number of elements, apply GELU function on it and then write it back into the output
array at the same index we read it from.
Quick refresher on CUDA (if you are confused): Each thread in our blocks (assigned for this computation) executes one instance of this kernel in parallel on the GPU. This is the real power of GPUs - we can process all elements of the input array in parallel at the same time as this operation is point-wise and applying GELU on one element of the input
array is independent of the other elements of the array.
To visualize this kernel in your mind: Imagine that our input
array has 512 elements => there will be 512 individual threads all running in parallel, computing one element of the array each and then writing it back to the memory. No two threads will be calculating the same elements and they won’t be calculating something from the garbage memory either (that is what our if (idx < total){...}
check ensures).
These threads
are divided in blocks
which are arranged along x, y and z axis in the grid
. To get the index a kernel is responsible for, we just need to get the current thread’s place in the block (threadIdx.x
), current block’s place in the grid (blockIdx.x
) and the total number of blocks in the grid (blockDim.x
). Using all these three, we can get the unique idx
in the grid.
Some points to note from the kernel:
__restrict__
is present for bothinput
andoutput
arrays because we know for a fact their pointers will not point to the same memory location (input array is not output array) and so__restrict__
tells the compiler to apply optimizations keeping in mind these two pointers will never point to the same place.total
is the number of elements in the input array.
But wait?: Isn’t the input
array the hidden state of the model which is 3-dimensional? How are we treating it as a 1D array?
Well because in memory, it’s stored as one contiguous array of elements and since GELU is a point-wise op, we don’t have to worry about anything more than the element at hand. That’s why this is a “simple” kernel ;)
This isn’t always the case by the way, in fact in most cases one has to do indexing parkour in kernels to get the element(s) they require from the memory when writing custom kernels. We will cover this too, in future blogs.
Now let’s write a function that will launch this kernel with required number of threads and blocks:
void launch_fused_gelu(const float *input, float *output, int total, cudaStream_t stream){
const int threads = 256; // total threads in a block
const int blocks = (total + threads - 1) / threads; // total blocks in the grid
fused_gelu_kernel<<< blocks, threads, 0, stream >>>(input, output, total); // launch the kernel
}
Some points:
- We allocate the number of blocks required to cover total number of elements and then some, by performing ceiling division so we don’t have to hardcode a big number.
- The threads are per block, meaning that we allocate a bunch of 256-thread blocks enough to cover the elements we are going to GELU.
Now, let’s write the binding that will glue everything together.
4.2 Writing the C++ Binding (fused_gelu_binding.cpp
)
The C++ binding will serve as the glue between our kernel and the Python code where we run the inference on our LLM.
In the binding, we will do the following:
- Sanity checks on the
input
tensor - Allocating the
output
tensor - Running the kernel which populates the
output
tensor - Registering the fused gelu op as a PyTorch op (multiple steps here)
Let’s start:
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
extern "C" void launch_fused_gelu(const float*, float*, int, cudaStream_t);
Here we:
- Import the torch package in C++ for creating the
output
tensor and performing CUDA sanity checks - Import the CUDA context for the
cudaStream_t
which allows us to execute this kernel in the same CUDA stream as rest of the PyTorch operations during our LLM inference. - Forward declare the kernel helper in
extern "C" ...
which is needed because the CUDA object will be linked to the binding later.
Now we will make a launcher function in C++ that performs sanity checks, declares the output
tensor and launches the CUDA kernel we wrote before:
torch::Tensor fused_gelu_launcher(torch::Tensor input){
TORCH_CHECK(input.is_cuda(), "Input tensor must be on the GPU");
TORCH_CHECK(input.dtype() == torch::kFloat32, "Use fp32");
auto output = torch::empty_like(input);
launch_fused_gelu(
input.data_ptr<float>(),
output.data_ptr<float>(),
input.numel(),
c10::cuda::getCurrentCUDAStream()
);
return output;
}
Above, we:
- Assign the return type of the launcher to a PyTorch tensor (
torch::Tensor
) as it returns the output tensor - Check if the
input
tensor is on the GPU and is offloat32
datatype - Initialize an empty
output
tensor with same shape asinput
- Run the CUDA kernel (from the extern line we defined above) which passes the pointers to the
input
,output
, and the number of elements ininput
tensor (which is thetotal
in our CUDA kernel implementation) along with the current cuda stream to the kernel and returns the populatedoutput
tensor.
This C++ function is purely a launcher that does sanity checks and wraps around the actual kernel.
Finally, we register our op, first as an op schema in the new namespace we create for our op:
TORCH_LIBRARY(fg, m){
m.def("fused_gelu(Tensor input) -> Tensor");
}
This is so that we could call our op from PyTorch as: torch.ops.fg.fused_gelu(x)
. fg
is the namespace name and fused_gelu
is the function name.
Then we add the CUDA implementation of this schema to our C++ function that launches the kernel:
TORCH_LIBRARY_IMPL(fg, CUDA, m){
m.impl("fused_gelu", TORCH_FN(fused_gelu_launcher));
}
In the above two code-blocks, we first defined the schema then added it’s implementation.
Finally, we just need to generate an empty Python Binding so torch.utils.cpp_extension
can import it:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// (no additional Python bindings needed
// as this just generates PyInit_<name>)
}
And we are done in terms of binding. Now all we need to do is load this custom op we just defined in our Python file, patch this fused gelu op in our LLM implementation and run inference!
4.3 Patching the Custom Op + Inference (fused_gelu_gpt2_inference.py
)
We are almost at the home-stretch now! Let’s start by importing necessary packages:
import time
import types
import torch
import numpy as np
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.cpp_extension import load
And load our fused op and binding in PyTorch:
fused = load(
name="fused_gelu",
sources=["csrc/fused_gelu.cu", "csrc/bindings/fused_gelu_binding.cpp"],
extra_cuda_cflags=["--expt-relaxed-constexpr", "--use_fast_math"],
extra_cflags=["-O2"],
verbose=False,
)
Understanding this snippet:
load(...)
function does two things here:- It tells PyTorch which translation units to build
- It tells the C++ compiler (for example:
g++
) and the CUDA compiler (nvcc
) which corresponding flags to use
- Why do we pass both the binding and the CUDA files? Because the function needs to see both files so it can:
- Compile the
.cu
withnvcc
, - Compile the
.cpp
with the host compiler (eg: g++) - Link the two object files (plus CUDA runtime libraries) into a single shared object that Python can
import
.
- Compile the
- What do those flags mean?
--expt-relaxed-constexpr
: For future-proofing the build as it will allow us to write header-only utilities--use-fast-math
: For tellingnvcc
to replace IEEE-compliant math functions (tanhf
,powf
,sinf
, etc) with faster, less accurate hardware intrinsics (__tanhf
,__powf
,__sinf
, etc). Intrinsic functions run 2-3x faster than stock math functions but they allow larger relative errors (we will come to this later).-O2
: This C++ flag (above two were CUDA flags) tells the compiler to perform standard host-side optimizations (-O3
is a more hardcore version of this flag).
Let’s now initialize the GPT2 model, tokenizer and use our dummy input text from last blog which we will use now for inference:
model_name = "gpt2"
device = "cuda"
text = "A quick brown fox jumped upon a lazy dog." * 100
model = AutoModelForCausalLM.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(text, return_tensors="pt").to(device)
Finally we come to the interesting part: patching our custom op in to the model! First we define a custom torch function for our op:
class FusedGELU(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# As said in last section, we can now use our custom op
# as if it was from pytorch itself!
return torch.ops.fg.fused_gelu(x)
@staticmethod
def backward(ctx, grad_out):
raise RuntimeError("Backward fn not implemented")
We use torch.autograd.Function
here because it’s the recommended way in PyTorch to wrap a C++/CUDA op into the autograd graph. We only define the forward function where we just call our fused op as our focus here is LLM inference. As an exercise, you can try to first derive (by hand if you are into that stuff), then implement the backward function as a custom kernel.
Let’s define a function which will performs the “forward” pass for the MLP module. The idea here is to replace the feedforward network (mlp) layer of a GPT2 block with our custom forward function so we can apply our fused kernel there as it’s not possible to intercept just the activation part from the original Huggingface transformers implementation.
def fused_mlp_forward(self, hidden_states):
x = torch.nn.functional.linear(
input=hidden_states, weight=self.c_fc.weight.t(), bias=self.c_fc.bias
)
x = FusedGELU.apply(x) # apply our custom GELU op on the hidden_states (x)
return self.c_proj(x)
This of course will involve us adding the linear layer ourselves before applying the GELU.
Note: You may be wondering, why we are transposing the weights in the linear layer computation? (i.e this part: ...weight=self.c_fc.weight.t()...
) as the linear layer already does the transpose: \(y = xW^T+b\)?
Well you are right in thinking this but the reason we transpose weights is because in the Huggingface transformer’s GPT2 implementation, they don’t actually use linear
layer for the MLP layer. They instead use Conv1D
(source) and Conv1D
stores the weights in a different layout (out-features x in-features) which would otherwise be incompatible with the matmul inside linear
. So we transpose the weights to make them compatible.
Finally, we monkey patch our custom MLP forward function into the model:
def patch(module):
if isinstance(module, transformers.models.gpt2.modeling_gpt2.GPT2MLP):
module.forward = types.MethodType(fused_mlp_forward, module)
# Visits every sub-module in the model and patches the GPT2MLP ones
model.apply(patch)
Here, the monkey patch only happens when module in question is the GPT2MLP
module (which is the MLP module).
And this is it! We created a custom CUDA kernel, its binding and patched it in the model successfully! Now let’s test if the outputs of this model are the same as the baseline vanilla GPT2 model during inference (to verify we didn’t accidentally contaminate the forward pass):
baseline_model = AutoModelForCausalLM.from_pretrained(model_name).to(device).eval()
fused_output = model.generate(**inputs)
baseline_output = baseline_model.generate(**inputs)
print(fused_output.squeeze().tolist() == baseline_output.squeeze().tolist())
The output comes out to be True
.
Now of course this is not the best way to check for mismatched forward passes and I encourage you to compare the final logits of both outputs yourself. Here’s the hint on what you will see: a very small percentage of logits mismatch (<5%). Why? Well, do you remember the --use-fast-math
flag we used when loading the op in PyTorch? As I explained before, it allows us to use intrinsic functions for faster calculation but with less accuracy. This is the most probable cause for our slight logit mismatch.
However the mismatch is so minute, it doesn’t really cause performance regression in any meaningful way (for now). That being said, the performance vs speed trade-off is one you have to make depending on your use-case and your appetite for taking accuracy hits in order to increase speed.
5. Speed-up and Conclusion
Now that we have a “supposed” faster model thanks to our fused implementation, let’s do profiling once again to see how much time our fused op takes and verify that it is faster:
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls Input Shapes
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
ampere_sgemm_128x64_nn 0.00% 0.000us 0.00% 0.000us 0.000us 17.492ms 52.08% 17.492ms 364.414us 48 []
aten::linear 0.00% 1.584us 0.34% 115.798us 115.798us 0.000us 0.00% 7.283ms 7.283ms 1 [[1, 1000, 768], [50257, 768], []]
aten::matmul 0.01% 4.830us 0.33% 112.214us 112.214us 0.000us 0.00% 7.283ms 7.283ms 1 [[1, 1000, 768], [768, 50257]]
aten::mm 0.06% 19.450us 0.31% 105.226us 105.226us 7.283ms 21.68% 7.283ms 7.283ms 1 [[1000, 768], [768, 50257]]
ampere_sgemm_128x64_tn 0.00% 0.000us 0.00% 0.000us 0.000us 7.283ms 21.68% 7.283ms 7.283ms 1 []
aten::scaled_dot_product_attention 0.57% 190.905us 1.28% 430.773us 35.898us 0.000us 0.00% 6.336ms 528.034us 12 [[1, 12, 1000, 64], [1, 12, 1000, 64], [1, 12, 1000, 64], [], [], [], [], []]
aten::_scaled_dot_product_efficient_attention 0.17% 58.427us 0.71% 239.868us 19.989us 0.000us 0.00% 6.336ms 528.034us 12 [[1, 12, 1000, 64], [1, 12, 1000, 64], [1, 12, 1000, 64], [], [], [], [], []]
aten::_efficient_attention_forward 0.17% 57.751us 0.41% 139.729us 11.644us 6.336ms 18.87% 6.336ms 528.034us 12 [[1, 1000, 12, 64], [1, 1000, 12, 64], [1, 1000, 12, 64], [], [], [], [], [], []
fmha_cutlassF_f32_aligned_64x64_rf_sm80(PyTorchMemEf... 0.00% 0.000us 0.00% 0.000us 0.000us 6.336ms 18.87% 6.336ms 528.034us 12 []
aten::addmm 0.23% 76.907us 0.44% 146.848us 12.237us 5.777ms 17.20% 5.777ms 481.413us 12 [[768], [1000, 3072], [3072, 768], [], []]
aten::linear 0.08% 26.127us 0.52% 175.984us 14.665us 0.000us 0.00% 5.716ms 476.351us 12 [[1, 1000, 768], [3072, 768], [3072]]
aten::addmm 0.24% 80.694us 0.35% 118.809us 9.901us 5.716ms 17.02% 5.716ms 476.351us 12 [[3072], [1000, 768], [768, 3072], [], []]
aten::addmm 0.35% 117.105us 0.55% 185.574us 15.465us 4.511ms 13.43% 4.511ms 375.885us 12 [[2304], [1000, 768], [768, 2304], [], []]
aten::addmm 0.29% 97.916us 0.57% 192.381us 16.032us 1.518ms 4.52% 1.518ms 126.534us 12 [[768], [1000, 768], [768, 768], [], []]
FusedGELU 0.34% 116.281us 0.59% 199.785us 16.649us 0.000us 0.00% 1.043ms 86.908us 12 [[1, 1000, 3072]]
fg::fused_gelu 0.06% 21.698us 0.25% 83.504us 6.959us 1.043ms 3.11% 1.043ms 86.908us 12 [[1, 1000, 3072]]
fused_gelu_kernel 0.00% 0.000us 0.00% 0.000us 0.000us 1.043ms 3.11% 1.043ms 86.908us 12 []
aten::contiguous 0.05% 17.121us 0.88% 295.589us 8.211us 0.000us 0.00% 542.139us 15.059us 36 [[1, 12, 1000, 64], []]
aten::clone 0.12% 41.356us 0.83% 278.468us 7.735us 0.000us 0.00% 542.139us 15.059us 36 [[1, 12, 1000, 64], []]
aten::copy_ 0.25% 84.645us 0.46% 156.669us 4.352us 542.139us 1.61% 542.139us 15.059us 36 [[1, 12, 1000, 64], [1, 12, 1000, 64], []]
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
As you can see, all those aten::
point-wise ops are gone and are replaced by our custom fg::fused_gelu
op!
But wait Tanay, why do I see three of our op’s entries in the profiled results?
That’s because, in the PyTorch profiler
output, every “layer” of the call-stack that touches our custom op is recorded separately, so we see three distinct rows even though they all belong to one logical GELU step. Here’s a rough flow to better understand this:
FusedGELU ← Python wrapper (parent node)
└─ fg::fused_gelu ← C++ op (child)
└─ fused_gelu_kernel ← CUDA kernel (grand-child)
Speed-up: Our fg::fused_gelu
kernel now takes 1.043 ms down from 4.62 ms that the unoptimized GELU ops took. That is a 4.62 - 1.04 = 3.58 ms time improvement for each forward pass and 72 - 12 = 60 less CUDA kernel calls.
In terms of overall inference speed, we went down from 38 ms to 34.4 ms which is approximately a 9%-10% speed boost (counting the variance between different runs). This is like saving 9% of GPU hours with just 2 hours of work (+ debugging). Pretty good for a “simple” kernel, but we can always do better!
You can find all three files we coded in this blog here.
In next blog, we will implement a more “advanced” kernel to understand these concepts better! If you have comments / suggestions about this blog or you want to get in touch, you can email me on heyytanay@gmail.com or approach me on LinkedIn or Twitter.