diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..769170d3b7 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1912,14 +1912,14 @@ def get_model(dtype, config): # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( 2, - 4096, + 2048, 128, 192, head_dim_v=128, ), "fp8_10": ModelConfig( - 1, - 4096, + 2, + 2048, 128, 192, head_dim_v=128, @@ -1927,21 +1927,23 @@ def get_model(dtype, config): ), "fp8_11": ModelConfig( 2, - 4096, + 2048, 128, 192, head_dim_v=128, attn_mask_type="causal_bottom_right", ), - "fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), - "fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), + "fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_13": ModelConfig( + 2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), + "fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), "fp8_16": ModelConfig( - 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + 1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), "fp8_17": ModelConfig( - 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + 2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" ), "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),