1515
1616"""Calibration functions for sparse attention."""
1717
18+ import hashlib
19+ import json
1820import warnings
1921from collections .abc import Callable
22+ from pathlib import Path
2023from typing import Any
2124
2225import torch
3033from .dataset import RulerDatasetBuilder
3134
3235
36+ def _get_cache_path (
37+ tokenizer_path : str , samples : int , max_seqlen : int , cache_dir : str | None = None
38+ ) -> Path :
39+ """Generate cache file path based on calibration parameters.
40+
41+ Args:
42+ tokenizer_path: Path to tokenizer (used in hash)
43+ samples: Number of calibration samples
44+ max_seqlen: Maximum sequence length
45+ cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/sparse_attention/
46+ """
47+ # Create a hash of the parameters for the cache filename
48+ key = f"{ tokenizer_path } _{ samples } _{ max_seqlen } "
49+ hash_str = hashlib .md5 (key .encode (), usedforsecurity = False ).hexdigest ()[:12 ]
50+ filename = f"ruler_cache_{ samples } s_{ max_seqlen } l_{ hash_str } .json"
51+
52+ if cache_dir :
53+ base_dir = Path (cache_dir )
54+ else :
55+ base_dir = Path .home () / ".cache" / "modelopt" / "sparse_attention"
56+
57+ return base_dir / filename
58+
59+
60+ def _load_cached_data (cache_path : Path ) -> list [dict [str , Any ]] | None :
61+ """Load calibration data from cache if it exists."""
62+ if cache_path .exists ():
63+ try :
64+ with open (cache_path ) as f :
65+ data = json .load (f )
66+ print (f"Loaded { len (data )} cached calibration samples from { cache_path } " )
67+ return data
68+ except Exception as e :
69+ print (f"Warning: Failed to load cache: { e } " )
70+ return None
71+
72+
73+ def _save_cached_data (cache_path : Path , data : list [dict [str , Any ]]) -> None :
74+ """Save calibration data to cache."""
75+ try :
76+ cache_path .parent .mkdir (parents = True , exist_ok = True )
77+ with open (cache_path , "w" ) as f :
78+ json .dump (data , f )
79+ print (f"Saved calibration samples to cache: { cache_path } " )
80+ except Exception as e :
81+ print (f"Warning: Failed to save cache: { e } " )
82+
83+
3384def _extract_tokenizer_from_model (model : nn .Module ) -> str :
3485 """Extract tokenizer name/path from model config.
3586
@@ -255,18 +306,31 @@ def calibrate_sparse_attention(
255306 calibration_data = None
256307
257308 if calibrate_prefill or calibrate_decode :
258- builder = RulerDatasetBuilder (
259- samples = calib_config . samples ,
260- max_seqlen = calib_config . max_seqlen ,
261- tokenizer_name_or_path = tokenizer ,
262- num_length_bins = calib_config .num_length_bins ,
263- max_length_filter = int ( calib_config .max_seqlen * 1.5 ) ,
309+ # Try to load from cache first
310+ cache_path = _get_cache_path (
311+ tokenizer ,
312+ calib_config . samples ,
313+ calib_config .max_seqlen ,
314+ cache_dir = calib_config .cache_dir ,
264315 )
265- calibration_data = builder .build_calibration_dataset ()
266- print (f"Generated { len (calibration_data )} calibration samples" )
316+ calibration_data = _load_cached_data (cache_path )
317+
318+ # Generate if not cached
319+ if calibration_data is None :
320+ builder = RulerDatasetBuilder (
321+ samples = calib_config .samples ,
322+ max_seqlen = calib_config .max_seqlen ,
323+ tokenizer_name_or_path = tokenizer ,
324+ num_length_bins = calib_config .num_length_bins ,
325+ max_length_filter = int (calib_config .max_seqlen * 1.5 ),
326+ )
327+ calibration_data = builder .build_calibration_dataset ()
328+ print (f"Generated { len (calibration_data )} calibration samples" )
329+
330+ # Save to cache for future runs
331+ _save_cached_data (cache_path , calibration_data )
267332
268333 # Initialize results
269- threshold_scale_factor : dict [str , float ] = {}
270334 calibration_results : dict [str , Any ] = {}
271335
272336 # Run prefill calibration if enabled
@@ -282,13 +346,11 @@ def calibrate_sparse_attention(
282346 )
283347
284348 prefill_calibrator = DynamicThresholdCalibrator (
285- target_sparse_ratio = target_dict ,
286349 threshold_trials = calib_config .threshold_trials ,
287350 )
288351 prefill_result = prefill_calibrator .calibrate (model , prefill_forward_loop , phase = "prefill" )
289352
290- if "scale_factor" in prefill_result :
291- threshold_scale_factor ["prefill" ] = prefill_result ["scale_factor" ]
353+ if "k" in prefill_result and "p" in prefill_result :
292354 calibration_results ["prefill" ] = prefill_result
293355 else :
294356 warnings .warn ("Prefill calibration did not produce valid results" )
@@ -306,38 +368,57 @@ def calibrate_sparse_attention(
306368 )
307369
308370 decode_calibrator = DynamicThresholdCalibrator (
309- target_sparse_ratio = target_dict ,
310371 threshold_trials = calib_config .threshold_trials ,
311372 )
312373 decode_result = decode_calibrator .calibrate (model , decode_forward_loop , phase = "decode" )
313374
314- if "scale_factor" in decode_result :
315- threshold_scale_factor ["decode" ] = decode_result ["scale_factor" ]
375+ if "k" in decode_result and "p" in decode_result :
316376 calibration_results ["decode" ] = decode_result
317377 else :
318378 warnings .warn ("Decode calibration did not produce valid results" )
319379
320380 # Check if any calibration succeeded
321- if not threshold_scale_factor :
381+ if not calibration_results :
322382 warnings .warn ("No calibration produced valid results" )
323383 return {}
324384
325- # Apply combined threshold_scale_factor dict to all modules
385+ # Extract k and p for each phase
386+ calibration_params : dict [str , dict [str , float ]] = {}
387+ for phase in ["prefill" , "decode" ]:
388+ if phase in calibration_results :
389+ result = calibration_results [phase ]
390+ calibration_params [phase ] = {
391+ "k" : result ["k" ],
392+ "p" : result ["p" ],
393+ }
394+
395+ # Apply calibration params to all modules
326396 print ("\n " + "=" * 60 )
327397 print ("APPLYING CALIBRATION RESULTS" )
328398 print ("=" * 60 )
329- print (f"Applying threshold_scale_factor to { len (sparse_modules )} modules:" )
330- for phase , scale_factor in threshold_scale_factor .items ():
331- print (f" { phase } : { scale_factor :.6f} " )
399+ print (f"Applying calibration to { len (sparse_modules )} modules:" )
400+ for phase , params in calibration_params .items ():
401+ result = calibration_results [phase ]
402+ print (f" { phase } :" )
403+ print (f" Model: scale_factor = { params ['k' ]:.4f} / (1 - sparsity)^{ params ['p' ]:.4f} " )
404+ print (f" R-squared: { result ['r_squared' ]:.6f} " )
332405
333406 for module_name , module in sparse_modules :
334- module ._sparse_method_instance .threshold_scale_factor = threshold_scale_factor
407+ module ._sparse_method_instance .calibration_params = calibration_params
408+ module ._sparse_method_instance .target_sparse_ratio = target_dict
335409
336410 # Print final summary
337411 print ("\n Calibration complete!" )
412+ print (
413+ f"Target sparsity: prefill={ target_dict .get ('prefill' , 0 ):.0%} , "
414+ f"decode={ target_dict .get ('decode' , 0 ):.0%} "
415+ )
416+ print ("\n To change target sparsity at inference time, update:" )
417+ print (" module._sparse_method_instance.target_sparse_ratio = {'prefill': X, 'decode': Y}" )
338418 print_sparse_attention_summary (model )
339419
340420 return {
341- "threshold_scale_factor" : threshold_scale_factor ,
421+ "calibration_params" : calibration_params ,
422+ "target_sparse_ratio" : target_dict ,
342423 "calibration_results" : calibration_results ,
343424 }
0 commit comments