### 问题描述 Please describe your issue mask_index = [0, 2] mask = torch.ones([3, 3]) img_mask = torch.Tensor([2,2,2]) print(mask) mask[:, mask_index] = mask[:, mask_index] * img_mask.unsqueeze(-1) print(mask)