-
Notifications
You must be signed in to change notification settings - Fork 349
Closed
Description
Bug Description
I was playing around with the new tile_lower_solve
in nightly and found a bug that appears when solving with a tranpose right hand side with a subsequent tile_matmul
operation.
Here a minimal example to reproduce the bug:
import numpy as np
import scipy.linalg as la
import warp as wp
wp.init()
@wp.kernel
def test_bug(A: wp.array(dtype=wp.float64, ndim=2), L: wp.array(dtype=wp.float64, ndim=2), C: wp.array(dtype=wp.float64, ndim=2)):
tid, block_tid = wp.tid()
A_tile = wp.tile_load(A, shape=(3, 3), offset=(0, 0))
L_tile = wp.tile_load(L, shape=(3, 3), offset=(0, 0))
A_tile_T = wp.tile_transpose(A_tile)
A_sol_T = wp.tile_lower_solve(L_tile, A_tile_T)
A_sol = wp.tile_transpose(A_sol_T)
C_tile = wp.tile_matmul(A_sol, A_sol_T)
C_tile_unexpected = wp.tile_matmul(A_sol_T, A_sol)
wp.tile_store(C, C_tile, offset=(0, 0))
if block_tid == 0:
print(A_sol)
print(A_sol_T)
print(C_tile)
print(C_tile_unexpected)
np.random.seed(0)
A = np.random.rand(3, 3)
L = np.array([[1.0, 0.0, 0.0],
[1.0, 2.0, 0.0],
[1.0, 2.0, 3.0]])
A_sol = la.solve_triangular(L, A.T, lower=True).T
print(A_sol)
print(A_sol.T)
print(A_sol @ A_sol.T)
A_warp = wp.from_numpy(A, dtype=wp.float64, device='cuda')
L_warp = wp.from_numpy(L, dtype=wp.float64, device='cuda')
C_warp = wp.zeros((3, 3), dtype=wp.float64, device='cuda')
wp.launch_tiled(test_bug, dim=[1], inputs=[A_warp, L_warp, C_warp], block_dim=64)
which outputs the following on my machine:
[[ 0.5488135 0.08318793 -0.03747533]
[ 0.54488318 -0.06061419 0.07407977]
[ 0.43758721 0.22709289 0.02396325]]
[[ 0.5488135 0.54488318 0.43758721]
[ 0.08318793 -0.06061419 0.22709289]
[-0.03747533 0.07407977 0.02396325]]
[[0.30952089 0.29122072 0.25814713]
[0.29122072 0.30605958 0.22644405]
[0.25814713 0.22644405 0.24362799]]
Module __main__ bd7c09d load on device 'cuda:0' took 44.06 ms (cached)
[[0.548814 0.0831879 -0.0374753]
[0.544883 -0.0606142 0.0740798]
[0.437587 0.227093 0.0239633]] = tile(shape=(3,3), storage=shared)
[[0.548814 0.544883 0.437587]
[0.0831879 -0.0606142 0.227093]
[-0.0374753 0.0740798 0.0239633]] = tile(shape=(3,3), storage=shared)
[[0.789577 0.112 0.0302839]
[0.112 0.0621655 -0.0021659]
[0.0302839 -0.0021659 0.00746645]] = tile(shape=(3,3), storage=shared)
[[0.309521 0.291221 0.258147]
[0.291221 0.30606 0.226444]
[0.258147 0.226444 0.243628]] = tile(shape=(3,3), storage=shared)
Note that for some reason, the result in C is C = A_sol.T @ A_sol
instead of C = A_sol.T @ A_sol
, i.e., it multiplies the transpose of the input.
System Information
Warp 1.8.0.dev20250602 initialized:
Git commit: e11e5afc2b1076b94f976d7b16642df3fafff52e
CUDA Toolkit 12.8, Driver 12.7
Devices:
"cpu" : "x86_64"
"cuda:0" : "NVIDIA RTX 2000 Ada Generation" (16 GiB, sm_89, mempool enabled)
Kernel cache:
/root/.cache/warp/1.8.0.dev20250602