-
Notifications
You must be signed in to change notification settings - Fork 349
Closed
Labels
Description
Bug Description
Seem to run into Warp NVRTC compilation error
when passing non-contiguous tiles (result from wp.tile_broadcast
) into a wp.func
with tile: Any
typing.
from typing import Any
import warp as wp
import torch
DIM_BATCH = 2
DIM_IN = 10
SHIFTED = True
TILE_M = wp.constant(1)
TILE_K = wp.constant(DIM_IN)
@wp.func
def func2(tile: Any):
return tile
@wp.kernel
def kernel(x: wp.array2d(dtype=wp.float32), y: wp.array2d(dtype=wp.float32)):
i = wp.tid()
a = wp.tile_load(x, shape=(TILE_M, TILE_K), offset=(i * TILE_M, 0))
if wp.static(SHIFTED): # error occurs with conditional and when moved outside conditional
a_max = wp.tile_broadcast(wp.tile_max(a), shape=(TILE_M, TILE_K))
a_max2 = func2(a_max)
a -= a_max2 # or a2 = a - a_max2
x, y = torch.rand((DIM_BATCH, DIM_IN)), torch.rand((DIM_BATCH, DIM_IN))
x, y = wp.from_torch(x), wp.from_torch(y)
kernel_dims = (DIM_BATCH,)
inputs, outputs = [x], [y]
wp.launch_tiled(
kernel,
dim=kernel_dims,
inputs=inputs,
outputs=outputs,
device=x.device,
block_dim=64,
)
System Information
On warp nightly 1.8.0.dev20250523, df2df7d