Skip to content

put_along_axis reduce='mul' 结果不对, cpu正确,gpu错误 #52446

@JunnYu

Description

@JunnYu

bug描述 Describe the Bug

这里没有实现,

// TODO(huangxu96) platform::CudaAtomicMul(*self_data, *src_data);

import paddle
def scatter_reduce(input: paddle.Tensor, 
                   axis: int, 
                   index: paddle.Tensor, 
                   src: paddle.Tensor, 
                   reduce: str) -> paddle.Tensor:
    # reduce "sum", "prod", "mean", "amax", "amin"
    if reduce == "sum":
        input.put_along_axis_(indices=index, values=src, axis=axis, reduce="add")
    elif reduce == "mean":
        input.put_along_axis_(indices=index, values=src, axis=axis, reduce="add")
        dst_div = (
            paddle.ones_like(input)
            .put_along_axis(
                indices=index, values=paddle.to_tensor(1.0, dtype=input.dtype), axis=axis, reduce="add"
            )
        )
        input = input / dst_div
    elif reduce == "prod":
        input = input.put_along_axis(indices=index, values=src, axis=axis, reduce="mul")
    else:
        raise NotImplementedError("only support mode in ['sum', 'prod', 'mean']!")
    return input
paddle.set_device('cpu')
src = paddle.to_tensor([1., 2., 3., 4., 5., 6.])
index = paddle.to_tensor([0, 1, 0, 1, 2, 1])
input = paddle.to_tensor([1., 2., 3., 4.])
out = scatter_reduce(input, 0, index, src, reduce="prod")
print(out)
# Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
#        [3. , 96., 15., 4. ])

paddle.set_device('gpu')
src = paddle.to_tensor([1., 2., 3., 4., 5., 6.])
index = paddle.to_tensor([0, 1, 0, 1, 2, 1])
input = paddle.to_tensor([1., 2., 3., 4.])
out = scatter_reduce(input, 0, index, src, reduce="prod")
print(out)
# Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [1. , 4. , 15., 4. ])

其他补充信息 Additional Supplementary Information

希望新增 reduce='mean' 的,当前是组合算子组成

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions