-
Notifications
You must be signed in to change notification settings - Fork 349
Closed
Closed
Copy link
Labels
Milestone
Description
In the tile_matmul_lto_dispatch_func()
, we should be able to save some compilation time by skipping compiling LTO's for adjoint operations if enable_backward
for the module is False
.
Lines 6547 to 6576 in 3635b9a
# adjA += adjC * B^T - Transpose ~= flipped layout | |
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot( | |
M, | |
K, | |
N, | |
out.type.dtype, | |
b.type.dtype, | |
a.type.dtype, | |
out.type.layout, | |
tile_flip_layout(b.type.layout), | |
a.type.layout, | |
arch, | |
num_threads, | |
builder, | |
) | |
# adjB += A^T * adjC - Transpose ~= flipped layout | |
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot( | |
K, | |
N, | |
M, | |
a.type.dtype, | |
out.type.dtype, | |
b.type.dtype, | |
tile_flip_layout(a.type.layout), | |
out.type.layout, | |
b.type.layout, | |
arch, | |
num_threads, | |
builder, | |
) |