Skip to content

[BUG] Non-contiguous tile with Any typing in wp.func #749

@etaoxing

Description

@etaoxing

Bug Description

Seem to run into Warp NVRTC compilation error when passing non-contiguous tiles (result from wp.tile_broadcast) into a wp.func with tile: Any typing.

tile_fn_error.log

from typing import Any
import warp as wp
import torch

DIM_BATCH = 2
DIM_IN = 10
SHIFTED = True

TILE_M = wp.constant(1)
TILE_K = wp.constant(DIM_IN)

@wp.func
def func2(tile: Any):
  return tile

@wp.kernel
def kernel(x: wp.array2d(dtype=wp.float32), y: wp.array2d(dtype=wp.float32)):
  i = wp.tid()
  a = wp.tile_load(x, shape=(TILE_M, TILE_K), offset=(i * TILE_M, 0))
  if wp.static(SHIFTED):  # error occurs with conditional and when moved outside conditional
    a_max = wp.tile_broadcast(wp.tile_max(a), shape=(TILE_M, TILE_K))
    a_max2 = func2(a_max)
    a -= a_max2  # or a2 = a - a_max2

x, y = torch.rand((DIM_BATCH, DIM_IN)), torch.rand((DIM_BATCH, DIM_IN))
x, y = wp.from_torch(x), wp.from_torch(y)

kernel_dims = (DIM_BATCH,)
inputs, outputs = [x], [y]

wp.launch_tiled(
    kernel,
    dim=kernel_dims,
    inputs=inputs,
    outputs=outputs,
    device=x.device,
    block_dim=64,
)

System Information

On warp nightly 1.8.0.dev20250523, df2df7d

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtile

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions