Language Reference¶
meTile provides a Python eDSL (embedded domain-specific language) for writing GPU kernels. Functions
decorated with @metile.kernel are traced and compiled to Metal shaders. They are not executed
as regular Python.
This page documents every construct available inside a @metile.kernel function.
Kernel Definition¶
@metile.kernel
def my_kernel(ptr_a, ptr_b, N, BLOCK: metile.constexpr):
...
Parameters are either:
Pointers: numpy arrays or
metile.Bufferobjects becomedevice float*in MetalScalars: Python ints/floats become
constant int&orconstant float&Constexprs: annotated with
metile.constexpr, baked into the shader at compile time
Constexprs are passed as keyword arguments at launch:
my_kernel[grid](a, b, N, BLOCK=256)
Launching Kernels¶
kernel[grid](*args, **constexprs)
grid is a tuple of 1, 2, or 3 integers specifying the number of program instances
(threadgroups) along each axis.
kernel[(N,)](...) # 1D grid
kernel[(M, N)](...) # 2D grid
kernel[(X, Y, Z)](...) # 3D grid
Program Identity¶
- metile.program_id(axis)¶
Returns the index of the current program instance along the given axis.
pid_x = metile.program_id(0) # threadgroup X index pid_y = metile.program_id(1) # threadgroup Y index
Index Generation¶
- metile.arange(start, size)¶
Creates a tile of
sizeconsecutive integers starting atstart.idx = metile.arange(0, 256) # [0, 1, 2, ..., 255]
- metile.cdiv(a, b)¶
Ceiling division. Useful for computing grid sizes.
grid_size = metile.cdiv(N, BLOCK) # ceil(N / BLOCK)
Element-wise Memory Access¶
For element-wise kernels (softmax, activations, reductions), use pointer arithmetic
with load and store:
- metile.load(ptr, mask=None)¶
Load elements from memory. Masked-off elements read zero.
offs = pid * BLOCK + metile.arange(0, BLOCK) mask = offs < N x = metile.load(X + offs, mask=mask)
- metile.store(ptr, value, mask=None)¶
Store elements to memory. Masked-off elements are skipped.
metile.store(Out + offs, result, mask=mask)
Tile Memory Access¶
For matrix operations (GEMM), use tile-level loads and stores that map to simdgroup or tensor_ops hardware:
- metile.tile_load(ptr, row_offset, col_offset, stride, shape)¶
Load a 2D tile from row-major memory.
- Parameters:
ptr – base pointer to the matrix
row_offset – row index of tile’s top-left corner
col_offset – column index of tile’s top-left corner
stride – leading dimension (number of columns in the full matrix)
shape –
(rows, cols)of the tile to load
# Load a 128x32 tile of A starting at (pid_m * 128, k) a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K))
- metile.tile_store(ptr, row_offset, col_offset, stride, value, shape)¶
Store a 2D tile to row-major memory.
metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N))
- metile.zeros(shape, dtype='f32')¶
Create a zero-initialized tile. Used to initialize accumulators.
acc = metile.zeros((BLOCK_M, BLOCK_N), dtype="f32")
Matrix Multiply¶
- metile.dot(a, b, acc)¶
Tile-level matrix multiply-accumulate:
acc += a @ b.The compiler maps this to
simdgroup_multiply_accumulate(M1-M3) ormatmul2dtensor_ops (M4+) depending on hardware.acc = metile.zeros((128, 128), dtype="f32") for k in metile.tile_range(0, K, BLOCK_K): a = metile.tile_load(A, pid_m * 128, k, K, (128, BLOCK_K)) b = metile.tile_load(B, k, pid_n * 128, N, (BLOCK_K, 128)) acc = metile.dot(a, b, acc)
Control Flow¶
- metile.tile_range(start, end, step)¶
A tiling loop. Equivalent to
range(start, end, step)but tells the compiler this is a tile-level iteration (e.g., the K-loop in GEMM).for k in metile.tile_range(0, K, BLOCK_K): ...
Math Operations¶
All math ops are element-wise and work on both scalars and tiles:
Function |
Description |
|---|---|
|
Exponential |
|
Natural logarithm |
|
Square root |
|
Absolute value |
|
Hyperbolic tangent |
|
Select |
|
Element-wise maximum |
|
Element-wise minimum |
Standard Python arithmetic works inside kernels: +, -, *, /, <, >, etc.
Reductions¶
- metile.sum(x)¶
Sum-reduce a tile to a scalar.
- metile.max(x)¶
Max-reduce a tile to a scalar.
- metile.min(x)¶
Min-reduce a tile to a scalar.
These compile to simdgroup shuffle reductions on the GPU.
# Two-pass softmax: find max, then compute normalized exponentials
m = -1e38
for i in metile.tile_range(0, N, BLOCK):
cols = i + metile.arange(0, BLOCK)
x = metile.load(X + row * N + cols, mask=cols < N)
m = metile.maximum(m, x)
m = metile.max(m) # reduce across the tile
Advanced: Simdgroup Operations¶
For low-level control over Apple GPU simdgroups:
- metile.simdgroup_role(role, num_roles, body, num_sgs=0)¶
Execute different code on different simdgroup subsets within a threadgroup. Enables producer/consumer patterns.
with metile.simdgroup_role(role=0, num_roles=2): # Only the first half of simdgroups run this ... with metile.simdgroup_role(role=1, num_roles=2): # Only the second half run this ...
- metile.simd_shuffle_xor(value, mask)¶
Exchange data between lanes within a simdgroup using XOR addressing.
- metile.simd_broadcast(value, lane)¶
Broadcast a value from one lane to all lanes in a simdgroup.
- metile.simd_lane_id()¶
Returns the current thread’s lane index within its simdgroup (0-31).
- metile.thread_id()¶
Returns the thread’s position within the threadgroup.
- metile.barrier()¶
Threadgroup memory barrier. Forces all threads to reach this point before proceeding.
Allocate threadgroup (shared) memory.
Tile Scheduling¶
- metile.tile_swizzle(pid_m, pid_n, pattern='morton', block_size=2)¶
Apply a tile scheduling pattern for better cache locality in 2D grids. Supported patterns:
"morton"(Z-order),"diagonal".pid_m, pid_n = metile.tile_swizzle( metile.program_id(0), metile.program_id(1), pattern="morton", block_size=2, )