1515
1616"""Calibration functions for sparse attention."""
1717
18- import hashlib
19- import json
2018import warnings
2119from collections .abc import Callable
22- from pathlib import Path
2320from typing import Any
2421
2522import torch
2623import torch .nn as nn
2724from transformers import AutoTokenizer
2825
26+ from modelopt .torch .utils import get_module_device
27+
2928from ..config import CalibrationConfig
3029from ..conversion import print_sparse_attention_summary
31- from ..sparse_attention import SparseAttentionModule
30+ from ..utils import get_named_sparse_attention_modules
3231from .calibrator import DynamicThresholdCalibrator
33- from .dataset import RulerDatasetBuilder
34-
32+ from .ruler_dataset import RulerDatasetBuilder
3533
36- def _get_cache_path (
37- tokenizer_path : str , samples : int , max_seqlen : int , cache_dir : str | None = None
38- ) -> Path :
39- """Generate cache file path based on calibration parameters.
4034
41- Args:
42- tokenizer_path: Path to tokenizer (used in hash)
43- samples: Number of calibration samples
44- max_seqlen: Maximum sequence length
45- cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/sparse_attention/
46- """
47- # Create a hash of the parameters for the cache filename
48- key = f"{ tokenizer_path } _{ samples } _{ max_seqlen } "
49- hash_str = hashlib .md5 (key .encode (), usedforsecurity = False ).hexdigest ()[:12 ]
50- filename = f"ruler_cache_{ samples } s_{ max_seqlen } l_{ hash_str } .json"
51-
52- if cache_dir :
53- base_dir = Path (cache_dir )
54- else :
55- base_dir = Path .home () / ".cache" / "modelopt" / "sparse_attention"
56-
57- return base_dir / filename
58-
59-
60- def _load_cached_data (cache_path : Path ) -> list [dict [str , Any ]] | None :
61- """Load calibration data from cache if it exists."""
62- if cache_path .exists ():
63- try :
64- with open (cache_path ) as f :
65- data = json .load (f )
66- print (f"Loaded { len (data )} cached calibration samples from { cache_path } " )
67- return data
68- except Exception as e :
69- print (f"Warning: Failed to load cache: { e } " )
70- return None
71-
72-
73- def _save_cached_data (cache_path : Path , data : list [dict [str , Any ]]) -> None :
74- """Save calibration data to cache."""
75- try :
76- cache_path .parent .mkdir (parents = True , exist_ok = True )
77- with open (cache_path , "w" ) as f :
78- json .dump (data , f )
79- print (f"Saved calibration samples to cache: { cache_path } " )
80- except Exception as e :
81- print (f"Warning: Failed to save cache: { e } " )
35+ def _load_tokenizer (tokenizer_name_or_path : str ) -> "AutoTokenizer" :
36+ """Load tokenizer and ensure pad_token is set."""
37+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_name_or_path )
38+ if not tokenizer .pad_token :
39+ tokenizer .pad_token = tokenizer .eos_token
40+ return tokenizer
8241
8342
8443def _extract_tokenizer_from_model (model : nn .Module ) -> str :
@@ -147,12 +106,10 @@ def create_calibration_forward_loop(
147106 Returns:
148107 Forward loop function that takes model as argument
149108 """
150- tokenizer = AutoTokenizer .from_pretrained (tokenizer_name_or_path )
151- if not tokenizer .pad_token :
152- tokenizer .pad_token = tokenizer .eos_token
109+ tokenizer = _load_tokenizer (tokenizer_name_or_path )
153110
154111 def forward_loop (model : nn .Module ) -> None :
155- device = next (model . parameters ()). device
112+ device = get_module_device (model )
156113
157114 for sample in calibration_data :
158115 inputs = tokenizer (
@@ -205,12 +162,10 @@ def create_decode_calibration_forward_loop(
205162 Returns:
206163 Forward loop function that takes model as argument
207164 """
208- tokenizer = AutoTokenizer .from_pretrained (tokenizer_name_or_path )
209- if not tokenizer .pad_token :
210- tokenizer .pad_token = tokenizer .eos_token
165+ tokenizer = _load_tokenizer (tokenizer_name_or_path )
211166
212167 def forward_loop (model : nn .Module ) -> None :
213- device = next (model . parameters ()). device
168+ device = get_module_device (model )
214169
215170 for sample in calibration_data :
216171 inputs = tokenizer (
@@ -291,9 +246,7 @@ def calibrate_sparse_attention(
291246 return {}
292247
293248 # Get sparse attention modules
294- sparse_modules = [
295- (name , m ) for name , m in model .named_modules () if isinstance (m , SparseAttentionModule )
296- ]
249+ sparse_modules = get_named_sparse_attention_modules (model )
297250
298251 if not sparse_modules :
299252 print ("No sparse attention modules found for calibration" )
@@ -306,29 +259,16 @@ def calibrate_sparse_attention(
306259 calibration_data = None
307260
308261 if calibrate_prefill or calibrate_decode :
309- # Try to load from cache first
310- cache_path = _get_cache_path (
311- tokenizer ,
312- calib_config .samples ,
313- calib_config .max_seqlen ,
262+ builder = RulerDatasetBuilder (
263+ samples = calib_config .samples ,
264+ max_seqlen = calib_config .max_seqlen ,
265+ tokenizer_name_or_path = tokenizer ,
266+ num_length_bins = calib_config .num_length_bins ,
267+ max_length_filter = int (calib_config .max_seqlen * 1.5 ),
314268 cache_dir = calib_config .cache_dir ,
269+ data_dir = calib_config .data_dir ,
315270 )
316- calibration_data = _load_cached_data (cache_path )
317-
318- # Generate if not cached
319- if calibration_data is None :
320- builder = RulerDatasetBuilder (
321- samples = calib_config .samples ,
322- max_seqlen = calib_config .max_seqlen ,
323- tokenizer_name_or_path = tokenizer ,
324- num_length_bins = calib_config .num_length_bins ,
325- max_length_filter = int (calib_config .max_seqlen * 1.5 ),
326- )
327- calibration_data = builder .build_calibration_dataset ()
328- print (f"Generated { len (calibration_data )} calibration samples" )
329-
330- # Save to cache for future runs
331- _save_cached_data (cache_path , calibration_data )
271+ calibration_data = builder .build_calibration_dataset ()
332272
333273 # Initialize results
334274 calibration_results : dict [str , Any ] = {}
0 commit comments