Skip to content

Commit 4eb1835

Browse files
Support for KV cache quantization for MLA Attention vLLM fakequant (#714)
## What does this PR do? **Type of change:** Feature extention **Overview:** Added support to quantize KV cache in vLLM fakequant by adding quantization support for [MLAAttention](https://github.com/vllm-project/vllm/blob/v0.11.1/vllm/attention/layer.py#L641) ## Usage Please refer to [Readme](https://github.com/NVIDIA/Model-Optimizer/tree/kinjal/vllm_att_quant/examples/vllm_serve#calibrate-and-serve-fake-quant-model-in-vllm) ```shell KV_QUANT_CFG=NVFP4_KV_CFG QUANT_CFG=NVFP4_DEFAULT_CFG python vllm_serve_fakequant.py deepseek-ai/DeepSeek-V2 --served-model-name deepseek-ai/DeepSeek-V2 --host 0.0.0.0 --port 8001 --trust-remote-code --enforce-eager --gpu-memory-utilization 0.8 ``` ## Testing Locally tested KV Cache quantization ``` �(rotary_emb): DeepseekScalingRotaryEmbedding() �(mla_attn): MultiHeadLatentAttentionWrapper( � (fused_qkv_a_proj): QuantMergedColumnParallelLinear( � in_features=5120, output_features=2112, bias=False, tp_size=1, gather_output=False � (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=141.0000 calibrator=MaxCalibrator quant) � (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=1.4297 calibrator=MaxCalibrator quant) � (output_quantizer): TensorQuantizer(disabled) � ) � (q_a_layernorm): RMSNorm(hidden_size=1536, eps=1e-06) � (q_b_proj): QuantColumnParallelLinear( � in_features=1536, output_features=3072, bias=False, tp_size=8, gather_output=False � (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=32.0000 calibrator=MaxCalibrator quant) � (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.1670 calibrator=MaxCalibrator quant) � (output_quantizer): TensorQuantizer(disabled) � ) � (kv_a_layernorm): RMSNorm(hidden_size=512, eps=1e-06) � (kv_b_proj): QuantColumnParallelLinear( � in_features=512, output_features=4096, bias=False, tp_size=8, gather_output=False � (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=7.5312 calibrator=MaxCalibrator quant) � (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.2773 calibrator=MaxCalibrator quant) � (output_quantizer): TensorQuantizer(disabled) � ) � (rotary_emb): DeepseekScalingRotaryEmbedding() � (o_proj): QuantRowParallelLinear( � in_features=2048, output_features=5120, bias=False, tp_size=8, reduce_results=True � (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=1.7188 calibrator=MaxCalibrator quant) � (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.4336 calibrator=MaxCalibrator quant) � (output_quantizer): TensorQuantizer(disabled) � ) � (mla_attn): QuantMLAAttention( � (q_bmm_quantizer): TensorQuantizer(disabled) � (kv_c_bmm_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=7.5312 calibrator=MaxCalibrator quant) � ) �) ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**:No - **Did you add or update any necessary documentation?**:NA - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: NA ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 8426c36 commit 4eb1835

3 files changed

Lines changed: 67 additions & 8 deletions

File tree

examples/vllm_serve/fakequant_worker.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,35 @@ def disable_compilation(model):
149149
quant_config: dict[str, Any] = {
150150
"dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"),
151151
"calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)),
152-
"quant_cfg": os.environ.get("QUANT_CFG", "NVFP4_DEFAULT_CFG"),
152+
"quant_cfg": os.environ.get("QUANT_CFG", None),
153153
"kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None),
154154
"amax_file_path": os.environ.get("AMAX_FILE_PATH", None),
155155
}
156156

157157

