Skip to content

Commit 0a1ca5d

Browse files
Fix unified_export_megatron for transformers 5.6 (#1335)
### What does this PR do? Type of change: Bug fix Broaden the exception handler around `AutoTokenizer.from_pretrained` in `GPTModelExporter.save_pretrained` to also catch `ValueError` and `ImportError`. In `transformers` 4.x, attempting to load a tokenizer from a directory that contains only `config.json` (no `tokenizer.json` / `tokenizer.model` / `tokenizer_config.json`) raised `OSError`, which was already handled. In `transformers` 5.x the resolution path now reaches `PreTrainedTokenizerFast.__init__` and raises a `ValueError` ("Couldn't instantiate the backend tokenizer from one of: ...") when none of the three backend sources are available. This caused `export_mcore_gpt_to_hf` to hard-fail for checkpoint directories that don't carry tokenizer files — including `tests/gpu_megatron/torch/export/test_unified_export_megatron.py`, which writes only a minimal `config.json`. The broadened `except (OSError, TypeError, ValueError, ImportError)` mirrors the pattern already used just below for `AutoProcessor.from_pretrained` and keeps tokenizer export best-effort, as originally intended. ### Usage No API change. Existing call sites continue to work: ```python from modelopt.torch.export import export_mcore_gpt_to_hf export_mcore_gpt_to_hf( model, pretrained_model_name_or_path, # may or may not contain tokenizer files dtype=torch.bfloat16, export_dir=export_dir, ) ``` ### Testing - Reproduced the failure on `transformers==5.6` with: ``` pytest tests/gpu_megatron/torch/export/test_unified_export_megatron.py::test_unified_export_megatron[llama-LlamaForCausalLM-medusa-None-None] ``` which failed with `ValueError: Couldn't instantiate the backend tokenizer...` raised from `unified_export_megatron.py:299`. - After the fix, the same parametrization passes, and the other `llama` / `nemotron` / `eagle` / `medusa` parametrizations in the same test file remain green. - No behavioral change on `transformers` 4.x: the `OSError` path is still caught. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A <!-- Existing test `test_unified_export_megatron` already covers this path; the fix makes it pass on transformers 5.x. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information Triggered by the upgrade to `transformers` 5.6 (used in `nvcr.io/nvidia/nemo:26.04`, which is the container for the `gpu_megatron` nox session). The error message from `transformers` — *"You need to have sentencepiece or tiktoken installed..."* — is a misleading generic fallback; `sentencepiece` and `tiktoken` are already pulled in via the `[hf]` extras, and the real cause is the missing tokenizer files in the export source directory. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved error handling in the export process to gracefully manage additional exception types. * **Tests** * Enhanced test validation for Megatron export by using actual tokenizer artifacts, ensuring model vocabulary size matches test tokenizer configuration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent fda0899 commit 0a1ca5d

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

modelopt/torch/export/unified_export_megatron.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,7 @@ def save_pretrained(
301301
trust_remote_code=self.trust_remote_code,
302302
)
303303
tokenizer.save_pretrained(save_directory)
304-
except OSError:
305-
pass
306-
except TypeError:
304+
except (OSError, TypeError, ValueError, ImportError):
307305
pass
308306
try:
309307
# Load and save preprocessor config from the original model

tests/gpu_megatron/torch/export/test_unified_export_megatron.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import transformers
2424
from _test_utils.torch.megatron.models import get_mcore_gpt_model
2525
from _test_utils.torch.megatron.utils import get_forward
26-
from _test_utils.torch.transformers_models import create_tiny_llama_dir
26+
from _test_utils.torch.transformers_models import create_tiny_llama_dir, get_tiny_tokenizer
2727
from safetensors import safe_open
2828
from safetensors.torch import save_file
2929

@@ -74,13 +74,16 @@ def _verify_model_quant_config(
7474
def _test_unified_export_megatron(
7575
tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg, rank, size
7676
):
77+
tokenizer = get_tiny_tokenizer()
78+
tokenizer.save_pretrained(tmp_path)
79+
7780
num_layers = 2
7881
hidden_size = 64
7982
num_attention_heads = 8
8083
num_query_groups = size
8184
ffn_hidden_size = 128
8285
max_sequence_length = 32
83-
vocab_size = 64
86+
vocab_size = tokenizer.vocab_size
8487

8588
arch = "NemotronForCausalLM" if model_type == "nemotron" else "LlamaForCausalLM"
8689
activation_func = "squared_relu" if model_type == "nemotron" else "swiglu"

0 commit comments

Comments
 (0)