diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py index f1582ed560..837ddc4467 100644 --- a/deep_gemm/utils/math.py +++ b/deep_gemm/utils/math.py @@ -18,7 +18,7 @@ def ceil_to_ue8m0(x: torch.Tensor): def pack_ue8m0_to_int(x: torch.Tensor): assert x.dtype == torch.float and x.size(-1) % 4 == 0 - assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all() + assert ((x.view(torch.int) & ((1 << 23) - 1)) == 0).all() return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int)