-
Notifications
You must be signed in to change notification settings - Fork 349
Closed
Description
Bug Description
tape.zero()
should zero-out gradients for all gradient arrays that were accumulated to in the backward pass. Currently this properly handles arrays passed as top-level fields of a wp.struct
, but gradient arrays that are fields of a nested struct are not reset.
import warp as wp
@wp.struct
class B:
arr_doubly_nested: wp.array(dtype=float)
@wp.struct
class A:
b: B
arr_nested: wp.array(dtype=float)
@wp.kernel
def fn(a: A, loss: wp.array(dtype=float)):
i = wp.tid()
wp.atomic_add(loss, 0, a.arr_nested[i])
wp.atomic_add(loss, 0, a.b.arr_doubly_nested[i])
a = A()
a.b.arr_doubly_nested = wp.ones(shape=(1,), dtype=float, requires_grad=True)
a.arr_nested = wp.ones(shape=(1,), dtype=float, requires_grad=True)
loss = wp.zeros(shape=(1,), dtype=float, requires_grad=True)
tape = wp.Tape()
with tape:
wp.launch(fn, dim=1, inputs=(a, loss))
print(loss) # [2]
tape.backward(loss)
# both grads are correct
print(a.arr_nested.grad) # [1]
print(a.b.arr_doubly_nested.grad) # [1]
tape.zero()
# Both grad arrays *should* be reset, but only a.arr_nested is
print(a.arr_nested.grad) # [0]
print(a.b.arr_doubly_nested.grad) # [1]
System Information
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working