-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[Intel GPU] int4 WOQ gemm XPU Support #137566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137566
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 8b14d42 with merge base d7f3cd0 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@zhuyuhua-v Could you please review the PR? |
"oneDNN input matrixes must have the same ranks"); | ||
TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); | ||
|
||
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please unify the code style. curDevice
-> cur_device
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for you suggestions, the naming has been changed.
mb = dst.size(0); | ||
TORCH_CHECK( | ||
mb == m1.size(0) && mb == m2.size(0), | ||
"batch size mismatch, dst mb: ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is mb
a common term? Can users fully understand the exact meaning of mb
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestions, mb
means minibach here, but i review the code and remove mb
in the code since int4_gemm has no need to handle batch currently.
scale_usr_md = dnnl::memory::desc(scale_dims, scale_user_dt, scale_strides); | ||
zp_usr_md = dnnl::memory::desc(zp_usr_dims, zp_user_dt, zp_usr_strides); | ||
dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides); | ||
// STEP4: create dnnl::memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are STEP 2 and STEP 3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed these kind of comments and add new comments in the codes.
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_usr_m}); | ||
|
||
sycl::event matmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args, deps); | ||
if (!dst.is_same(result)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is dst
not the same as result
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These issues roots from the woq matmul is ported from regular matmul. int4_gemm has no need to consider dst
is not same as result
currently, I have removed the code.
sycl::event matmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args, deps); | ||
if (!dst.is_same(result)) | ||
result.copy_(dst); | ||
result = resize_as_onednn_mat1(mat1_, result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is resize_as_onednn_mat1
required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory}); | ||
|
||
if (attr.with_binary()) | ||
attr.construct_post_binary(matmul_pd, args); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attr
constructs the post binary. However, dnnl::post_ops po = attr.extract_post_ops(dst);
has extracted the post ops and pattr.set_post_ops(po);
has assigned the post op to matmul primitive attribute. Is it a valid behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int4 would have no post-ops currently, i have removed the code, thanks.
dnnl::memory::data_type::s8); | ||
// Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to | ||
// integral primitives (in this example, matmul). | ||
pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OneDNN supports both f16 and bf16. Why do we need to constrain the dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We know have a control statement to determine which dtype is used for fpmath_mode
, thanks. However, bf16
would have runtime issue in oneDNN at current version. The bf16 dtype is valid in newer version of onednn.
@liangan1 Could you please review the PR? |
TORCH_CHECK( | ||
dims == mat1.dim() && dims == mat2.dim(), | ||
"oneDNN input matrixes must have the same ranks"); | ||
TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you have flatten the mat1 and mat2 into dims=2 and the result is also 2 dimension empty tensor. when will dim=3 and result is not defined? Can you show a example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some logic is too old, and we have removed such weird code in gemm integration now.
Attr attr, | ||
const c10::optional<Tensor>& g_idx, | ||
const std::vector<sycl::event>& deps, | ||
Tensor b_raw = at::Tensor()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to bias_raw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias is not presented in weight_int4pack_mm
API, and I have removed bias related code in newest commit. Thanks for your suggestions.
(b.size(0) == 1 && b.size(1) == 1), | ||
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ..."); | ||
if (b.size(0) == 1 && b.size(1) == 1) | ||
b = b.expand({1, n}).contiguous(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other case(e.g., b.dim()=1/3/0), you always expand the b to the same dim to m1. Whether it works when the m1.dim()==3 while b.dim()==2? According to the doc of onednn: "all tensors (including bias
, if it exists) must have the same number of dimensions."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These codes have been removed by me as it is bias -related. Thanks for you reminding.
auto m2_usr_dt = get_onednn_dtype(m2); | ||
auto scale_user_dt = get_onednn_dtype(scale_); // half <==> fp16 | ||
// auto zp_user_dt = dnnl::memory::data_type::s4; // int32, representing 8xint4 | ||
auto zp_user_dt = get_onednn_dtype(zp_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to change xxx_user_xxx to xxx_usr_xxx to unify the style. Due to onednn support different data types , suggest to change to "e.g., half<==>f16"
return output.view_symint(sizes); | ||
} | ||
|
||
sycl::event woq_matmul_int4( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to add more function description here. e.g. the activation data type supported, data layout information for both inputs. etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More detailed description is added in older commits.
|
||
m2_usr_dims = {compressed_k, n}; | ||
scale_dims = {num_groups, n}; | ||
zp_dims = {1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dims of zp_dims is not aligned with the original zp inputs. With this limitation, only the symmetric or per-tensor quantization is supported. Pls add the comments about this limitation of oneDNN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OneDNN provides us with a way to support asymmetry, allowing us to handle asymmetrical scenarios. I'm currently testing it, and if it works, I will modify it here to support both symmetric and asymmetric logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a prepack process since OneDNN doesn't support the most popular layout
https://github.com/intel/torch-xpu-ops/pull/1035/files This PR is used to do int4 weight prepack. |
test/xpu/test_gemm.py
Outdated
b, n_bit=4, q_group_size=q_group | ||
) | ||
# b_int4pack [n, k//8] | ||
b_int4pack = torch._convert_weight_to_int4pack( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be b_int4pack [k//8, n]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reminding, I have modified the description here.
sizes[sizes.size() - 1] = n; | ||
return output.view_symint(sizes); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these codes have been removed in newest commit, thanks for reminding.
Tensor m1 = is_onednn_matmul_strides(mat1_) ? mat1_ : mat1_.contiguous(); | ||
//m2_ may be a 4 dims fake tensor in torchAO with shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2} | ||
//Tensor m2 = mat2_.flatten(0, -2); //ToDo: change to the fke shape: mat2_.flatten(0, -2); // N1 | ||
Tensor m2 = is_onednn_matmul_strides(mat2_) ? mat2_ : mat2_.contiguous(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
auto expected_m1_md = matmul_pd.src_desc(); | ||
auto expected_m2_md = matmul_pd.weights_desc(); | ||
auto expected_dst_md = matmul_pd.dst_desc(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to remove this part.
zeros = min_val + scales * (2 ** (n_bit - 1)) | ||
zeros = min_int - min_val.div(scales).round() | ||
zeros = torch.clamp(zeros, min_int, max_int) | ||
zeros = zeros.to(torch.int8) | ||
assert torch.isnan(zeros).sum() == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also used in tinygemm, should not change this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out, I have moved the codes to xpu/test_gemm.py
const at::Tensor& zp, // [k/group_size, N] | ||
int64_t group_size, | ||
Attr attr, | ||
const std::vector<sycl::event>& deps = {}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this operation require deps
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formerly, fs1 requires we add events at oneDNN integration layer for profiling purposes. For me, it is just intended to have consistent API with conv/gemm. Do we need to remove this?
const at::Tensor& scale, // [K/group_size, N] | ||
const at::Tensor& zp, // [k/group_size, N] | ||
int64_t group_size, | ||
Attr attr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attr attr, | |
std::optional<Attr> attr = std::nullopt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will removed attr, as we do not append post-op currently.
const at::Tensor& zp, // [k/group_size, N] | ||
int64_t group_size, | ||
Attr attr, | ||
const std::vector<sycl::event>& deps = {}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<sycl::event>& deps = {}); | |
const std::optional<std::vector<sycl::event>>& deps = std::nullopt); |
|
||
// qscale:[K/qGroupSize, N] | ||
// qzp:[K/qGroupSize, N] | ||
woq_matmul_int4(C, A, B, qScale, qZeros, qGroupSize, onednn::Attr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any case that we need to fuse other operations? What's the motivation here to provide attributes?
const Tensor& A, | ||
const Tensor& B, | ||
int64_t qGroupSize, | ||
const Tensor& qScale, | ||
const Tensor& qZeros) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhiweiYan-96 , the code style of Blass.cpp
is snake_case, why is the style of these variables camelCase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, @EikanWang The naming style here is for aligning with other backend like cuda (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/int4mm.cu#L1097) and cpu (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LinearAlgebra.cpp#L3461)
|
||
at::Device cur_device = at::Device(at::kXPU, at::xpu::current_device()); | ||
auto engine = GpuEngineManager::Instance().get_engine(cur_device); | ||
auto stream = GpuStreamManager::Instance().get_stream(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhiweiYan-96 , may I know where the guard code to ensure all the input tensors to be on the same device?
dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides); | ||
|
||
std::unordered_map<int, dnnl::memory> args; | ||
dnnl::post_ops po = attr.extract_post_ops(dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The po
should be useless. Has this file been added to torch linter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Post-ops is not required at present. We can remove the post op and added it back when it is necessary.
- All file in
xpu/detail/*.cpp
is in linter checking list. I met this before. It should caused that, linter does not check this noused style issue.
auto engine = GpuEngineManager::Instance().get_engine(cur_device); | ||
auto stream = GpuStreamManager::Instance().get_stream(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto engine = GpuEngineManager::Instance().get_engine(cur_device); | |
auto stream = GpuStreamManager::Instance().get_stream(); | |
auto& engine = GpuEngineManager::Instance().get_engine(); | |
auto& stream = GpuStreamManager::Instance().get_stream(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the information. Has updated the code.
Update
|
Rebased |
Starting merge as part of PR stack under #147962 |
1 similar comment
Starting merge as part of PR stack under #147962 |
…tration (#147962) Pull Request resolved: #147962 Approved by: https://github.com/jerryzh168, https://github.com/guangyey, https://github.com/EikanWang ghstack dependencies: #137566 Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Pull Request resolved: pytorch#137566 Approved by: https://github.com/liangan1, https://github.com/guangyey, https://github.com/EikanWang Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
…tration (pytorch#147962) Pull Request resolved: pytorch#147962 Approved by: https://github.com/jerryzh168, https://github.com/guangyey, https://github.com/EikanWang ghstack dependencies: pytorch#137566 Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Pull Request resolved: pytorch#137566 Approved by: https://github.com/liangan1, https://github.com/guangyey, https://github.com/EikanWang Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
…tration (pytorch#147962) Pull Request resolved: pytorch#147962 Approved by: https://github.com/jerryzh168, https://github.com/guangyey, https://github.com/EikanWang ghstack dependencies: pytorch#137566 Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
ghstack-source-id: c2c4f90 Pull Request resolved: pytorch/pytorch#137566
Stack from ghstack (oldest at bottom):
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @EikanWang @fengyuan14 @guangyey