Skip to content

Conversation

fangfangssj
Copy link
Contributor

@fangfangssj fangfangssj commented Apr 10, 2025

PR Category

User Experience

PR Types

Bug fixes

Description

Pcard-75624
前置PR,为Pow算子支持复数

实数的SVD反向传播参考https://j-towns.github.io/papers/svd-derivative.pdf
复数的SVD反向传播参考https://arxiv.org/pdf/1909.02659

Copy link

paddle-bot bot commented Apr 10, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@fangfangssj
Copy link
Contributor Author

fangfangssj commented Apr 10, 2025

对于svd反向传播,使用了torch(2.6.0)进行对比测试,对齐了torch的实现

import numpy as np
import torch
import paddle

# 更换不同的shape
shapes = [(30, 20), (100, 200), (5, 400), (20, 5), (5, 2), (100, 5)]

# 定义要测试的损失函数组合
loss_functions = {
    'U': {
        'torch': lambda u, s, v: torch.sum(torch.abs(u)),
        'paddle': lambda u, s, v: paddle.sum(paddle.abs(u))
    },
    'S': {
        'torch': lambda u, s, v: torch.sum(torch.abs(s)),
        'paddle': lambda u, s, v: paddle.sum(paddle.abs(s))
    },
    'V': {
        'torch': lambda u, s, v: torch.sum(torch.abs(v)),
        'paddle': lambda u, s, v: paddle.sum(paddle.abs(v))
    },
    'U+V': {
        'torch': lambda u, s, v: torch.sum(torch.abs(u)) + torch.sum(torch.abs(v)),
        'paddle': lambda u, s, v: paddle.sum(paddle.abs(u)) + paddle.sum(paddle.abs(v))
    },
    'U+S': {
        'torch': lambda u, s, v: torch.sum(torch.abs(u)) + torch.sum(torch.abs(s)),
        'paddle': lambda u, s, v: paddle.sum(paddle.abs(u)) + paddle.sum(paddle.abs(s))
    },
    'S+V': {
        'torch': lambda u, s, v: torch.sum(torch.abs(s)) + torch.sum(torch.abs(v)),
        'paddle': lambda u, s, v: paddle.sum(paddle.abs(s)) + paddle.sum(paddle.abs(v))
    },
    'U+S+V': {
        'torch': lambda u, s, v: torch.sum(torch.abs(u)) + torch.sum(torch.abs(s)) + torch.sum(torch.abs(v)),
        'paddle': lambda u, s, v: paddle.sum(paddle.abs(u)) + paddle.sum(paddle.abs(s)) + paddle.sum(paddle.abs(v))
    }
}

for shape in shapes:    
    # ===================== 1. 生成复数随机数据 =====================
    np.random.seed(42)  # 固定随机种子
    real_part = np.random.rand(*shape).astype(np.float64)
    imag_part = np.random.rand(*shape).astype(np.float64)
    np_complex_array = real_part + 1j * imag_part  # 构造复数矩阵

    for loss_name, funcs in loss_functions.items():
        # ===================== 2. PyTorch 计算 =====================
        # 转换为 PyTorch 复数张量
        torch_complex_tensor = torch.from_numpy(np_complex_array).to('cuda').requires_grad_(True)

        # SVD 前向计算 (复数输入)
        U_torch, S_torch, V_torch = torch.linalg.svd(torch_complex_tensor, full_matrices=False)

        # 计算损失
        loss_torch = funcs['torch'](U_torch, S_torch, V_torch)

        # 反向传播
        loss_torch.backward()
        torch_grad = torch_complex_tensor.grad.cpu().numpy()

        # ===================== 3. PaddlePaddle 计算 =====================
        # 转换为 Paddle 复数张量 (paddle.complex64)
        paddle_complex_tensor = paddle.to_tensor(np_complex_array, stop_gradient=False)

        # SVD 前向计算 (复数输入)
        U_paddle, S_paddle, V_paddle = paddle.linalg.svd(paddle_complex_tensor, full_matrices=False)

        # 计算损失
        loss_paddle = funcs['paddle'](U_paddle, S_paddle, V_paddle)

        # 反向传播
        loss_paddle.backward()
        paddle_grad = paddle_complex_tensor.grad.numpy()

        # 比较梯度
        are_close = np.allclose(torch_grad, paddle_grad, atol=1e-5, rtol=1e-5)
        print(f"梯度是否一致 (shape {shape}, loss {loss_name}):", are_close)
        
        # 如果不一致,打印最大差异
        if not are_close:
            max_diff = np.max(np.abs(torch_grad - paddle_grad))
            print(f"最大差异: {max_diff}")

51505b4a52dd7d99643c90c731d42a76

Copy link

paddle-ci-bot bot commented Apr 21, 2025

Sorry to inform you that 5a523ce's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Comment on lines 110 to 112
U->Resize(U->dims());
S->Resize(S->dims());
VH->Resize(VH->dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

U, S, VH的size应该在infermeta里设置好了,这里可能不用再resize了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@HydrogenSulfate
Copy link
Contributor

@fangfangssj 覆盖率好象没过
image

@fangfangssj
Copy link
Contributor Author

@fangfangssj 覆盖率好象没过 image

这部分是关于复数的,没有写对应的单测,但是对齐了torch

@HydrogenSulfate
Copy link
Contributor

@fangfangssj 覆盖率好象没过 image

这部分是关于复数的,没有写对应的单测,但是对齐了torch

okok

luotao1
luotao1 previously approved these changes Apr 23, 2025
Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for skipif

@HydrogenSulfate
Copy link
Contributor

@fangfangssj 加一个调用反向的单测,仅作功能测试,以此满足代码覆盖率需求

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 99.25373% with 1 line in your changes missing coverage. Please review.

Please upload report for BASE (develop@441816a). Learn more about missing BASE report.

Files with missing lines Patch % Lines
python/paddle/tensor/linalg.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #72169   +/-   ##
==========================================
  Coverage           ?   99.25%           
==========================================
  Files              ?        4           
  Lines              ?      134           
  Branches           ?        0           
==========================================
  Hits               ?      133           
  Misses             ?        1           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@HydrogenSulfate HydrogenSulfate merged commit 5e86325 into PaddlePaddle:develop Apr 30, 2025
40 of 42 checks passed
YqGe585 pushed a commit to YqGe585/Paddle that referenced this pull request May 7, 2025
…ddle#72169)

* support complex

* fix

* fix

* fix

* fix

* fix ci

* rerun ci

* fix

* add test

* fix ci

---------

Co-authored-by: fangfangssj <fangfangssj@qq.com>
@fangfangssj fangfangssj deleted the svd branch May 15, 2025 07:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants