-
applies a tiling scheme (similar to matmuls) in transformer models
-
instead of doing
- matmul, mask, softmax, dropout, softmax
- single kernel
-
we never have enough gpu memory, so we need to use tiling
-
7 years ago, could saturate the gpu by sending multiple kernels
-
Attention as classification
- d dimensional input activations q
- c by d final layer k
- logits = q K^T
- softmax
- class probabilities p = softmax(logits)
- rearrange this
- d dimensional input activations q
- d dimensional class embeddings K
- scalar product to get logits qK_i
- class probabilities with scaling
- attention becomes a classification problem to find which row of should we pick
-
multi head attention
- multiple attentions of this type (smaller d) simultaneously
- heads are independent -- fully parallelizable
-
attention on the gpu
- distribute these to use all the SMs
- one head per block on the gpu -- too much interdependence in the block
-
assumptions, Q, K, V have the same shape
- seq length N
- head dimension of d
- P = softmax(sQK^T, dim=-1), O = PV, scaling factor s = sqrt(d)
- try to avoid materializing P
- avoid materializing intermediate matrices
-
largest constraint is the dimension of each head, which affects registers
- have such a need for shared memory will move it to registers
-
tiling strategy
- contraction dimension, softmax dimension coincide
- need input to the softmax
- relies on huge tiles
-
fa2 rearranged tile order to avoid writing things to dram compared to fa1
-
stabilized softmax
- to get e(li) / sum(e(lj)); always between 0 & 1
- both numerator and denominator are large; not stable numerically
- stabilized softmax: e(li-m) / sum(e(lj-m))
expand the fraction by setting m = max(li)
- makes the terms become much smaller
- whenever there's a new m update the output by multiplying for online stabilized softmax
-
this is a very sensitive computation
- rounding difference between fused multiply and add vs unfused multiply and add are bad enough
-
masks cause a non-rectangular block layout
-
fa2 uses cutlass to use tensor cores
- mm multiplication primitives
- very large c++ file; 40gb compiling
- tiling options -- 64 or 128
-
nice to see someone do this live
- implemented it in pytorch
-
can call numba kernels with pytorch tensors with a cuda array interface
-
sm has 16k single precision registers; 200 registers per thread
-
48k shared mem
-
showing native cuda compile
- python-cuda
- providing bindings for nvrtc
- compile cuda code to kernel inline
-
register spill can be checked in godbolt
-
thunder: source to source pytorch compiler (thunder.jit)