-
Notifications
You must be signed in to change notification settings - Fork 349
Closed
Labels
Milestone
Description
Description
Add wp.tile_squeeze(), to function like np.squeeze()
Test
@wp.kernel
def test_tile_squeeze_kernel(
x: wp.array2d(dtype=float),
y: wp.array(dtype=float)
):
a = wp.tile_load(x, shape=(TILE_M, 1), offset=(0,0))
b = wp.tile_squeeze(a)
wp.tile_store(y, b, offset=(0,))
device = "cuda:0"
x = wp.ones((TILE_M, 1), dtype=float, device=device, requires_grad=True)
y = wp.zeros((TILE_M,), dtype=float, device=device, requires_grad=True)
tape = wp.Tape()
with tape:
wp.launch_tiled(test_tile_squeeze_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
y.grad = wp.ones_like(y)
tape.backward()