Fused Activations & Simdgroup Roles¶
This example shows two patterns: simple element-wise activations, and using simdgroup_role
to run different computations on different simdgroup subsets within a single kernel.
Simple Activations¶
Element-wise kernels follow the same pattern as vector add (load, compute, store):
import metile
@metile.kernel
def gelu(X, Out, N, BLOCK: metile.constexpr):
pid = metile.program_id(0)
offs = pid * BLOCK + metile.arange(0, BLOCK)
mask = offs < N
x = metile.load(X + offs, mask=mask)
# GELU approximation: x / (1 + exp(-1.702 * x))
out = x / (1.0 + metile.exp(-1.702 * x))
metile.store(Out + offs, out, mask=mask)
@metile.kernel
def silu(X, Out, N, BLOCK: metile.constexpr):
pid = metile.program_id(0)
offs = pid * BLOCK + metile.arange(0, BLOCK)
mask = offs < N
x = metile.load(X + offs, mask=mask)
# SiLU (Swish): x / (1 + exp(-x))
out = x / (1.0 + metile.exp(-x))
metile.store(Out + offs, out, mask=mask)
Fused GEMM + Activation¶
When an activation follows a dot operation, the compiler fuses it into the GEMM epilogue.
The activation runs on register-resident data, no global memory round-trip:
@metile.kernel
def matmul_gelu(A, B, C, M, N, K,
BLOCK_M: metile.constexpr, BLOCK_N: metile.constexpr,
BLOCK_K: metile.constexpr):
pid_m = metile.program_id(0)
pid_n = metile.program_id(1)
acc = metile.zeros((BLOCK_M, BLOCK_N), dtype="f32")
for k in metile.tile_range(0, K, BLOCK_K):
a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K))
b = metile.tile_load(B, k, pid_n * BLOCK_N, N, (BLOCK_K, BLOCK_N))
acc = metile.dot(a, b, acc)
# Fused GELU epilogue, runs on accumulator registers
acc = acc / (1.0 + metile.exp(-1.702 * acc))
metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N))
Simdgroup Roles¶
Apple GPUs organize threads into 32-thread simdgroups. A threadgroup can contain
multiple simdgroups. With simdgroup_role, you can assign different work to different
simdgroup subsets, useful for computing multiple outputs in a single dispatch:
@metile.kernel
def exp_sqrt(X, out_exp, out_sqrt, N, BLOCK: metile.constexpr):
pid = metile.program_id(0)
offs = pid * BLOCK + metile.arange(0, BLOCK)
mask = offs < N
with metile.simdgroup_role(role=0, num_roles=2):
# First half of simdgroups compute exp
x = metile.load(X + offs, mask=mask)
metile.store(out_exp + offs, metile.exp(x), mask=mask)
with metile.simdgroup_role(role=1, num_roles=2):
# Second half compute sqrt(abs(x))
x = metile.load(X + offs, mask=mask)
metile.store(out_sqrt + offs, metile.sqrt(metile.abs(x)), mask=mask)
With num_roles=2, the threadgroup’s simdgroups are split in half. Role 0 computes
exponentials while role 1 computes square roots, simultaneously, in the same kernel launch.
GEGLU (Gated GELU)¶
A practical use of simdgroup roles for computing the gate and up projections of GEGLU in parallel:
@metile.kernel
def geglu(X_gate, X_up, Out, N, BLOCK: metile.constexpr):
pid = metile.program_id(0)
offs = pid * BLOCK + metile.arange(0, BLOCK)
mask = offs < N
with metile.simdgroup_role(role=0, num_roles=2):
gate = metile.load(X_gate + offs, mask=mask)
gate = gate / (1.0 + metile.exp(-1.702 * gate))
metile.store(Out + offs, gate, mask=mask)
with metile.simdgroup_role(role=1, num_roles=2):
up = metile.load(X_up + offs, mask=mask)
gate = metile.load(Out + offs, mask=mask)
metile.store(Out + offs, gate * up, mask=mask)
Concepts Introduced¶
Element-wise activation patterns
metile.expfor activation functionsFused GEMM epilogues: zero-cost post-GEMM operations
metile.simdgroup_role: split work across simdgroup subsetsMultiple outputs from a single kernel