Skip to content

[REQ] Extend graph capture support to allow conditional branches #597

@nvtw

Description

@nvtw

Description

In many applications, certain kernels do not need to execute if a specific condition is met. For example, if there are no contacts in a physics simulator, all contact processing kernels can be skipped. Since graph capturing is typically used to launch many kernels with low overhead, it is difficult to dynamically adjust to situations like having no contacts. The current solution still involves launching all kernels, but they exit immediately, which is somewhat wasteful. Conditional CUDA graph nodes could enable completely skipping these kernels.

Context

The API could look as follows:
def capture_insert_if_else(condition: warp.array(dtype=int, ndim=1), stream: Stream, on_true, on_false, **kwargs)
def capture_insert_while(condition: warp.array(dtype=int, ndim=1), stream: Stream, while_body, **kwargs)

kwargs are the argumens required by the on_true/on_false and while_body functions/lambdas. condition is a single integer array. If it's value is zero, then the condition is false and true otherwise.

And usage could be as follows (note that all launch_ prefixed functions are normal python functions that can call as many warp kernels as needed. They can also introduce conditional graph nodes themselves.):

wp.capture_begin(stream=stream)
testWhile = False
if testWhile:
	wp.capture_insert_while(
		condition=whileCondition,
		stream=stream,
		while_body=launch_multiply_by_two_until_limit,
		c=whileCondition,
		array=array,
		limit=1000,
		s=stream,
		device=device,
	)
else:
	launch_multiply_by_seven(array, s=stream, device=device)
	wp.capture_insert_if_else(
		condition=condition,
		stream=stream,
		on_true=lambda: wp.capture_insert_if_else(
			condition=condition2,
			stream=stream,
			on_true=launch_multiply_by_two,
			on_false=launch_multiply_by_thirteen,
			array=array,
			s=stream,
			device=device,
		),
		on_false=lambda: wp.capture_insert_if_else(
			condition=condition2,
			stream=stream,
			on_true=launch_multiply_by_three,
			on_false=launch_multiply_by_eleven,
			array=array,
			s=stream,
			device=device,
		),
	)
	launch_multiply_by_five(array, s=stream, device=device)
graph = wp.capture_end(stream=stream)

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions