Flash Attention in ~100 lines of CUDA

This web content presents a minimal re-implementation of Flash Attention with CUDA and PyTorch, aimed to simplify the official implementation which can be intimidating for CUDA beginners. The code is concise, with the forward pass written in roughly 100 lines in flash.cu. The benchmark results show a significant speed-up compared to manual attention, especially on a GPU. However, there are some caveats, such as the lack of a backward pass and slower matrix multiplications due to thread-per-row simplification. Additionally, the implementation uses float32 for Q,K,Vs instead of the original float16, and the block size is fixed at 32. Future improvements include adding a backward pass, speeding up matrix multiplications, and dynamically setting the block size.

https://github.com/tspeterkim/flash-attention-minimal

To top