Skip to content

Paddle 的 to_tensor()方法和 torch 的 tensor()方法的行为不一致 #72484

@YoctoHan

Description

@YoctoHan

问题背景

在进行 Paddle 和 PyTorch 的精度对齐的过程中,为了模拟某些计算过程,需要使用随机生成的张量进行实验,为了消除随机性,采用 numpy 充当中间件,指定 seed 后生成随机 array ,再将其分别转换为两个框架的 tensor ,在此过程中发现了一个 Paddle 和 PyTorch 行为不一致的现象,现场如下:

Paddle 代码

import paddle
import numpy as np

# Function to save tensor as numpy array
def save_tensor_to_numpy(tensor_data, file_path):
    # Detach the tensor from the computation graph, move it to CPU, and convert to numpy
    tensor_data_cpu = tensor_data.astype(paddle.float32).detach().cpu().numpy()

    # Saving the numpy array to a file
    np.save(file_path, tensor_data_cpu)


np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = paddle.to_tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.astype(paddle.bfloat16)

A_tensor_2 = paddle.to_tensor(A_np, dtype=paddle.bfloat16).cuda()

save_tensor_to_numpy(A_tensor_1, 'A_matrix_origin_1.npy')
save_tensor_to_numpy(A_tensor_2, 'A_matrix_origin_2.npy')

Pytorch 代码

import torch
import numpy as np

# Function to save tensor as numpy array
def save_tensor_to_numpy(tensor_data, file_path):
    # Detach the tensor from the computation graph, move it to CPU, and convert to numpy
    tensor_data_cpu = tensor_data.to(torch.float32).detach().cpu().numpy()

    # Saving the numpy array to a file
    np.save(file_path, tensor_data_cpu)


np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = torch.tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.to(torch.bfloat16)

A_tensor_2 = torch.tensor(A_np, dtype=torch.bfloat16).cuda()

save_tensor_to_numpy(A_tensor_1, 'A_matrix_origin_1.npy')
save_tensor_to_numpy(A_tensor_2, 'A_matrix_origin_2.npy')

对比脚本

def compare(torch_tensor: np.ndarray, paddle_tensor: np.ndarray) -> dict:
    # Ensure both tensors have the same dtype and shape
    assert torch_tensor.dtype == paddle_tensor.dtype, \
        f"Data type mismatch: torch_tensor dtype={torch_tensor.dtype}, paddle_tensor dtype= {paddle_tensor.dtype}"
    assert torch_tensor.shape == paddle_tensor.shape, \
        f"Shape mismatch: torch_tensor shape={torch_tensor.shape}, paddle_tensor shape= {paddle_tensor.shape}"
    
    # Calculate mean and variance for both tensors
    torch_mean = np.mean(torch_tensor)
    paddle_mean = np.mean(paddle_tensor)
    torch_variance = np.var(torch_tensor)
    paddle_variance = np.var(paddle_tensor)
    
    # Calculate Mean Absolute Error (MAE)
    mae = np.mean(np.abs(torch_tensor.squeeze() - paddle_tensor.squeeze()))
    # Create a mask for values where both tensors have values greater than 1e-6
    mask = (np.abs(torch_tensor) > 1e-6) & (np.abs(paddle_tensor) > 1e-6)
    # Calculate MAE percentage for positions where both values are greater than 1e-6
    # Avoid division by zero with the mask and 1e-8 for safety
    mae_percentage = np.mean(np.abs(torch_tensor[mask] - paddle_tensor[mask]) / (np.abs(torch_tensor[mask]) + 1e-8)) * 100

    return {
        "torch_mean": torch_mean,
        "paddle_mean": paddle_mean,
        "torch_variance": torch_variance,
        "paddle_variance": paddle_variance,
        "mae": mae,
        "mae%": mae_percentage,
    }

实验结果

Matrix Torch Mean Paddle Mean Torch Variance Paddle Variance MAE MAE%
A_matrix_origin_1 0.49004754 0.49004754 0.0069367127 0.0069367127 0.0 0.0
A_matrix_origin_2 0.49004754 0.4886367 0.0069367127 0.006872112 0.001410692 0.282440148293972

结论

上述代码中的 A_npfloat64 类型的 numpy.arrary,使用两种不同的方式将其转为tensor,第一种方式为先转为float64类型的tensor,再将其转为目标数据类型bfloat16,第二种方式是直接转为目标数据类型bfloat16tensor,PyTorch 的两种方式结果保持了一致,但是 Paddle 的两种方式结果不一致,且对比分析应该是第二种转换方式存在精度误差。

原因分析

之前怀疑是由于 Paddle 这头的张量的 Place 不同导致的表现不一致,但是排查之后 Place 应该一直是一致的,复现场景如下:

Paddle精度分析

import paddle
import numpy as np

