Skip to content

[BUG] Gradient arrays in nested structs are not reset by tape.zero() #807

@gdaviet

Description

@gdaviet

Bug Description

tape.zero() should zero-out gradients for all gradient arrays that were accumulated to in the backward pass. Currently this properly handles arrays passed as top-level fields of a wp.struct, but gradient arrays that are fields of a nested struct are not reset.

import warp as wp


@wp.struct
class B:
    arr_doubly_nested: wp.array(dtype=float)


@wp.struct
class A:
    b: B
    arr_nested: wp.array(dtype=float)


@wp.kernel
def fn(a: A, loss: wp.array(dtype=float)):
    i = wp.tid()
    wp.atomic_add(loss, 0, a.arr_nested[i])
    wp.atomic_add(loss, 0, a.b.arr_doubly_nested[i])


a = A()
a.b.arr_doubly_nested = wp.ones(shape=(1,), dtype=float, requires_grad=True)
a.arr_nested = wp.ones(shape=(1,), dtype=float, requires_grad=True)

loss = wp.zeros(shape=(1,), dtype=float, requires_grad=True)

tape = wp.Tape()
with tape:
    wp.launch(fn, dim=1, inputs=(a, loss))

print(loss)  # [2]

tape.backward(loss)
# both grads are correct
print(a.arr_nested.grad)  # [1]
print(a.b.arr_doubly_nested.grad)  # [1]

tape.zero()
# Both grad arrays *should* be reset, but only a.arr_nested is
print(a.arr_nested.grad)  # [0]
print(a.b.arr_doubly_nested.grad)  # [1]

System Information

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions