Skip to content

Commit 7233616

Browse files
Added support for KV cache quantization for vllm fakequant (#686)
## What does this PR do? **Type of change:** New feature **Overview:** - Added support to quantize KV cache in vLLM fakequant by adding quantization support for [Attention](https://github.com/vllm-project/vllm/blob/v0.12.0/vllm/attention/layer.py#L161) - Modified initialization of parallel state to incorporate vLLM parallel state groups for correct quantization parameter syncing ## Usage Please refer to [Readme](https://github.com/NVIDIA/Model-Optimizer/tree/kinjal/vllm_att_quant/examples/vllm_serve#calibrate-and-serve-fake-quant-model-in-vllm) ``` KV_QUANT_CFG=NVFP4_KV_CFG QUANT_CFG=NVFP4_DEFAULT_CFG python vllm_serve_fakequant.py meta-llama/Llama-3.2-1B-Instruct --served-model-name meta-llama/Llama-3.2-1B-Instruct --host 0.0.0.0 --port 8001 --trust-remote-code ``` ## Testing Locally tested KV Cache quantization ``` model.layers.0.self_attn.qkv_proj.input_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=5.0312 calibrator=MaxCalibratorquant) model.layers.0.self_attn.qkv_proj.weight_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.6758 calibrator=MaxCalibratorquant) model.layers.0.self_attn.qkv_proj.output_quantizer TensorQuantizer(disabled) model.layers.0.self_attn.o_proj.input_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits':(4,3)}, amax=1.3438 calibrator=MaxCalibrator quant) model.layers.0.self_attn.o_proj.weight_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic','scale_bits': (4, 3)}, amax=0.3145 calibrator=MaxCalibratorquant) model.layers.0.self_attn.o_proj.output_quantizer TensorQuantizer(disabled) model.layers.0.self_attn.attn.q_bmm_quantizer TensorQuantizer(disabled) model.layers.0.self_attn.attn.k_bmm_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=13.8125 calibrator=MaxCalibrator quant) model.layers.0.self_attn.attn.v_bmm_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic','scale_bits': (4, 3)}, amax=1.3438 calibrator=MaxCalibratorquant) model.layers.0.mlp.gate_up_proj.input_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=3.2812 calibrator=MaxCalibratorquant) model.layers.0.mlp.gate_up_proj.weight_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.5938 calibrator=MaxCalibratorquant) model.layers.0.mlp.gate_up_proj.output_quantizer TensorQuantizer(disabled) model.layers.0.mlp.down_proj.input_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=33.7500 calibrator=MaxCalibrator quant) model.layers.0.mlp.down_proj.weight_quantizer TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.6211 calibrator=MaxCalibratorquant) model.layers.0.mlp.down_proj.output_quantizer TensorQuantizer(disabled) ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: NA - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent d8d5a29 commit 7233616

4 files changed

Lines changed: 48 additions & 31 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ NVIDIA Model Optimizer Changelog (Linux)
99
- Add support for Transformer Engine quantization for Megatron Core models.
1010
- Add support for Qwen3-Next model quantization.
1111
- Add support for dynamically linked TensorRT plugins in the ONNX quantization workflow.
12+
- Add support for KV Cache Quantization for vLLM FakeQuant PTQ script. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#Calibrate-and-serve-fake-quant-model-in-vLLM>`__ for more details.
1213

1314
**Deprecations**
1415

examples/vllm_serve/README.md

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`,
2424
| QUANT_DATASET | Dataset name for calibration | cnn_dailymail |
2525
| QUANT_CALIB_SIZE| Number of samples used for calibration | 512 |
2626
| QUANT_CFG | Quantization format | NVFP4_DEFAULT_CFG |
27+
| KV_QUANT_CFG | Quantization format for KV Cache | None |
2728
| AMAX_FILE_PATH | Optional path to amax file (for loading amax) | None |
2829

2930
Set these variables in your shell or Docker environment as needed to customize calibration.
@@ -68,25 +69,8 @@ Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH envi
6869
AMAX_FILE_PATH=<vllm_amax.pth> QUANT_CFG=<quant_config> python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000
6970
```
7071

71-
## Important Notes
72-
73-
**Amax Synchronization across Tensor Parallel (TP):**
74-
75-
- **For non-per-tensor quantization**: It is **recommended** to use an amax file (via `AMAX_FILE_PATH`) because amax synchronization across TP/EP is not automatically handled. Without an amax file, the amax values can be different across different TP ranks, leading to inconsistent results compared to real-quantization.
76-
77-
- **For per-tensor quantization**: If you are not using an amax file, you need to enable amax synchronization across TP ranks. An example implementation is provided in `fakequant_worker.py` (lines 190-198):
78-
79-
```python
80-
for name, buffer in model.named_buffers():
81-
if name.endswith("_amax"):
82-
torch.distributed.all_reduce(
83-
buffer, op=torch.distributed.ReduceOp.MAX, group=get_tp_group().device_group
84-
)
85-
torch.distributed.barrier()
86-
```
87-
8872
## Known Problems
8973

9074
1. AWQ is not yet supported in vLLM.
91-
2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled.
75+
2. QAT checkpoint export doesn't have KV Cache quantization enabled. KV Cache fake quantization works for PTQ.
9276
3. Mixed precision checkpoint doesn't work currently.

examples/vllm_serve/fakequant_worker.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def disable_compilation(model):
150150
"dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"),
151151
"calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)),
152152
"quant_cfg": os.environ.get("QUANT_CFG", "NVFP4_DEFAULT_CFG"),
153+
"kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None),
153154
"amax_file_path": os.environ.get("AMAX_FILE_PATH", None),
154155
}
155156

@@ -236,6 +237,10 @@ def calibrate_loop(model: Any = None) -> None:
236237
self.sample_tokens(None)
237238

238239
quant_cfg = getattr(mtq, quant_config["quant_cfg"])
240+
if quant_config["kv_quant_cfg"] is not None:
241+
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
242+
quant_cfg, getattr(mtq, quant_config["kv_quant_cfg"])["quant_cfg"]
243+
)
239244

240245
model = self.model_runner.model
241246
if hasattr(model, "unwrap"):
@@ -290,17 +295,6 @@ def calibrate_loop(model: Any = None) -> None:
290295
model.load_state_dict(current_state_dict)
291296
torch.distributed.barrier()
292297

293-
if amax_file_path is None:
294-
# Sync amax across TP can be done here if needed
295-
pass
296-
# for name, buffer in model.named_buffers():
297-
# if name.endswith("_amax"):
298-
# print("syncing amax across TP for", name)
299-
# torch.distributed.all_reduce(
300-
# buffer, op=torch.distributed.ReduceOp.MAX, group=get_tp_group().device_group
301-
# )
302-
# torch.distributed.barrier()
303-
304298
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
305299
mtq.print_quant_summary(model)
306300

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
import importlib
1919

2020
import torch
21+
import vllm.attention as vllm_attention
2122
import vllm.model_executor.layers.fused_moe.layer as vllm_fused_moe_layer
2223
import vllm.model_executor.layers.linear as vllm_linear
24+
from vllm.attention.layers.cross_attention import CrossAttention
25+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
26+
from vllm.distributed.parallel_state import get_dp_group, get_ep_group, get_tp_group
2327

2428
from ...utils.distributed import ParallelState
2529
from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer
@@ -90,6 +94,14 @@ def apply(
9094
return output
9195

9296

97+
def create_parallel_state():
98+
"""Create a parallel state for vLLM."""
99+
dp_group = get_dp_group().device_group
100+
tp_group = get_tp_group().device_group
101+
ep_group = get_ep_group().device_group
102+
return ParallelState(dp_group, tp_group, ep_group)
103+
104+
93105
class _VLLMParallelLinear(QuantModule):
94106
def _setup(self):
95107
self.input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
@@ -100,7 +112,7 @@ def _setup(self):
100112
f"quant_method is {type(self.quant_method)}"
101113
)
102114
self.fake_quant_method = FakeQuantMethod(self.quant_method)
103-
self.parallel_state = ParallelState(-1, -1)
115+
self.parallel_state = create_parallel_state()
104116

105117
def forward(self, input_):
106118
# This context manager will conflict with torch.compile
@@ -151,7 +163,7 @@ def _setup(self):
151163
assert type(self.quant_method) is vllm_fused_moe_layer.UnquantizedFusedMoEMethod, (
152164
f"quant_method is {type(self.quant_method)}"
153165
)
154-
self.parallel_state = ParallelState(-1, -1)
166+
self.parallel_state = create_parallel_state()
155167

156168
def invoke_fused_moe_quantized(
157169
self,
@@ -243,3 +255,29 @@ class _QuantVLLMFusedMoE(_QuantFusedMoEBase):
243255
)
244256
class _QuantVLLMSharedFusedMoE(_QuantFusedMoEBase):
245257
pass
258+
259+
260+
@QuantModuleRegistry.register({vllm_attention.Attention: "vllm_Attention"})
261+
class _QuantVLLMAttention(QuantModule):
262+
def _setup(self):
263+
self.q_bmm_quantizer = TensorQuantizer()
264+
self.k_bmm_quantizer = TensorQuantizer()
265+
self.v_bmm_quantizer = TensorQuantizer()
266+
self.parallel_state = create_parallel_state()
267+
268+
def forward(self, query, key, value, *args, **kwargs):
269+
query = self.q_bmm_quantizer(query)
270+
key = self.k_bmm_quantizer(key)
271+
value = self.v_bmm_quantizer(value)
272+
273+
return super().forward(query, key, value, *args, **kwargs)
274+
275+
276+
@QuantModuleRegistry.register({CrossAttention: "vllm_CrossAttention"})
277+
class _QuantVLLMCrossAttention(_QuantVLLMAttention):
278+
pass
279+
280+
281+
@QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"})
282+
class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention):
283+
pass

0 commit comments

Comments
 (0)