Layer Normalization¶
Row-wise layer normalization: Out = ((X - mean) / sqrt(var + eps)) * W + B.
Kernel¶
import metile
@metile.kernel
def layernorm(X, W, B, Out, N, BLOCK: metile.constexpr):
row = metile.program_id(0)
# Pass 1: compute mean
_sum = 0.0
for i in metile.tile_range(0, N, BLOCK):
cols = i + metile.arange(0, BLOCK)
mask = cols < N
x = metile.load(X + row * N + cols, mask=mask)
_sum = _sum + x
mean = metile.sum(_sum) / N
# Pass 2: compute variance
_var = 0.0
for i in metile.tile_range(0, N, BLOCK):
cols = i + metile.arange(0, BLOCK)
mask = cols < N
x = metile.load(X + row * N + cols, mask=mask)
diff = x - mean
_var = _var + diff * diff
var = metile.sum(_var) / N
# Pass 3: normalize, scale, shift
inv_std = 1.0 / metile.sqrt(var + 1e-5)
for i in metile.tile_range(0, N, BLOCK):
cols = i + metile.arange(0, BLOCK)
mask = cols < N
x = metile.load(X + row * N + cols, mask=mask)
w = metile.load(W + cols, mask=mask)
b = metile.load(B + cols, mask=mask)
out = (x - mean) * inv_std * w + b
metile.store(Out + row * N + cols, out, mask=mask)
Launching¶
import numpy as np
rows, hidden = 128, 512
X = metile.Buffer(data=np.random.randn(rows, hidden).astype(np.float32))
W = metile.Buffer(data=np.ones(hidden, dtype=np.float32))
B = metile.Buffer(data=np.zeros(hidden, dtype=np.float32))
Out = metile.Buffer.zeros((rows * hidden,))
layernorm[(rows,)](X, W, B, Out, hidden, BLOCK=256)
Concepts Introduced¶
Three-pass algorithm (mean, variance, normalize)
Scalar accumulators across tiled loops
metile.sumreductionmetile.sqrt: element-wise square rootLoading separate weight/bias arrays (shared across all rows)