158+
def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]:
159+
"""Update KV cache quantization config for MLA models.
160+
161+
MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate
162+
`k_bmm_quantizer` and `v_bmm_quantizer`. This function copies the
163+
config from `*[kv]_bmm_quantizer` to also cover `*kv_c_bmm_quantizer`.
164+
"""
165+
try:
166+
from vllm.attention.layer import MLAAttention
167+
except ImportError:
168+
return kv_quant_cfg
169+
170+
if not any(isinstance(m, MLAAttention) for m in model.modules()):
171+
return kv_quant_cfg
172+
173+
if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"):
174+
kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config
175+
kv_quant_cfg["*k_pe_bmm_quantizer"] = kv_config
176+
print("MLA detected: added *kv_c_bmm_quantizer and k_pe_bmm_quantizer config")
177+
178+
return kv_quant_cfg
179+
180+
158181
def _create_new_data_cls(data_cls, **kwargs):
159182
"""vLLM's low-level API changes frequently. This function creates a class with parameters
160183
compatible with the different vLLM versions."""
@@ -236,16 +259,24 @@ def calibrate_loop(model: Any = None) -> None:
236259
if output is None: # TODO: make this default when vllm <= 0.11 is outdated
237260
self.sample_tokens(None)
238261

239-
quant_cfg = getattr(mtq, quant_config["quant_cfg"])
240-
if quant_config["kv_quant_cfg"] is not None:
241-
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
242-
quant_cfg, getattr(mtq, quant_config["kv_quant_cfg"])["quant_cfg"]
243-
)
262+
quant_cfg = {} if quant_config["quant_cfg"] is None else getattr(mtq, quant_config["quant_cfg"])
263+
quant_kv_cfg = (
264+
{} if quant_config["kv_quant_cfg"] is None else getattr(mtq, quant_config["kv_quant_cfg"])
265+
)
244266

245267
model = self.model_runner.model
246268
if hasattr(model, "unwrap"):
247269
model = model.unwrap()
248270

271+
# Check if model has MLA and update KV config accordingly
272+
if quant_kv_cfg:
273+
quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"])
274+
275+
if quant_kv_cfg:
276+
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
277+
quant_cfg, quant_kv_cfg["quant_cfg"]
278+
)
279+
249280
with disable_compilation(model):
250281
print("quantizing model...")
251282
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
@@ -314,6 +345,6 @@ def determine_available_memory(self) -> int:
314345
return super().determine_available_memory()
315346

316347
def compile_or_warm_up_model(self) -> None:
317-
if quant_config["quant_cfg"]:
348+
if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]:
318349
_fakequant_run_prolog_worker(self)
319350
super().compile_or_warm_up_model()

examples/vllm_serve/vllm_serve_fakequant.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@
7070

7171

7272
# Adding the envs you want to pass to the workers
73-
additional_env_vars = {"QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", "AMAX_FILE_PATH"}
73+
additional_env_vars = {
74+
"QUANT_DATASET",
75+
"QUANT_CALIB_SIZE",
76+
"QUANT_CFG",
77+
"AMAX_FILE_PATH",
78+
"KV_QUANT_CFG",
79+
}
7480

7581
RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)
7682

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
except ImportError:
4141
continue
4242

43+
try:
44+
from vllm.attention.layer import MLAAttention as VllmMLAAttention
45+
except ImportError:
46+
VllmMLAAttention = None
47+
4348
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
4449

4550

@@ -281,3 +286,20 @@ class _QuantVLLMCrossAttention(_QuantVLLMAttention):
281286
@QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"})
282287
class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention):
283288
pass
289+
290+
291+
if VllmMLAAttention is not None:
292+
293+
@QuantModuleRegistry.register({VllmMLAAttention: "vllm_MLAAttention"})
294+
class _QuantVLLMMLAAttention(QuantModule):
295+
def _setup(self):
296+
self.q_bmm_quantizer = TensorQuantizer()
297+
self.kv_c_bmm_quantizer = TensorQuantizer()
298+
self.k_pe_bmm_quantizer = TensorQuantizer()
299+
self.parallel_state = create_parallel_state()
300+
301+
def forward(self, query, kv_c, k_pe, *args, **kwargs):
302+
query = self.q_bmm_quantizer(query)
303+
kv_c = self.kv_c_bmm_quantizer(kv_c)
304+
k_pe = self.k_pe_bmm_quantizer(k_pe)
305+
return super().forward(query, kv_c, k_pe, *args, **kwargs)

0 commit comments

Comments
 (0)