def compare(tensor_1: paddle.Tensor, tensor_2: paddle.Tensor) -> dict:
    # Ensure both tensors have the same dtype and shape
    assert tensor_1.dtype == tensor_2.dtype, \
        f"Data type mismatch: torch_tensor dtype={tensor_1.dtype}, paddle_tensor dtype= {tensor_2.dtype}"
    assert tensor_1.shape == tensor_2.shape, \
        f"Shape mismatch: torch_tensor shape={tensor_1.shape}, paddle_tensor shape= {tensor_2.shape}"
    assert tensor_1.place._equals(tensor_2.place), \
        f"Device mismatch: torch_tensor device={tensor_1.place}, paddle_tensor device= {tensor_2.place}"
    
    # Calculate mean and variance for both tensors
    tensor_1_mean = paddle.mean(tensor_1)
    tensor_2_mean = paddle.mean(tensor_2)
    tensor_1_variance = paddle.var(tensor_1)
    tensor_2_variance = paddle.var(tensor_2)
    
    # Calculate Mean Absolute Error (MAE)
    mae = paddle.mean(paddle.abs(tensor_1.squeeze() - tensor_2.squeeze()))
    # Create a mask for values where both tensors have values greater than 1e-6
    mask = (paddle.abs(tensor_1) > 1e-6) & (paddle.abs(tensor_2) > 1e-6)
    # Calculate MAE percentage for positions where both values are greater than 1e-6
    # Avoid division by zero with the mask and 1e-8 for safety
    mae_percentage = paddle.mean(paddle.abs(tensor_1[mask] - tensor_2[mask]) / (paddle.abs(tensor_1[mask]) + 1e-8)) * 100

    return {
        "tensor_1_mean": tensor_1_mean,
        "tensor_2_mean": tensor_2_mean,
        "tensor_1_variance": tensor_1_variance,
        "tensor_2_variance": tensor_2_variance,
        "mae": mae,
        "mae%": mae_percentage,
    }

np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = paddle.to_tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.astype(paddle.bfloat16)

A_tensor_2 = paddle.to_tensor(A_np, dtype=paddle.bfloat16).cuda()

print(compare(A_tensor_1, A_tensor_2))

输出为:

{
    'tensor_1_mean': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.49023438), 
    'tensor_2_mean': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.48828125), 
    'tensor_1_variance': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.00692749), 
    'tensor_2_variance': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.00686646), 
    'mae': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.00141144), 
    'mae%': Tensor(shape=[], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, 0.28320312)
}

PyTorch 精度分析

import torch
import numpy as np

def compare(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> dict:
    # Ensure both tensors have the same dtype and shape
    assert tensor_1.dtype == tensor_2.dtype, \
        f"Data type mismatch: torch_tensor dtype={tensor_1.dtype}, paddle_tensor dtype= {tensor_2.dtype}"
    assert tensor_1.shape == tensor_2.shape, \
        f"Shape mismatch: torch_tensor shape={tensor_1.shape}, paddle_tensor shape= {tensor_2.shape}"
    assert tensor_1.device == tensor_2.device, \
        f"Device mismatch: torch_tensor device={tensor_1.device}, paddle_tensor device= {tensor_2.device}"
    
    # Calculate mean and variance for both tensors
    torch_1_mean = torch.mean(tensor_1)
    torch_2_mean = torch.mean(tensor_2)
    torch_1_variance = torch.var(tensor_1)
    torch_2_variance = torch.var(tensor_2)
    
    # Calculate Mean Absolute Error (MAE)
    mae = torch.mean(torch.abs(tensor_1.squeeze() - tensor_2.squeeze()))
    # Create a mask for values where both tensors have values greater than 1e-6
    mask = (torch.abs(tensor_1) > 1e-6) & (torch.abs(tensor_2) > 1e-6)
    # Calculate MAE percentage for positions where both values are greater than 1e-6
    # Avoid division by zero with the mask and 1e-8 for safety
    mae_percentage = torch.mean(torch.abs(tensor_1[mask] - tensor_2[mask]) / (torch.abs(tensor_1[mask]) + 1e-8)) * 100

    return {
        "torch_1_mean": torch_1_mean,
        "torch_2_mean": torch_2_mean,
        "torch_1_variance": torch_1_variance,
        "torch_2_variance": torch_2_variance,
        "mae": mae,
        "mae%": mae_percentage,
    }

np.random.seed(12333)
# Generate random matrices using numpy
A_np = np.random.normal(0.49, 0.0833, (1024, 6144))

A_tensor_1 = torch.tensor(A_np).cuda()
A_tensor_1 = A_tensor_1.to(torch.bfloat16)

A_tensor_2 = torch.tensor(A_np, dtype=torch.bfloat16).cuda()

print(compare(A_tensor_1, A_tensor_2))

输出为:

{
    'torch_1_mean': tensor(0.4902, device='cuda:0', dtype=torch.bfloat16), 
    'torch_2_mean': tensor(0.4902, device='cuda:0', dtype=torch.bfloat16), 
    'torch_1_variance': tensor(0.0069, device='cuda:0', dtype=torch.bfloat16), 
    'torch_2_variance': tensor(0.0069, device='cuda:0', dtype=torch.bfloat16), 
    'mae': tensor(0., device='cuda:0', dtype=torch.bfloat16), 
    'mae%': tensor(0., device='cuda:0', dtype=torch.bfloat16)
}

感觉自己对于 Paddle 的 API 还是不够熟悉,会继续看看相关的源码分析一下,如果有哪位大佬知道问题的原因务请指出。

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions