diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 277d947..852be93 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -143,6 +143,11 @@ class GraniteSwitchPreTrainedModel(GraniteMoeHybridPreTrainedModel): config_class = GraniteSwitchConfig base_model_prefix = "model" _no_split_modules = ["GraniteSwitchAttentionDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [ + r"model\.adapter_token_ids", + r"model\.token_to_group_mask", + r"model\.adapter_hiding_matrix", + ] _is_stateful = True @@ -167,8 +172,9 @@ def __init__(self, config: GraniteSwitchConfig): # --- Control token buffers --- # All values come from config (serialized in config.json). - # Stored as buffers (not nn.Parameter) so they follow .to(device) - # without appearing as trainable parameters. + # Stored as non-persistent buffers so they follow .to(device) + # without appearing as trainable parameters or in state_dict() + # (avoids accelerate device_map placement errors on multi-GPU). # # adapter_token_ids: Hidden-flavor control tokens, one per adapter. # The switch layer detects these in the input sequence to determine @@ -180,12 +186,14 @@ def __init__(self, config: GraniteSwitchConfig): self.register_buffer( "adapter_token_ids", torch.tensor(token_ids, dtype=torch.long), + persistent=False, ) else: # Build script hasn't populated yet — zeros placeholder self.register_buffer( "adapter_token_ids", torch.zeros(config.num_adapters, dtype=torch.long), + persistent=False, ) # --- Hiding group buffers --- @@ -208,7 +216,7 @@ def __init__(self, config: GraniteSwitchConfig): for g, tids in group_token_ids.items(): for tid in tids: token_to_group_mask[tid, g] = True - self.register_buffer("token_to_group_mask", token_to_group_mask) + self.register_buffer("token_to_group_mask", token_to_group_mask, persistent=False) # adapter_hiding_matrix: [num_adapter_slots, num_groups] boolean. # Index 0 = base, 1+ = adapters. True if adapter hides group g. @@ -216,6 +224,7 @@ def __init__(self, config: GraniteSwitchConfig): self.register_buffer( "adapter_hiding_matrix", torch.tensor(policy_matrix, dtype=torch.bool), + persistent=False, ) else: self.token_to_group_mask = None @@ -257,6 +266,34 @@ def __init__(self, config: GraniteSwitchConfig): # Initialize weights self.post_init() + def _rebuild_hiding_buffers(self, device: torch.device): + """Rebuild hiding group buffers from config on the given device. + + Called on first forward when accelerate's init_empty_weights() has + zeroed out the non-persistent buffers during from_pretrained. + """ + config = self.config + num_groups = config.num_hiding_groups + if num_groups > 0: + group_token_ids = config.get_hiding_group_token_ids() + all_known_ids = [tid for tids in group_token_ids.values() for tid in tids] + if config.adapter_token_ids: + all_known_ids.extend(config.adapter_token_ids) + max_tid = max(all_known_ids) if all_known_ids else -1 + table_size = max(config.vocab_size, max_tid + 1) + token_to_group_mask = torch.zeros( + table_size, num_groups, dtype=torch.bool, device=device + ) + for g, tids in group_token_ids.items(): + for tid in tids: + token_to_group_mask[tid, g] = True + self.token_to_group_mask = token_to_group_mask + + policy_matrix = config.get_adapter_hiding_policy_matrix() + self.adapter_hiding_matrix = torch.tensor( + policy_matrix, dtype=torch.bool, device=device + ) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -318,6 +355,21 @@ def forward( # Compute adapter_indices using switch (BEFORE RoPE for position correction) hidden_count = None if self.switch is not None: + # Non-persistent buffers are zeroed by accelerate's init_empty_weights() + # during from_pretrained with device_map. Rebuild from config on first forward. + device = input_ids.device if input_ids is not None else inputs_embeds.device + if self.adapter_token_ids.sum() == 0 and self.config.adapter_token_ids: + self.adapter_token_ids = torch.tensor( + self.config.adapter_token_ids, dtype=torch.long, device=device + ) + self._rebuild_hiding_buffers(device) + elif self.adapter_token_ids.device != device: + self.adapter_token_ids = self.adapter_token_ids.to(device) + if self.token_to_group_mask is not None: + self.token_to_group_mask = self.token_to_group_mask.to(device) + if self.adapter_hiding_matrix is not None: + self.adapter_hiding_matrix = self.adapter_hiding_matrix.to(device) + adapter_indices = self.switch( input_ids=input_ids, adapter_token_ids=self.adapter_token_ids,