Skip to content

Performance optimization: Skip compilation of LTO's for adjoint matmuls #644

@shi-eric

Description

@shi-eric

In the tile_matmul_lto_dispatch_func(), we should be able to save some compilation time by skipping compiling LTO's for adjoint operations if enable_backward for the module is False.

warp/warp/builtins.py

Lines 6547 to 6576 in 3635b9a

# adjA += adjC * B^T - Transpose ~= flipped layout
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
M,
K,
N,
out.type.dtype,
b.type.dtype,
a.type.dtype,
out.type.layout,
tile_flip_layout(b.type.layout),
a.type.layout,
arch,
num_threads,
builder,
)
# adjB += A^T * adjC - Transpose ~= flipped layout
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
K,
N,
M,
a.type.dtype,
out.type.dtype,
b.type.dtype,
tile_flip_layout(a.type.layout),
out.type.layout,
b.type.layout,
arch,
num_threads,
builder,
)

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions