From 105ebf08663ba8cd5a6d5c8fa085a022f042815b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 12 Jun 2026 15:09:25 +0100 Subject: [PATCH 1/2] [Refactor] Re-export History from tensordict.llm History, ContentBase, add_chat_template and the base chat templates moved to tensordict.llm (pytorch/tensordict#1724), which is now their canonical home. torchrl/data/llm/history.py becomes a backward-compatibility shim re-exporting the public symbols and the private template registries, so all existing torchrl.data.llm.history import paths keep working and class identity is shared with tensordict.llm. No warning is emitted on import for now, as the path is used widely internally; new code should import from tensordict.llm directly. Co-Authored-By: Claude Fable 5 --- torchrl/data/llm/history.py | 1511 +---------------------------------- 1 file changed, 18 insertions(+), 1493 deletions(-) diff --git a/torchrl/data/llm/history.py b/torchrl/data/llm/history.py index 4fa76eafa03..3c4bd316948 100644 --- a/torchrl/data/llm/history.py +++ b/torchrl/data/llm/history.py @@ -2,1501 +2,26 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import dataclasses +"""Backward-compatibility re-exports for conversation containers. -import re -from typing import Literal, TYPE_CHECKING +:class:`~tensordict.llm.History`, :class:`~tensordict.llm.ContentBase` and +:func:`~tensordict.llm.add_chat_template` now live in ``tensordict.llm``, +which is their canonical home. This module re-exports them so that existing +``torchrl.data.llm.history`` import paths keep working. -import torch +New code should import from :mod:`tensordict.llm` directly. +""" +from __future__ import annotations -from tensordict import ( - lazy_stack, - LazyStackedTensorDict, - list_to_stack, - TensorClass, - TensorDict, +from tensordict.llm.history import ( # noqa: F401 + _assistant_content_spans, + _CHAT_TEMPLATES, + _CUSTOM_INVERSE_PARSERS, + _CUSTOM_MODEL_FAMILY_KEYWORDS, + _fallback_assistant_tokens_mask, + add_chat_template, + ContentBase, + History, ) -from tensordict.utils import _maybe_correct_neg_dim -from torchrl._utils import logger as torchrl_logger - -if TYPE_CHECKING: - import transformers - - -# Global storage for custom templates and their metadata -_CHAT_TEMPLATES = { - "chatml_format": """{%- for message in messages %} -{%- if message['role'] == 'assistant' -%} -{{ '<|im_start|>' + message['role'] + '\n' }}{% generation %}{{ message['content'] }}{% endgeneration %}{{ '<|im_end|>\n' }} -{%- else -%} -{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }} -{%- endif %} -{%- endfor %} -{%- if add_generation_prompt -%} -{{ '<|im_start|>assistant\n' }} -{%- endif %} -""", - "qwen": """ -{%- if tools %} - {{- '<|im_start|>system\\n' }} - {%- if messages[0]['role'] == 'system' %} - {{- messages[0]['content'] }} - {%- else %} - {{- 'You are a helpful assistant.' }} - {%- endif %} - {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }} - {%- for tool in tools %} - {{- "\\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n" }} -{%- else %} - {%- if messages[0]['role'] == 'system' %} - {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }} - {%- else %} - {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }} - {%- endif %} -{%- endif %} -{%- for message in messages %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} - {%- elif (message.role == "assistant" and not message.tool_calls) %} - {{- '<|im_start|>' + message.role + '\\n' }}{% generation %}{{- message.content }}{% endgeneration %}{{- '<|im_end|>' + '\\n' }} - {%- elif message.role == "assistant" %} - {{- '<|im_start|>' + message.role }} - {%- if message.content %} - {{- '\\n' }}{% generation %}{{- message.content }}{% endgeneration %} - {%- endif %} - {%- for tool_call in message.tool_calls %} - {% generation %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\\n\\n{\\\"name\\\": \\\"' }} - {{- tool_call.name }} - {{- '\\\", \\\"arguments\\\": ' }} - {{- tool_call.arguments | tojson }} - {{- '}\\n' }} - {%- endgeneration %} - {%- endfor %} - {{- '<|im_end|>\\n' }} - {%- elif message.role == "tool" %} - {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>tool' }} - {%- endif %} - {{- '\\n\\n' }} - {%- if message.tool_responses %} - {{- message.tool_responses }} - {%- else %} - {{- message.content }} - {%- endif %} - {{- '\\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\\n' }} -{%- endif %} -""", - "dialogpt": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'user' %}{{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' ' }}{% endif %}""", - "falcon": """{% for message in messages %}{% if message['role'] == 'assistant' %}{{ 'Assistant:' }}{% generation %}{{ ' ' + message['content'] }}{% endgeneration %}\n\n{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] }}\n\n{% elif message['role'] == 'system' %}{{ message['content'] }}\n\n{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}""", - "deepseek": """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant:' }}{% generation %}{{ ' ' + message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}""", - "llama": """{{- bos_token }} -{%- if messages[0]['role'] == 'system' %} - {%- set system_message = messages[0]['content']|trim %} - {%- set messages = messages[1:] %} -{%- else %} - {%- set system_message = "" %} -{%- endif %} -{%- if system_message %} - {{- "<|header_start|>system<|header_end|>\n\n" }} - {{- system_message }} - {{- "<|eot|>" }} -{%- endif %} -{%- for message in messages %} - {%- if message['role'] == 'assistant' %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}{% generation %} - {%- if message['content'] is string %} - {{- message['content'] }} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- content['text'] | trim }} - {%- endif %} - {%- endfor %} - {%- endif %} - {%- endgeneration %}{{- "<|eot|>" }} - {%- else %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} - {%- if message['content'] is string %} - {{- message['content'] }} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- content['text'] | trim }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- "<|eot|>" }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|header_start|>assistant<|header_end|>\n\n' }} -{%- endif %}""", -} - -# Global storage for custom template metadata -_CUSTOM_INVERSE_PARSERS = {} -_CUSTOM_MODEL_FAMILY_KEYWORDS = {} - - -def _assistant_content_spans( - rendered: str, conversation: list[dict] -) -> list[tuple[int, int]]: - spans = [] - cursor = 0 - for message in conversation: - if message.get("role") != "assistant": - continue - content = message.get("content") - if not isinstance(content, str) or not content: - continue - start = rendered.find(content, cursor) - if start < 0: - start = rendered.find(content) - if start < 0: - continue - end = start + len(content) - spans.append((start, end)) - cursor = end - return spans - - -def _fallback_assistant_tokens_mask( - *, - tokenizer, - rendered: str, - conversation: list[dict], - input_ids: torch.Tensor, - current_mask: torch.Tensor, -) -> torch.Tensor: - if current_mask.any(): - return current_mask - spans = _assistant_content_spans(rendered, conversation) - if not spans: - return current_mask - - try: - encoded = tokenizer( - rendered, - add_special_tokens=False, - return_offsets_mapping=True, - ) - offsets = encoded.get("offset_mapping", None) - except NotImplementedError: - offsets = None - if input_ids.ndim == 2: - if input_ids.shape[0] != 1: - target_mask = None - fallback_mask = None - elif offsets is None or len(offsets) != input_ids.shape[-1]: - target_mask = current_mask[0] - fallback_mask = None - else: - target_mask = current_mask[0] - fallback_mask = torch.zeros_like(target_mask) - elif input_ids.ndim == 1: - if offsets is None or len(offsets) != input_ids.shape[-1]: - target_mask = current_mask - fallback_mask = None - else: - target_mask = current_mask - fallback_mask = torch.zeros_like(target_mask) - else: - return current_mask - - if fallback_mask is not None: - for idx, (token_start, token_end) in enumerate(offsets): - if token_start == token_end: - continue - for span_start, span_end in spans: - if token_start < span_end and token_end > span_start: - fallback_mask[idx] = 1 - break - - if fallback_mask is None or not fallback_mask.any(): - if target_mask is None: - return current_mask - fallback_mask = torch.zeros_like(target_mask) - for span_start, span_end in spans: - prefix_ids = tokenizer(rendered[:span_start], add_special_tokens=False).get( - "input_ids" - ) - prefix_and_content_ids = tokenizer( - rendered[:span_end], add_special_tokens=False - ).get("input_ids") - start_idx = len(prefix_ids) - end_idx = len(prefix_and_content_ids) - if start_idx < end_idx <= fallback_mask.shape[-1]: - fallback_mask[start_idx:end_idx] = 1 - - if input_ids.ndim == 2: - current_mask = current_mask.clone() - current_mask[0] = fallback_mask - return current_mask - return fallback_mask - - -def add_chat_template( - template_name: str, - template: str, - inverse_parser: callable | None = None, - model_family_keywords: list[str] | None = None, -) -> None: - r"""Add a custom chat template to the global template dictionary. - - This function allows you to add custom chat templates for new model families - that support assistant token masking via the `{% generation %}` keyword. - - Args: - template_name (str): The name of the template (e.g., "llama", "mistral"). - This name will be used in the `chat_template_name` parameter of - `History.apply_chat_template()` and `History.from_text()`. - template (str): The Jinja2 template string. Must include `{% generation %}` - blocks around assistant message content to enable token masking. - inverse_parser (callable, optional): A function that parses formatted text back - into a History object. Should have signature `(text: str) -> History`. - If None, a basic parser will be used. - model_family_keywords (list[str], optional): Keywords to detect this model family - in the auto-detection logic. For example, ["llama", "meta-llama"] for Llama models. - If provided, the template will be automatically selected for models containing - these keywords in their name. - - Example: - >>> from torchrl.data.llm.chat import add_chat_template, History - >>> from transformers import AutoTokenizer - >>> - >>> # Add a custom template for Llama models - >>> llama_template = ''' - ... {% for message in messages %} - ... {%- if message['role'] == 'user' %} - ... {{ '[INST] ' + message['content'] + ' [/INST]' }} - ... {%- elif message['role'] == 'assistant' %} - ... {% generation %}{{ message['content'] + '' }}{% endgeneration %} - ... {%- endif %} - ... {% endfor %} - ... {%- if add_generation_prompt %} - ... {% generation %}{{ ' ' }}{% endgeneration %} - ... {%- endif %} - ... ''' - >>> - >>> def parse_llama_text(text: str) -> History: - ... # Custom parser for Llama format - ... import re - ... pattern = r'\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)' - ... matches = re.findall(pattern, text, re.DOTALL) - ... messages = [] - ... for user_content, assistant_content in matches: - ... messages.append(History(role="user", content=user_content.strip())) - ... messages.append(History(role="assistant", content=assistant_content.strip())) - ... return lazy_stack(messages) - >>> - >>> # Add the template with auto-detection - >>> add_chat_template( - ... template_name="llama", - ... template=llama_template, - ... inverse_parser=parse_llama_text, - ... model_family_keywords=["llama", "meta-llama"] - ... ) - >>> - >>> # Now you can use it with auto-detection - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"} - ... ]]) - >>> - >>> # Auto-detection will use the llama template - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - >>> - >>> # Or use it explicitly - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... chat_template_name="llama", - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - - .. note: - - The template must include `{% generation %}` blocks around assistant message - content to enable assistant token masking. - - The inverse parser should handle the specific format of your template. - - Model family keywords are case-insensitive and matched against the tokenizer's - `name_or_path` attribute. - - Templates are stored globally and persist for the duration of the Python session. - """ - global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS # noqa: F824 - - # Validate template contains generation blocks - if "{% generation %}" not in template: - raise ValueError( - f"Template '{template_name}' must include '{{% generation %}}' blocks " - "around assistant message content to enable token masking." - ) - - # Add template to dictionary - _CHAT_TEMPLATES[template_name] = template - - # Store inverse parser if provided - if inverse_parser is not None: - _CUSTOM_INVERSE_PARSERS[template_name] = inverse_parser - - # Store model family keywords if provided - if model_family_keywords is not None: - _CUSTOM_MODEL_FAMILY_KEYWORDS[template_name] = model_family_keywords - - torchrl_logger.info( - f"Added custom chat template '{template_name}' with assistant token masking support" - ) - - -# We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields -class ContentBase(TensorClass["nocast", "shadow"]): - """Base class for all message content types. - - Attributes: - type (str): The type of the content. - text (str, optional): The text content. - url (str, optional): The URL content. - data (str, optional): The data content. - mime_type (str, optional): The MIME type of the content. - name (str, optional): The name of the content. - size (int, optional): The size of the content. - function_name (str, optional): The name of the function. - function_args (dict, optional): The arguments of the function. - - Examples: - >>> from tensordict import lazy_stack - >>> content1 = ContentBase(type="text", text="Hello, world!") - >>> print(content1) - ContentBase( - text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None), - type=NonTensorData(data=text, batch_size=torch.Size([]), device=None), - url=None, - data=None, - mime_type=None, - name=None, - size=None, - function_name=None, - function_args=None, - batch_size=torch.Size([]), - device=None, - is_shared=False) - >>> content2 = ContentBase(type="image", url="https://example.com/image.jpg") - >>> print(content2) - ContentBase( - type=NonTensorData(data=image, batch_size=torch.Size([]), device=None), - url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None), - text=None, - data=None, - mime_type=None, - name=None, - size=None, - function_name=None, - function_args=None, - batch_size=torch.Size([]), - device=None, - is_shared=False) - >>> content = lazy_stack([content1, content2]) - >>> print(content) - ContentBase( - type=NonTensorStack( - ['text', 'image'], - batch_size=torch.Size([2]), - device=None), - url=None, - data=None, - mime_type=None, - name=None, - size=None, - function_name=None, - function_args=None, - text=None, - batch_size=torch.Size([2]), - device=None, - is_shared=False) - >>> # A content is typically used in a History object. Usually, its batch dimension is - >>> # one dimension greater than the History object. - >>> history = History(role="user", content=content) - - """ - - type: Literal[ - "text", "image", "audio", "video", "file", "function_call" - ] # Required: "text", "image", "audio", "video", "file", "function_call" - - # Text content - text: str | None = None - - # Media/file content (either URL or data) - url: str | None = None # HTTP URL to content - data: str | None = None # Base64 encoded content - - # Metadata - mime_type: str | None = None # "image/jpeg", "audio/mp3", "application/pdf" - name: str | None = None # Original filename or description - size: int | None = None # File size in bytes - - # Function calling (for AI agents) - function_name: str | None = None - function_args: dict | None = None - - -class History(TensorClass["nocast"]): - """A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models. - - The `History` class provides a centralized API for managing conversational data, offering several advantages over - traditional list-based approaches: - - - Centralized API for conversion to and from string formats, facilitating seamless integration with language models. - - Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation - trajectories, especially useful in reinforcement learning environments. - - Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data. - - **Assistant token masking support** across multiple model families for reinforcement learning applications. - - **Recent Changes:** - - **ChatHistory Integration**: History objects are now used within :class:`~torchrl.modules.llm.policies.ChatHistory` - containers for structured conversation management in LLM environments. - - **Modular Wrapper Support**: Both vLLMWrapper and TransformersWrapper now use History objects when `input_mode="history"` - is specified, providing consistent conversation state management. - - **Environment Integration**: ChatEnv and related environments use History objects for state management and conversation tracking. - - .. note:: The `""` role is used to indicate that the element is a placeholder, - for example when the tool call was not executed but a stack requires a certain number of elements - per batch to have congruent shapes. The :meth:`~torchrl.data.llm.chat.History.apply_chat_template` - method will remove the `` role from the history. - - **Assistant Token Masking Support:** - - The class supports assistant token masking across multiple model families, allowing you to identify which tokens - in a conversation were generated by the assistant. This is crucial for reinforcement learning applications. - - **Supported Model Families:** - - - **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support - - **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format - - **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format - - **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format - - **Other models** (OPT, GPT, MPT, BLOOM, Pythia, Phi, etc.): Default `chatml_format` template - - **Example with Assistant Token Masking:** - - .. code-block:: python - - >>> from torchrl.data.llm.chat import History - >>> from torchrl.modules.llm.policies import ChatHistory - >>> from transformers import AutoTokenizer - >>> - >>> # Create a conversation history - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"}, - ... {"role": "user", "content": "How are you?"}, - ... {"role": "assistant", "content": "I'm doing well, thanks!"} - ... ]]) - >>> - >>> # Create ChatHistory container for LLM wrapper - >>> chat_history = ChatHistory(prompt=history) - >>> - >>> # Load any supported tokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") - >>> - >>> # Apply chat template with assistant token masking - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - >>> - >>> # The result contains an assistant_masks tensor - >>> assistant_masks = result["assistant_masks"] - >>> print(f"Assistant tokens: {assistant_masks.sum().item()}") - - **Integration with LLM Wrappers:** - - History objects work seamlessly with the new modular wrapper design: - - .. code-block:: python - - >>> from torchrl.modules.llm import TransformersWrapper - >>> from torchrl.modules.llm.policies import ChatHistory - >>> - >>> # Create wrapper with history input mode - >>> wrapper = TransformersWrapper( - ... model, tokenizer=tokenizer, - ... input_mode="history", - ... generate=True, - ... return_log_probs=True - ... ) - >>> - >>> # Use History with ChatHistory container - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"} - ... ]]) - >>> chat_history = ChatHistory(prompt=history) - >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,))) - >>> print(result["history"].response) # New response from LLM - - Attributes: - role (str): The role of the message sender. - content (str): The content of the message. - is_complete (bool): Whether the message was properly terminated with an end token. Defaults to `True`. - tool_calls (list[dict] | None): Optional list of tool calls in the message. - tool_responses (list[str] | None): Optional list of tool responses. - - Methods: - apply_chat_template: converts the `History` object to str / tokens. - append: append one element to the list of items along a given dimension. - extend: extend the list of items along a given dimension. - - Examples: - >>> # With tensordict < 0.10, we need to tell the lib that lists constitute batches - >>> import tensordict - >>> tensordict.set_list_to_stack(True).set() - >>> import transformers - >>> history0 = History( - ... role='system', - ... content='''CONTENT - ... This is the setup''', - ... ) - >>> history1 = History( - ... role='user', - ... content='''CONTENT - ... This is the first user prompt''', - ... ) - >>> history2 = History( - ... role='assistant', - ... content='''CONTENT - ... This is the second prompt, the first for the assistant.''', - ... ) - >>> history = torch.stack([history0, history1, history2]) - >>> assert history.role == ['system', 'user', 'assistant'] - >>> tokenizer = transformers.AutoTokenizer.from_pretrained("GPT2") - >>> # Apply a template to pass the history to an LLM. Note that the output has - >>> # an additional prompt to elict an answer from the LLM thanks to the 'add_generation_prompt' argument. - >>> parsed_string = history.apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True) - >>> parsed_string - <|im_start|>system - CONTENT - This is the setup<|im_end|> - - <|im_start|>user - CONTENT - This is the first user prompt<|im_end|> - - <|im_start|>assistant - CONTENT - This is the second prompt, the first for the assistant.<|im_end|> - - <|im_start|>assistant - - .. seealso:: - :class:`~torchrl.modules.llm.policies.ChatHistory`: Container for managing conversation data in LLM environments. - :class:`~torchrl.modules.llm.policies.Text`: Container for text data. - :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data. - """ - - role: str | list[str] | list[list[str]] - content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[ - list[ContentBase] - ] - is_complete: bool = True - tool_calls: list[dict] | None = None - tool_responses: list[str] | None = None - - def __post_init__(self): - if not list_to_stack(): - raise RuntimeError( - "Please set the list_to_stack to True using tensordict.set_list_to_stack(True).set() at the beginning of your script, " - "or the LIST_TO_STACK=1 environment variable." - ) - - def apply_chat_template( - self, - *, - tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa - add_generation_prompt: bool = True, - chat_template: str | None = None, - chat_template_name: str | None = None, - continue_final_message: bool = False, - tokenize: bool | None = None, - padding: bool | str = False, - truncation: bool | str = False, - return_tensors: str | None = None, - return_dict: bool | None = None, - return_assistant_tokens_mask: bool = False, - **kwargs, - ) -> str | list[str] | TensorDict: - """Applies a chat template to the history. - - Keyword Args: - tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use. - add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`. - chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template. - chat_template_name (str, optional): The name of the chat template to use. - Prevalent over `tokenizer.chat_template`. If `None`, the method will automatically detect the model family and use the appropriate template. - Defaults to `None`. - continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`. - tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`. - padding (bool | str, optional): The padding strategy to use. Defaults to `False`. - truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`. - return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt". - return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`. - return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens. - If `True`, the mask will be written to the `assistant_masks` key. - For tokens generated by the assistant, the mask will contain `1`. - For user and system tokens, the mask will contain `0`. - This functionality is only available for chat templates that support it via the `{% generation %}` keyword. - Defaults to `False`. - - .. note:: Assistant token masking is supported across multiple model families: - - **Qwen family**: Uses custom template with full tool calling support - - **DialoGPT family**: Uses custom template for conversation format - - **Falcon family**: Uses custom template for instruction format - - **DeepSeek family**: Uses custom template with native format - - **Other models**: Use the default `chatml_format` template - - The method automatically detects the model family and selects the appropriate template. - - **kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method. - - Returns: - The formatted history. - """ - if chat_template is None: - if chat_template_name is not None: - chat_template = _CHAT_TEMPLATES[chat_template_name] - chat_template_name = None - elif tokenizer is None: - raise RuntimeError( - "You must specify a tokenizer to use when chat_template is not specified." - ) - else: - # Auto-detect model family and use appropriate template - model_name = getattr(tokenizer, "name_or_path", "").lower() - - # First check for custom model family keywords - custom_template_found = False - for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items(): - if any(keyword.lower() in model_name for keyword in keywords): - chat_template = _CHAT_TEMPLATES[template_name] - chat_template_name = None - custom_template_found = True - break - - if not custom_template_found: - # Fall back to built-in model family detection - if "qwen" in model_name: - # We prefer our implementation of the Qwen template, - # since it accounts for the assistant's masking. - chat_template = _CHAT_TEMPLATES["qwen"] - chat_template_name = None - elif "dialogpt" in model_name or "microsoft/dialo" in model_name: - # DialoGPT family - use our custom template - chat_template = _CHAT_TEMPLATES["dialogpt"] - chat_template_name = None - elif "falcon" in model_name or "tiiuae/falcon" in model_name: - # Falcon family - use our custom template - chat_template = _CHAT_TEMPLATES["falcon"] - chat_template_name = None - elif "deepseek" in model_name: - # DeepSeek family - use our custom template with generation keyword - chat_template = _CHAT_TEMPLATES["deepseek"] - chat_template_name = None - elif "llama" in model_name: - # Llama family - use our custom template - chat_template = _CHAT_TEMPLATES["llama"] - chat_template_name = None - else: - # For other models, check if their default template supports generation - if ( - hasattr(tokenizer, "chat_template") - and tokenizer.chat_template - and "{% generation %}" in tokenizer.chat_template - ): - # Use the model's own template if it supports generation - chat_template = tokenizer.chat_template - else: - # Use our default chatml_format template - chat_template = _CHAT_TEMPLATES["chatml_format"] - if chat_template is None: - chat_template = _CHAT_TEMPLATES["chatml_format"] - if tokenize is None: - if return_assistant_tokens_mask or return_tensors is not None: - tokenize = True - else: - tokenize = False - if tokenize: - if return_tensors is None: - return_tensors = "pt" - if return_dict is None and return_assistant_tokens_mask: - return_dict = True - elif return_dict is None: - return_dict = False - - if self.ndim > 1: - result = [ - self[i].apply_chat_template( - tokenizer=tokenizer, - add_generation_prompt=add_generation_prompt, - chat_template=chat_template, - chat_template_name=chat_template_name, - tokenize=tokenize, - padding=padding, - truncation=truncation, - return_tensors=return_tensors, - continue_final_message=continue_final_message, - return_dict=return_dict, - return_assistant_tokens_mask=return_assistant_tokens_mask, - **kwargs, - ) - for i in range(self.batch_size[0]) - ] - if return_dict: - return lazy_stack(result) - else: - return result - self_flat = self.view(-1) - # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts - self_flat = self_flat.tolist(tolist_first=True) - # Remove the "" role - self_flat = [item for item in self_flat if item["role"] != ""] - result = tokenizer.apply_chat_template( - conversation=self_flat, - add_generation_prompt=add_generation_prompt, - chat_template=chat_template, - tokenize=tokenize, - padding=padding, - truncation=truncation, - return_tensors=return_tensors, - continue_final_message=continue_final_message, - return_dict=return_dict, - return_assistant_tokens_mask=return_assistant_tokens_mask, - **kwargs, - ) - if ( - return_assistant_tokens_mask - and not isinstance(result, (torch.Tensor, list, str)) - and "assistant_masks" in result - and "input_ids" in result - ): - rendered = tokenizer.apply_chat_template( - conversation=self_flat, - add_generation_prompt=add_generation_prompt, - chat_template=chat_template, - tokenize=False, - continue_final_message=continue_final_message, - ) - if ( - isinstance(rendered, str) - and isinstance(result["assistant_masks"], torch.Tensor) - and isinstance(result["input_ids"], torch.Tensor) - ): - result["assistant_masks"] = _fallback_assistant_tokens_mask( - tokenizer=tokenizer, - rendered=rendered, - conversation=self_flat, - input_ids=result["input_ids"], - current_mask=result["assistant_masks"], - ) - if not isinstance(result, (torch.Tensor, list, str)): - result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1) - # If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result - if self.batch_dims == 1: - if result.batch_size[0] != 1: - raise RuntimeError( - f"Expected a batch size of 1, got {result.batch_size[0]}." - ) - result = result.squeeze(0) - return result - - @classmethod - def from_text( - cls, - text: str | list[str], - chat_template_name: str | None = None, - # currently without effect - chat_template: str | None = None, - tokenizer: transformers.AutoTokenizer # noqa: F821 - | transformers.AutoProcessor # noqa: F821 - | None = None, - ) -> History: - r"""Inverts a chat template into a History object. - - Args: - text (str | list[str]): The chat template to invert. - chat_template_name (str, optional): The name of the chat template to use. - tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use. - - Returns: - History: The inverted History object. - - Examples: - >>> from torchrl.data.llm.history import History - >>> from transformers import AutoTokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") - >>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris, the capital of Germany is Berlin.\n\n" - >>> history = History.from_text(text, tokenizer=tokenizer) - >>> print(history) - History( - content=NonTensorStack( - ['You are a helpful assistant.', 'Write a python s..., - batch_size=torch.Size([3]), - device=None), - is_complete=NonTensorStack( - [True, True, False], - batch_size=torch.Size([3]), - device=None), - role=NonTensorStack( - ['system', 'user', 'assistant'], - batch_size=torch.Size([3]), - device=None), - tool_calls=None, - tool_responses=None, - batch_size=torch.Size([3]), - device=None, - is_shared=False) - """ - if chat_template_name is None: - if chat_template is not None: - # TODO: find best match given template - pass - - model_name = getattr(tokenizer, "name_or_path", "").lower() - # First check for custom model family keywords - custom_template_found = False - for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items(): - if any(keyword.lower() in model_name for keyword in keywords): - chat_template_name = template_name - custom_template_found = True - break - - if not custom_template_found: - # Fall back to built-in model family detection - if "qwen" in model_name: - # We can automatically detect the template name from the tokenizer - # and use the precoded parser. - chat_template_name = "qwen" - elif "dialogpt" in model_name or "microsoft/dialo" in model_name: - chat_template_name = "dialogpt" - elif "falcon" in model_name or "tiiuae/falcon" in model_name: - chat_template_name = "falcon" - elif "deepseek" in model_name: - chat_template_name = "deepseek" - elif "llama" in model_name: - chat_template_name = "llama" - else: - chat_template_name = "chatml_format" - - # Get the appropriate inverse parser function - if chat_template_name in ("chatml_format",): - func = cls._inv_chatml - elif chat_template_name in ("qwen",): - func = cls._inv_qwen - elif chat_template_name in ("dialogpt",): - func = cls._inv_dialogpt - elif chat_template_name in ("falcon",): - func = cls._inv_falcon - elif chat_template_name in ("deepseek",): - func = cls._inv_deepseek - elif chat_template_name in ("llama",): - func = cls._inv_llama - elif chat_template_name in _CUSTOM_INVERSE_PARSERS: - # Use custom inverse parser - func = _CUSTOM_INVERSE_PARSERS[chat_template_name] - else: - raise NotImplementedError( - f"chat_template_name '{chat_template_name}' is not supported. " - "Supported templates: 'chatml_format', 'qwen', 'dialogpt', 'falcon', 'deepseek'. " - "Use add_chat_template() to add custom templates." - ) - if isinstance(text, list): - list_of_histories = [func(t) for t in text] - try: - return lazy_stack(list_of_histories) - except RuntimeError as e: - raise RuntimeError( - f"Failed to stack histories: {list_of_histories=}" - ) from e - return func(text) - - @classmethod - def _inv_chatml(cls, text: str) -> History: - """Inverts a chatml string into a History object. - - Args: - text (str): The chatml string to invert. - - Returns: - History: The inverted History object. - """ - import json - - torchrl_logger.debug(f"Inverting chatml:\n{text}") - # Find all complete blocks (ending with im_end or endoftext) - complete_pattern = r"<\|im_start\|>(.*?)\n(.*?)<\|(im_end|endoftext)\|>" - complete_matches = re.findall(complete_pattern, text, flags=re.DOTALL) - - # Find any incomplete block at the end - incomplete_pattern = r"<\|im_start\|>(.*?)\n(.*?)$" - incomplete_matches = [] - if complete_matches: - # Look for incomplete block after the last complete one - last_complete = complete_matches[-1] - last_complete_text = f"<|im_start|>{last_complete[0]}\n{last_complete[1]}<|{last_complete[2]}|>" - remaining_text = text[ - text.rindex(last_complete_text) + len(last_complete_text) : - ] - if remaining_text.strip(): - incomplete_match = re.search( - incomplete_pattern, remaining_text, flags=re.DOTALL - ) - if incomplete_match: - incomplete_matches = [ - (incomplete_match.group(1), incomplete_match.group(2), None) - ] - else: - # No complete blocks, check entire text for incomplete block - incomplete_match = re.search(incomplete_pattern, text, flags=re.DOTALL) - if incomplete_match: - incomplete_matches = [ - (incomplete_match.group(1), incomplete_match.group(2), None) - ] - - # Combine complete and incomplete matches - matches = complete_matches + incomplete_matches - - # Define tool patterns - same as Qwen for consistency - tool_call_pattern = re.compile(r"\n(.*?)\n", re.DOTALL) - tool_response_pattern = re.compile( - r"\n(.*?)\n", re.DOTALL - ) - - parsed_messages = [] - for match in matches: - role = match[0].strip() - content = match[1].strip() - is_complete = match[2] is not None # None indicates incomplete - - # Initialize message dict - message_dict = { - "role": role, - "content": content, - "is_complete": is_complete, - "tool_calls": None, - "tool_responses": None, - } - - # Find tool calls within the message - tool_calls = tool_call_pattern.findall(content) - if tool_calls: - tool_calls_list = [] - for tool_call in tool_calls: - try: - tool_call_dict = json.loads(tool_call) - tool_calls_list.append(tool_call_dict) - except json.JSONDecodeError: - continue - if tool_calls_list: - message_dict["tool_calls"] = tool_calls_list - - # Check for tool responses - tool_responses = tool_response_pattern.findall(content) - if tool_responses: - message_dict["tool_responses"] = tool_responses - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError( - f"Couldn't get a single item out of text {text}. A common cause " - f"if that special tokens should not be omitted, did you set include_stop_str_in_output/skip_special_tokens=False?" - ) - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_qwen(cls, template): - import json - - # Define regex patterns for different parts of the template - message_pattern = re.compile( - r"<\|im_start\|>(.*?)(?:<\|(im_end|endoftext)\|>|$)", re.DOTALL - ) - tool_call_pattern = re.compile(r"\n(.*?)\n", re.DOTALL) - tool_response_pattern = re.compile( - r"\n(.*?)\n", re.DOTALL - ) - - # Find all messages and track if they end with a proper token - messages = [] - is_complete_list = [] - for match in message_pattern.finditer(template): - full_match = match.group(0) - messages.append(match.group(1)) - # Check if the message ends with a proper token - is_complete_list.append( - full_match.endswith("<|im_end|>") - or full_match.endswith("<|endoftext|>") - ) - - parsed_messages = [] - for message, is_complete in zip(messages, is_complete_list): - # Split the message into role and content - parts = message.split("\n", 1) - if len(parts) < 2: - continue - role, content = parts[0], parts[1] - - # Initialize message dict - message_dict = { - "role": role.strip(), - "content": content.strip(), - "is_complete": is_complete, - "tool_calls": None, - "tool_responses": None, - } - - # Find tool calls within the message - tool_calls = tool_call_pattern.findall(content) - if tool_calls: - tool_calls_list = [] - for tool_call in tool_calls: - try: - tool_call_dict = json.loads(tool_call) - tool_calls_list.append(tool_call_dict) - except json.JSONDecodeError: - continue - if tool_calls_list: - message_dict["tool_calls"] = tool_calls_list - - # Check for tool responses - tool_responses = tool_response_pattern.findall(content) - if tool_responses: - message_dict["tool_responses"] = tool_responses - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError( - f"Couldn't get a single item out of text {template}. A common cause " - f"if that special tokens should not be omitted, did you set include_stop_str_in_output/skip_special_tokens=False?" - ) - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_dialogpt(cls, text: str) -> History: - """Inverts a DialogPT string into a History object. - - Args: - text (str): The DialogPT string to invert. - - Returns: - History: The inverted History object. - """ - torchrl_logger.debug(f"Inverting DialogPT:\n{text}") - - # DialogPT format is simple: alternating user/assistant messages - # Split by lines and parse - lines = text.strip().split("\n") - parsed_messages = [] - - for line in lines: - line = line.strip() - if not line: - continue - - # Determine role based on content - if line.startswith("Assistant:"): - role = "assistant" - content = line[len("Assistant:") :].strip() - elif line.startswith("User:"): - role = "user" - content = line[len("User:") :].strip() - else: - # Default to user if no prefix - role = "user" - content = line - - message_dict = { - "role": role, - "content": content, - "is_complete": True, # DialogPT doesn't have explicit end tokens - "tool_calls": None, - "tool_responses": None, - } - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError(f"Couldn't get a single item out of text {text}.") - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_falcon(cls, text: str) -> History: - """Inverts a Falcon string into a History object. - - Args: - text (str): The Falcon string to invert. - - Returns: - History: The inverted History object. - """ - torchrl_logger.debug(f"Inverting Falcon:\n{text}") - - # Falcon format: "User: ... Assistant: ..." - # Split by "User:" and "Assistant:" prefixes - import re - - # Pattern to match User: and Assistant: messages - pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)" - matches = re.findall(pattern, text, re.DOTALL) - - parsed_messages = [] - for match in matches: - if len(match) != 2: - continue - prefix, content = match - content = content.strip() - if not content: - continue - - if prefix == "User:": - role = "user" - elif prefix == "Assistant:": - role = "assistant" - else: - continue - - message_dict = { - "role": role, - "content": content, - "is_complete": True, # Falcon doesn't have explicit end tokens - "tool_calls": None, - "tool_responses": None, - } - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError(f"Couldn't get a single item out of text {text}.") - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_deepseek(cls, text: str) -> History: - """Inverts a DeepSeek string into a History object. - - Args: - text (str): The DeepSeek string to invert. - - Returns: - History: The inverted History object. - """ - torchrl_logger.debug(f"Inverting DeepSeek:\n{text}") - import re - - # Remove leading/trailing special tokens (e.g. - text = re.sub(r"^<[^>]+>", "", text) # Remove leading <...> - text = re.sub(r"<[^>]+>$", "", text) # Remove trailing <...> - # Remove any REDACTED_SPECIAL_TOKEN if present - text = re.sub(r"REDACTED_SPECIAL_TOKEN", "", text) - # Pattern to match User: and Assistant: messages - pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)" - matches = re.findall(pattern, text, re.DOTALL) - parsed_messages = [] - for match in matches: - if len(match) < 2: - continue - prefix, content = match[0], match[1] - content = content.strip() - if not content: - continue - if prefix == "User:": - role = "user" - elif prefix == "Assistant:": - role = "assistant" - else: - continue - message_dict = { - "role": role, - "content": content, - "is_complete": True, # DeepSeek doesn't have explicit end tokens - "tool_calls": None, - "tool_responses": None, - } - parsed_messages.append(cls(**message_dict)) - if not parsed_messages: - raise RuntimeError(f"Couldn't get a single item out of text {text}.") - return lazy_stack(parsed_messages) - - @classmethod - def _inv_llama(cls, text: str) -> History: - import re - - messages = [] - - # Remove BOS token if present - if text.startswith("<|begin_of_text|>"): - text = text[len("<|begin_of_text|>") :] - - # Pattern to match complete message blocks: <|header_start|>role<|header_end|>\n\ncontent<|eot|> - complete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)<\|eot\|>" - complete_matches = re.findall(complete_pattern, text, re.DOTALL) - - # Pattern to match incomplete message blocks: <|header_start|>role<|header_end|>\n\ncontent (without <|eot|>) - incomplete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)$" - - # Find any incomplete message at the end - incomplete_matches = [] - if complete_matches: - # Look for incomplete message after the last complete one - last_complete_end = text.rfind("<|eot|>") - if last_complete_end != -1: - remaining_text = text[last_complete_end + len("<|eot|>") :] - if remaining_text.strip(): - incomplete_match = re.search( - incomplete_pattern, remaining_text, re.DOTALL - ) - if incomplete_match: - incomplete_matches = [ - ( - incomplete_match.group(1), - incomplete_match.group(2), - False, - ) - ] - else: - # No complete messages, check entire text for incomplete message - incomplete_match = re.search(incomplete_pattern, text, re.DOTALL) - if incomplete_match: - incomplete_matches = [ - (incomplete_match.group(1), incomplete_match.group(2), False) - ] - - # Process complete messages - for role, content in complete_matches: - if content.strip(): - messages.append( - cls(role=role, content=content.strip(), is_complete=True) - ) - - # Process incomplete messages - for role, content, is_complete in incomplete_matches: - if content.strip(): - messages.append( - cls(role=role, content=content.strip(), is_complete=is_complete) - ) - - if not messages: - raise RuntimeError(f"Couldn't parse Llama format from text: {text}") - - from tensordict import lazy_stack - - return lazy_stack(messages) - - def append( - self, history: History, *, inplace: bool = True, dim: int = -1 - ) -> History: - """Appends a new history to the current one. - - Args: - history (History): The new history to append. - inplace (bool, optional): Whether to perform the operation in-place. Defaults to `True`. - dim (int, optional): The dimension to append along. Defaults to -1. - - Returns: - History: The appended History object. - """ - # TODO: we should remove the role from the history before appending / extending - # It works when keeping them, but it may lead to a lot of useless padding in between valid messages - if not self.batch_dims: - raise RuntimeError( - "Cannot append an element to a batchless History. Call unsqueeze(dim=0) first on self." - ) - if self.batch_dims != history.batch_dims + 1: - raise RuntimeError( - f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}." - ) - dim = _maybe_correct_neg_dim(dim, self.batch_size) - if inplace: - if ( - isinstance(self._tensordict, LazyStackedTensorDict) - and self._tensordict.stack_dim == dim - ): - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self._tensordict.append(td) - return self - else: - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - td = lazy_stack(list(self._tensordict.unbind(dim)) + [td], dim=dim) - self.__dict__["_tensordict"] = td - return self - if history.device != self.device: - if self.device is None: - history = history.copy().clear_device_() - else: - history = history.to(self.device) - return lazy_stack(list(self.unbind(dim)) + [history], dim=dim) - - def extend( - self, history: History, *, inplace: bool = True, dim: int = 0 - ) -> History: - if not self.batch_dims: - raise RuntimeError( - "Cannot add an element to a batchless History. Call unsqueeze(dim=0) first on self." - ) - if self.batch_dims != history.batch_dims: - raise RuntimeError( - f"The new history to extend must have as many dimensions as self. Got self.ndim={self.ndim} and history.ndim={self.ndim}." - ) - dim = _maybe_correct_neg_dim(dim, self.batch_size) - # if self.ndim > 1 and dim >= self.ndim - 1: - # # then we need to append each element independently - # result = [] - # for hist, new_hist in zip(self.unbind(0), history.unbind(0)): - # hist_c = hist.extend(new_hist, inplace=inplace, dim=dim - 1) - # result.append(hist_c) - # if inplace: - # return self - # return lazy_stack(result) - if inplace: - if ( - isinstance(self._tensordict, LazyStackedTensorDict) - and self._tensordict.stack_dim == dim - ): - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self._tensordict.extend(td) - return self - else: - td = lazy_stack( - list(self._tensordict.unbind(dim)) - + list(history._tensordict.unbind(dim)), - dim=dim, - ) - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self.__dict__["_tensordict"] = td - return self - if history.device != self.device: - if self.device is None: - history = history.copy().clear_device_() - else: - history = history.to(self.device) - return torch.stack(list(self.unbind(dim)) + list(history.unbind(dim)), dim=dim) - - @classmethod - def default_spec(cls, shape=(-1,)): - """A default spec to use in transforms / envs that return History objects. - - Args: - shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length - along the time dimension). - - Example: - >>> import tensordict - >>> from torchrl.data import History - >>> tensordict.set_list_to_stack(True).set() - >>> - >>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,)) - >>> spec = history.default_spec() - >>> print(spec) - Composite( - role: NonTensor( - shape=torch.Size([-1]), - space=None, - device=None, - dtype=None, - domain=None, - example_data=foo), - content: NonTensor( - shape=torch.Size([-1]), - space=None, - device=None, - dtype=None, - domain=None, - example_data=foo), - device=None, - shape=torch.Size([-1])) - >>> print(spec.zero()) - History( - content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), - role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), - batch_size=torch.Size([1]), - device=None, - is_shared=False) - - """ - from torchrl.data import Composite, NonTensor - - def get_default_value(field): - if field.default is not dataclasses.MISSING: - return field.default - elif field.type in (str, "str"): - return "foo" - else: - return None - - defaults = { - k: NonTensor( - example_data=get_default_value(cls.__dataclass_fields__[k]), - shape=shape, - ) - for k in cls.__dataclass_fields__ - } - - return Composite(defaults, shape=shape[:-1], data_cls=cls) - - @classmethod - def from_chats(cls, chats: list[list[dict]]) -> History: - """Create a History object from a list of chats. - Args: - chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries. - """ - if isinstance(chats[0], dict): - return lazy_stack([cls(**chat) for chat in chats]) - else: - return lazy_stack([cls.from_chats(chat) for chat in chats]) +__all__ = ["add_chat_template", "ContentBase", "History"] From 5021c1c36844de10e68250517b6a5249add9dc4a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 12 Jun 2026 16:30:45 +0100 Subject: [PATCH 2/2] [Refactor] Keep History spec machinery in torchrl tensordict cannot depend on torchrl's spec classes, so the spec side of the History move stays here: - history_default_spec() free function in the re-export shim is the canonical way to build a Composite spec for History objects; exported from torchrl.data.llm. - History.default_spec is attached to the class in the shim so the established classmethod API (mirroring ChatHistory/Text/Tokens) keeps working unchanged for existing callers. - Internal call sites (mlgym, policies.common) migrate to the free function; test_history_spec covers both paths. Co-Authored-By: Claude Fable 5 --- test/llm/test_data.py | 6 ++ torchrl/data/llm/__init__.py | 3 +- torchrl/data/llm/history.py | 92 +++++++++++++++++++++++++- torchrl/envs/llm/libs/mlgym.py | 6 +- torchrl/modules/llm/policies/common.py | 4 +- 5 files changed, 104 insertions(+), 7 deletions(-) diff --git a/test/llm/test_data.py b/test/llm/test_data.py index a7824511622..f5ecfc51d2e 100644 --- a/test/llm/test_data.py +++ b/test/llm/test_data.py @@ -29,6 +29,7 @@ _CUSTOM_MODEL_FAMILY_KEYWORDS, add_chat_template, ContentBase, + history_default_spec, ) from torchrl.data.llm.topk import TopKRewardSelector from torchrl.modules.llm.policies.common import _extract_responses_from_full_histories @@ -300,6 +301,11 @@ def test_history_spec(self): assert isinstance(r, History) assert spec.is_in(r) assert spec.is_in(history) + # The free function is the canonical API now that History lives in + # tensordict; the classmethod above is kept for backward compatibility. + spec_fn = history_default_spec() + assert spec_fn.is_in(history) + assert spec_fn.is_in(r) def test_content_base(self): from transformers import AutoProcessor diff --git a/torchrl/data/llm/__init__.py b/torchrl/data/llm/__init__.py index 77d15a3e01f..b446a4e0314 100644 --- a/torchrl/data/llm/__init__.py +++ b/torchrl/data/llm/__init__.py @@ -9,7 +9,7 @@ TensorDictTokenizer, TokenizedDatasetLoader, ) -from .history import add_chat_template, ContentBase, History +from .history import add_chat_template, ContentBase, History, history_default_spec from .prompt import PromptData, PromptTensorDictTokenizer from .reward import PairwiseDataset, RewardData from .topk import TopKRewardSelector @@ -30,5 +30,6 @@ "TokenizedDatasetLoader", "create_infinite_iterator", "get_dataloader", + "history_default_spec", "TopKRewardSelector", ] diff --git a/torchrl/data/llm/history.py b/torchrl/data/llm/history.py index 3c4bd316948..ea22b17afc9 100644 --- a/torchrl/data/llm/history.py +++ b/torchrl/data/llm/history.py @@ -10,9 +10,16 @@ ``torchrl.data.llm.history`` import paths keep working. New code should import from :mod:`tensordict.llm` directly. + +The spec-related functionality stays here: tensordict cannot depend on +torchrl's spec classes, so :func:`history_default_spec` (and the +backward-compatible ``History.default_spec`` classmethod it powers) are +defined in this module. """ from __future__ import annotations +import dataclasses + from tensordict.llm.history import ( # noqa: F401 _assistant_content_spans, _CHAT_TEMPLATES, @@ -24,4 +31,87 @@ History, ) -__all__ = ["add_chat_template", "ContentBase", "History"] +__all__ = ["add_chat_template", "ContentBase", "History", "history_default_spec"] + + +def _history_default_spec(cls, shape=(-1,)): + """A default spec to use in transforms / envs that return History objects. + + Args: + cls (type): The :class:`~tensordict.llm.History` class (or subclass) to build the spec for. + shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length + along the time dimension). + + .. seealso:: :func:`~torchrl.data.llm.history_default_spec`. + """ + # Composite/NonTensor cannot be imported at module level: torchrl.data is + # still initializing when this module is first imported. + from torchrl.data import Composite, NonTensor + + def get_default_value(field): + if field.default is not dataclasses.MISSING: + return field.default + elif field.type in (str, "str"): + return "foo" + else: + return None + + defaults = { + k: NonTensor( + example_data=get_default_value(cls.__dataclass_fields__[k]), + shape=shape, + ) + for k in cls.__dataclass_fields__ + } + + return Composite(defaults, shape=shape[:-1], data_cls=cls) + + +def history_default_spec(shape=(-1,)): + """A default Composite spec for :class:`~tensordict.llm.History` objects, to use in transforms / envs. + + Args: + shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length + along the time dimension). + + Example: + >>> import tensordict + >>> from torchrl.data.llm import history_default_spec + >>> tensordict.set_list_to_stack(True).set() + >>> + >>> spec = history_default_spec() + >>> print(spec) + Composite( + role: NonTensor( + shape=torch.Size([-1]), + space=None, + device=None, + dtype=None, + domain=None, + example_data=foo), + content: NonTensor( + shape=torch.Size([-1]), + space=None, + device=None, + dtype=None, + domain=None, + example_data=foo), + device=None, + shape=torch.Size([-1])) + >>> print(spec.zero()) + History( + content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), + role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), + batch_size=torch.Size([1]), + device=None, + is_shared=False) + + """ + return _history_default_spec(History, shape) + + +# `History` lives in tensordict, which cannot depend on torchrl's specs. +# Attach the `default_spec` classmethod here so that the established +# `History.default_spec()` API (mirroring ChatHistory/Text/Tokens) keeps +# working once torchrl is imported. +History.default_spec = classmethod(_history_default_spec) diff --git a/torchrl/envs/llm/libs/mlgym.py b/torchrl/envs/llm/libs/mlgym.py index 8527f069ccf..eec46702e89 100644 --- a/torchrl/envs/llm/libs/mlgym.py +++ b/torchrl/envs/llm/libs/mlgym.py @@ -23,7 +23,7 @@ from torchrl._utils import logger as torchrl_logger from torchrl.data import Choice, Composite, NonTensor -from torchrl.data.llm import History +from torchrl.data.llm import History, history_default_spec from torchrl.envs import ConditionalSkip, GymWrapper, Transform, TransformedEnv if TYPE_CHECKING: @@ -229,7 +229,7 @@ def set_environment_vars( env.add_commands(command_files) def transform_observation_spec(self, observation_spec: Composite) -> Composite: - observation_spec["history"] = History.default_spec() + observation_spec["history"] = history_default_spec() return observation_spec def transform_action_spec(self, action_spec: Composite) -> Composite: @@ -247,7 +247,7 @@ def transform_action_spec(self, action_spec: Composite) -> Composite: ) def transform_state_spec(self, state_spec: Composite) -> Composite: - state_spec["history"] = History.default_spec() + state_spec["history"] = history_default_spec() return state_spec diff --git a/torchrl/modules/llm/policies/common.py b/torchrl/modules/llm/policies/common.py index 461310a3363..4174f30e8d3 100644 --- a/torchrl/modules/llm/policies/common.py +++ b/torchrl/modules/llm/policies/common.py @@ -21,7 +21,7 @@ from torch import distributions as D from torch.distributions import Categorical from torch.nn.utils.rnn import pad_sequence -from torchrl.data.llm import History +from torchrl.data.llm import History, history_default_spec from torchrl.data.tensor_specs import Unbounded from torchrl.modules.distributions.discrete import LLMMaskedCategorical @@ -276,7 +276,7 @@ def default_spec( if keys is None: keys = ["prompt", "response", "full"] return Composite( - {k: History.default_spec(shape=shape + (-1,)) for k in keys}, + {k: history_default_spec(shape + (-1,)) for k in keys}, shape=shape[:-1], data_cls=cls, step_mdp_static=True,