- 
applies a tiling scheme (similar to matmuls) in transformer models 
- 
instead of doing 
- matmul, mask, softmax, dropout, softmax
- single kernel
 
- 
we never have enough gpu memory, so we need to use tiling 
- 
7 years ago, could saturate the gpu by sending multiple kernels 
- 
Attention as classification 
- d dimensional input activations q
- c by d final layer              k
- logits = q K^T
- softmax
- class probabilities p = softmax(logits)
- rearrange this
- d dimensional input activations q
- d dimensional class embeddings K
- scalar product to get logits qK_i
- class probabilities with scaling
 
- attention becomes a classification problem to find which row of should we pick
 
- 
multi head attention 
- multiple attentions of this type (smaller d) simultaneously
- heads are independent -- fully parallelizable
 
- 
attention on the gpu 
- distribute these to use all the SMs
- one head per block on the gpu -- too much interdependence in the block
 
- 
assumptions, Q, K, V have the same shape 
- seq length N
- head dimension of d
- P = softmax(sQK^T, dim=-1), O = PV, scaling factor s = sqrt(d)
- try to avoid materializing P
- avoid materializing intermediate matrices
 
 
- 
largest constraint is the dimension of each head, which affects registers 
- have such a need for shared memory will move it to registers
 
- 
tiling strategy 
- contraction dimension, softmax dimension coincide
- need input to the softmax
- relies on huge tiles
 
- 
fa2 rearranged tile order to avoid writing things to dram compared to fa1 
- 
stabilized softmax 
- to get e(li) / sum(e(lj)); always between 0 & 1
- both numerator and denominator are large; not stable numerically
- stabilized softmax: e(li-m) / sum(e(lj-m))
expand the fraction by setting m = max(li)
- makes the terms become much smaller
- whenever there's a new m update the output by multiplying for online stabilized softmax
 
- 
this is a very sensitive computation 
- rounding difference between fused multiply and add vs unfused multiply and add are bad enough
 
- 
masks cause a non-rectangular block layout 
- 
fa2 uses cutlass to use tensor cores 
- mm multiplication primitives
- very large c++ file; 40gb compiling
- tiling options -- 64 or 128
 
- 
nice to see someone do this live 
- implemented it in pytorch
 
- 
can call numba kernels with pytorch tensors with a cuda array interface 
- 
sm has 16k single precision registers; 200 registers per thread 
- 
48k shared mem 
- 
showing native cuda compile 
- python-cuda
- providing bindings for nvrtc
- compile cuda code to kernel inline
 
- 
register spill can be checked in godbolt 
- 
thunder: source to source pytorch compiler (thunder.jit)