diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 274a35b81d..f54d16abe2 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -414,17 +414,20 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [1024, 256, 128, 32]) @pytest.mark.parametrize("topk", [4, 32]) -def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): +@pytest.mark.parametrize("expert_multiplier", [1, 2]) +def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk, expert_multiplier): if topk >= num_experts: pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") + # Sequence aux loss batches independent sequences along the expert dimension. + num_cols = num_experts * expert_multiplier # Construct the special probs to avoid inf in the sigmoid function offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 - probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + probs = torch.arange(-num_cols // 2, num_cols // 2, device="cuda", dtype=dtype) * 1e-2 probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) - probs = probs.view(num_tokens, num_experts) + probs = probs.view(num_tokens, num_cols) probs.requires_grad = True - tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32) + tokens_per_expert = torch.randint(1, 1000, (num_cols,), device="cuda", dtype=torch.int32) coeff = 0.01 probs_clone = deepcopy(probs) @@ -448,7 +451,7 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): coeff=coeff, ) - atol, rtol = _get_tolerances(dtype, num_experts) + atol, rtol = _get_tolerances(dtype, num_cols) torch.testing.assert_close(aux_loss, aux_loss_fused, atol=atol, rtol=rtol) # Backward diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 7e516af97b..cc5e5e3bcc 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -87,8 +87,11 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, int num_cols, int topk, float coeff, DataType* aux_loss, float* Coeff_buf, cudaStream_t stream) { - NVTE_CHECK(num_experts == num_cols, "Number of experts (", num_experts, - ") must be equal to number of input columns (", num_cols, ")."); + NVTE_CHECK(num_cols > 0, "num_cols must be positive, got ", num_cols); + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + // Sequence aux loss batches independent sequences along the expert dimension. + NVTE_CHECK(num_cols % num_experts == 0, "Number of input columns (", num_cols, + ") must be a multiple of number of experts (", num_experts, ")."); // Round up to a multiple of warp size for correct warp shuffles. const int block_size = ((std::min(1024, num_cols) + static_cast(kThreadsPerWarp) - 1) / @@ -98,7 +101,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, // One CompType per thread in shared memory. const size_t smem_size = block_size * sizeof(CompType); - check_shared_memory_capacity_num_experts(smem_size, num_experts); + check_shared_memory_capacity_num_experts(smem_size, num_cols); // Compute final coefficient and zero the float accumulator (Coeff_buf[1]) before launch. const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;