Skip to content

[BUG] tile_matmul returns wrong result after tile_lower_solve with transpose #768

@RSchwan

Description

@RSchwan

Bug Description

I was playing around with the new tile_lower_solve in nightly and found a bug that appears when solving with a tranpose right hand side with a subsequent tile_matmul operation.

Here a minimal example to reproduce the bug:

import numpy as np
import scipy.linalg as la

import warp as wp

wp.init()

@wp.kernel
def test_bug(A: wp.array(dtype=wp.float64, ndim=2), L: wp.array(dtype=wp.float64, ndim=2), C: wp.array(dtype=wp.float64, ndim=2)):
    tid, block_tid = wp.tid()
    A_tile = wp.tile_load(A, shape=(3, 3), offset=(0, 0))
    L_tile = wp.tile_load(L, shape=(3, 3), offset=(0, 0))
    A_tile_T = wp.tile_transpose(A_tile)
    A_sol_T = wp.tile_lower_solve(L_tile, A_tile_T)
    A_sol = wp.tile_transpose(A_sol_T)
    C_tile = wp.tile_matmul(A_sol, A_sol_T)
    C_tile_unexpected = wp.tile_matmul(A_sol_T, A_sol)
    wp.tile_store(C, C_tile, offset=(0, 0))
    if block_tid == 0:
        print(A_sol)
        print(A_sol_T)
        print(C_tile)
        print(C_tile_unexpected)

np.random.seed(0)
A = np.random.rand(3, 3)
L = np.array([[1.0, 0.0, 0.0],
              [1.0, 2.0, 0.0],
              [1.0, 2.0, 3.0]])
A_sol = la.solve_triangular(L, A.T, lower=True).T
print(A_sol)
print(A_sol.T)
print(A_sol @ A_sol.T)

A_warp = wp.from_numpy(A, dtype=wp.float64, device='cuda')
L_warp = wp.from_numpy(L, dtype=wp.float64, device='cuda')
C_warp = wp.zeros((3, 3), dtype=wp.float64, device='cuda')
wp.launch_tiled(test_bug, dim=[1], inputs=[A_warp, L_warp, C_warp], block_dim=64)

which outputs the following on my machine:

[[ 0.5488135   0.08318793 -0.03747533]
 [ 0.54488318 -0.06061419  0.07407977]
 [ 0.43758721  0.22709289  0.02396325]]
[[ 0.5488135   0.54488318  0.43758721]
 [ 0.08318793 -0.06061419  0.22709289]
 [-0.03747533  0.07407977  0.02396325]]
[[0.30952089 0.29122072 0.25814713]
 [0.29122072 0.30605958 0.22644405]
 [0.25814713 0.22644405 0.24362799]]
Module __main__ bd7c09d load on device 'cuda:0' took 44.06 ms  (cached)
[[0.548814 0.0831879 -0.0374753]
 [0.544883 -0.0606142 0.0740798]
 [0.437587 0.227093 0.0239633]] = tile(shape=(3,3), storage=shared)
[[0.548814 0.544883 0.437587]
 [0.0831879 -0.0606142 0.227093]
 [-0.0374753 0.0740798 0.0239633]] = tile(shape=(3,3), storage=shared)
[[0.789577 0.112 0.0302839]
 [0.112 0.0621655 -0.0021659]
 [0.0302839 -0.0021659 0.00746645]] = tile(shape=(3,3), storage=shared)
[[0.309521 0.291221 0.258147]
 [0.291221 0.30606 0.226444]
 [0.258147 0.226444 0.243628]] = tile(shape=(3,3), storage=shared)

Note that for some reason, the result in C is C = A_sol.T @ A_sol instead of C = A_sol.T @ A_sol, i.e., it multiplies the transpose of the input.

System Information

Warp 1.8.0.dev20250602 initialized:
   Git commit: e11e5afc2b1076b94f976d7b16642df3fafff52e
   CUDA Toolkit 12.8, Driver 12.7
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA RTX 2000 Ada Generation" (16 GiB, sm_89, mempool enabled)
   Kernel cache:
     /root/.cache/warp/1.8.0.dev20250602

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtile

Type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions