-
Notifications
You must be signed in to change notification settings - Fork 349
Description
Description
In many applications, certain kernels do not need to execute if a specific condition is met. For example, if there are no contacts in a physics simulator, all contact processing kernels can be skipped. Since graph capturing is typically used to launch many kernels with low overhead, it is difficult to dynamically adjust to situations like having no contacts. The current solution still involves launching all kernels, but they exit immediately, which is somewhat wasteful. Conditional CUDA graph nodes could enable completely skipping these kernels.
Context
The API could look as follows:
def capture_insert_if_else(condition: warp.array(dtype=int, ndim=1), stream: Stream, on_true, on_false, **kwargs)
def capture_insert_while(condition: warp.array(dtype=int, ndim=1), stream: Stream, while_body, **kwargs)
kwargs are the argumens required by the on_true/on_false and while_body functions/lambdas. condition is a single integer array. If it's value is zero, then the condition is false and true otherwise.
And usage could be as follows (note that all launch_ prefixed functions are normal python functions that can call as many warp kernels as needed. They can also introduce conditional graph nodes themselves.):
wp.capture_begin(stream=stream)
testWhile = False
if testWhile:
wp.capture_insert_while(
condition=whileCondition,
stream=stream,
while_body=launch_multiply_by_two_until_limit,
c=whileCondition,
array=array,
limit=1000,
s=stream,
device=device,
)
else:
launch_multiply_by_seven(array, s=stream, device=device)
wp.capture_insert_if_else(
condition=condition,
stream=stream,
on_true=lambda: wp.capture_insert_if_else(
condition=condition2,
stream=stream,
on_true=launch_multiply_by_two,
on_false=launch_multiply_by_thirteen,
array=array,
s=stream,
device=device,
),
on_false=lambda: wp.capture_insert_if_else(
condition=condition2,
stream=stream,
on_true=launch_multiply_by_three,
on_false=launch_multiply_by_eleven,
array=array,
s=stream,
device=device,
),
)
launch_multiply_by_five(array, s=stream, device=device)
graph = wp.capture_end(stream=stream)