diff --git a/test/llm/test_data.py b/test/llm/test_data.py index a7824511622..f5ecfc51d2e 100644 --- a/test/llm/test_data.py +++ b/test/llm/test_data.py @@ -29,6 +29,7 @@ _CUSTOM_MODEL_FAMILY_KEYWORDS, add_chat_template, ContentBase, + history_default_spec, ) from torchrl.data.llm.topk import TopKRewardSelector from torchrl.modules.llm.policies.common import _extract_responses_from_full_histories @@ -300,6 +301,11 @@ def test_history_spec(self): assert isinstance(r, History) assert spec.is_in(r) assert spec.is_in(history) + # The free function is the canonical API now that History lives in + # tensordict; the classmethod above is kept for backward compatibility. + spec_fn = history_default_spec() + assert spec_fn.is_in(history) + assert spec_fn.is_in(r) def test_content_base(self): from transformers import AutoProcessor diff --git a/torchrl/data/llm/__init__.py b/torchrl/data/llm/__init__.py index 77d15a3e01f..b446a4e0314 100644 --- a/torchrl/data/llm/__init__.py +++ b/torchrl/data/llm/__init__.py @@ -9,7 +9,7 @@ TensorDictTokenizer, TokenizedDatasetLoader, ) -from .history import add_chat_template, ContentBase, History +from .history import add_chat_template, ContentBase, History, history_default_spec from .prompt import PromptData, PromptTensorDictTokenizer from .reward import PairwiseDataset, RewardData from .topk import TopKRewardSelector @@ -30,5 +30,6 @@ "TokenizedDatasetLoader", "create_infinite_iterator", "get_dataloader", + "history_default_spec", "TopKRewardSelector", ] diff --git a/torchrl/data/llm/history.py b/torchrl/data/llm/history.py index 4fa76eafa03..ea22b17afc9 100644 --- a/torchrl/data/llm/history.py +++ b/torchrl/data/llm/history.py @@ -2,1501 +2,116 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import dataclasses +"""Backward-compatibility re-exports for conversation containers. -import re -from typing import Literal, TYPE_CHECKING - -import torch - -from tensordict import ( - lazy_stack, - LazyStackedTensorDict, - list_to_stack, - TensorClass, - TensorDict, -) -from tensordict.utils import _maybe_correct_neg_dim -from torchrl._utils import logger as torchrl_logger +:class:`~tensordict.llm.History`, :class:`~tensordict.llm.ContentBase` and +:func:`~tensordict.llm.add_chat_template` now live in ``tensordict.llm``, +which is their canonical home. This module re-exports them so that existing +``torchrl.data.llm.history`` import paths keep working. -if TYPE_CHECKING: - import transformers +New code should import from :mod:`tensordict.llm` directly. +The spec-related functionality stays here: tensordict cannot depend on +torchrl's spec classes, so :func:`history_default_spec` (and the +backward-compatible ``History.default_spec`` classmethod it powers) are +defined in this module. +""" +from __future__ import annotations -# Global storage for custom templates and their metadata -_CHAT_TEMPLATES = { - "chatml_format": """{%- for message in messages %} -{%- if message['role'] == 'assistant' -%} -{{ '<|im_start|>' + message['role'] + '\n' }}{% generation %}{{ message['content'] }}{% endgeneration %}{{ '<|im_end|>\n' }} -{%- else -%} -{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }} -{%- endif %} -{%- endfor %} -{%- if add_generation_prompt -%} -{{ '<|im_start|>assistant\n' }} -{%- endif %} -""", - "qwen": """ -{%- if tools %} - {{- '<|im_start|>system\\n' }} - {%- if messages[0]['role'] == 'system' %} - {{- messages[0]['content'] }} - {%- else %} - {{- 'You are a helpful assistant.' }} - {%- endif %} - {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }} - {%- for tool in tools %} - {{- "\\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n" }} -{%- else %} - {%- if messages[0]['role'] == 'system' %} - {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }} - {%- else %} - {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }} - {%- endif %} -{%- endif %} -{%- for message in messages %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} - {%- elif (message.role == "assistant" and not message.tool_calls) %} - {{- '<|im_start|>' + message.role + '\\n' }}{% generation %}{{- message.content }}{% endgeneration %}{{- '<|im_end|>' + '\\n' }} - {%- elif message.role == "assistant" %} - {{- '<|im_start|>' + message.role }} - {%- if message.content %} - {{- '\\n' }}{% generation %}{{- message.content }}{% endgeneration %} - {%- endif %} - {%- for tool_call in message.tool_calls %} - {% generation %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\\n\\n{\\\"name\\\": \\\"' }} - {{- tool_call.name }} - {{- '\\\", \\\"arguments\\\": ' }} - {{- tool_call.arguments | tojson }} - {{- '}\\n' }} - {%- endgeneration %} - {%- endfor %} - {{- '<|im_end|>\\n' }} - {%- elif message.role == "tool" %} - {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>tool' }} - {%- endif %} - {{- '\\n\\n' }} - {%- if message.tool_responses %} - {{- message.tool_responses }} - {%- else %} - {{- message.content }} - {%- endif %} - {{- '\\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\\n' }} -{%- endif %} -""", - "dialogpt": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'user' %}{{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' ' }}{% endif %}""", - "falcon": """{% for message in messages %}{% if message['role'] == 'assistant' %}{{ 'Assistant:' }}{% generation %}{{ ' ' + message['content'] }}{% endgeneration %}\n\n{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] }}\n\n{% elif message['role'] == 'system' %}{{ message['content'] }}\n\n{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}""", - "deepseek": """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant:' }}{% generation %}{{ ' ' + message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}""", - "llama": """{{- bos_token }} -{%- if messages[0]['role'] == 'system' %} - {%- set system_message = messages[0]['content']|trim %} - {%- set messages = messages[1:] %} -{%- else %} - {%- set system_message = "" %} -{%- endif %} -{%- if system_message %} - {{- "<|header_start|>system<|header_end|>\n\n" }} - {{- system_message }} - {{- "<|eot|>" }} -{%- endif %} -{%- for message in messages %} - {%- if message['role'] == 'assistant' %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}{% generation %} - {%- if message['content'] is string %} - {{- message['content'] }} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- content['text'] | trim }} - {%- endif %} - {%- endfor %} - {%- endif %} - {%- endgeneration %}{{- "<|eot|>" }} - {%- else %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} - {%- if message['content'] is string %} - {{- message['content'] }} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- content['text'] | trim }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- "<|eot|>" }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|header_start|>assistant<|header_end|>\n\n' }} -{%- endif %}""", -} +import dataclasses -# Global storage for custom template metadata -_CUSTOM_INVERSE_PARSERS = {} -_CUSTOM_MODEL_FAMILY_KEYWORDS = {} +from tensordict.llm.history import ( # noqa: F401 + _assistant_content_spans, + _CHAT_TEMPLATES, + _CUSTOM_INVERSE_PARSERS, + _CUSTOM_MODEL_FAMILY_KEYWORDS, + _fallback_assistant_tokens_mask, + add_chat_template, + ContentBase, + History, +) +__all__ = ["add_chat_template", "ContentBase", "History", "history_default_spec"] -def _assistant_content_spans( - rendered: str, conversation: list[dict] -) -> list[tuple[int, int]]: - spans = [] - cursor = 0 - for message in conversation: - if message.get("role") != "assistant": - continue - content = message.get("content") - if not isinstance(content, str) or not content: - continue - start = rendered.find(content, cursor) - if start < 0: - start = rendered.find(content) - if start < 0: - continue - end = start + len(content) - spans.append((start, end)) - cursor = end - return spans +def _history_default_spec(cls, shape=(-1,)): + """A default spec to use in transforms / envs that return History objects. -def _fallback_assistant_tokens_mask( - *, - tokenizer, - rendered: str, - conversation: list[dict], - input_ids: torch.Tensor, - current_mask: torch.Tensor, -) -> torch.Tensor: - if current_mask.any(): - return current_mask - spans = _assistant_content_spans(rendered, conversation) - if not spans: - return current_mask + Args: + cls (type): The :class:`~tensordict.llm.History` class (or subclass) to build the spec for. + shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length + along the time dimension). - try: - encoded = tokenizer( - rendered, - add_special_tokens=False, - return_offsets_mapping=True, - ) - offsets = encoded.get("offset_mapping", None) - except NotImplementedError: - offsets = None - if input_ids.ndim == 2: - if input_ids.shape[0] != 1: - target_mask = None - fallback_mask = None - elif offsets is None or len(offsets) != input_ids.shape[-1]: - target_mask = current_mask[0] - fallback_mask = None - else: - target_mask = current_mask[0] - fallback_mask = torch.zeros_like(target_mask) - elif input_ids.ndim == 1: - if offsets is None or len(offsets) != input_ids.shape[-1]: - target_mask = current_mask - fallback_mask = None + .. seealso:: :func:`~torchrl.data.llm.history_default_spec`. + """ + # Composite/NonTensor cannot be imported at module level: torchrl.data is + # still initializing when this module is first imported. + from torchrl.data import Composite, NonTensor + + def get_default_value(field): + if field.default is not dataclasses.MISSING: + return field.default + elif field.type in (str, "str"): + return "foo" else: - target_mask = current_mask - fallback_mask = torch.zeros_like(target_mask) - else: - return current_mask - - if fallback_mask is not None: - for idx, (token_start, token_end) in enumerate(offsets): - if token_start == token_end: - continue - for span_start, span_end in spans: - if token_start < span_end and token_end > span_start: - fallback_mask[idx] = 1 - break - - if fallback_mask is None or not fallback_mask.any(): - if target_mask is None: - return current_mask - fallback_mask = torch.zeros_like(target_mask) - for span_start, span_end in spans: - prefix_ids = tokenizer(rendered[:span_start], add_special_tokens=False).get( - "input_ids" - ) - prefix_and_content_ids = tokenizer( - rendered[:span_end], add_special_tokens=False - ).get("input_ids") - start_idx = len(prefix_ids) - end_idx = len(prefix_and_content_ids) - if start_idx < end_idx <= fallback_mask.shape[-1]: - fallback_mask[start_idx:end_idx] = 1 + return None - if input_ids.ndim == 2: - current_mask = current_mask.clone() - current_mask[0] = fallback_mask - return current_mask - return fallback_mask + defaults = { + k: NonTensor( + example_data=get_default_value(cls.__dataclass_fields__[k]), + shape=shape, + ) + for k in cls.__dataclass_fields__ + } + return Composite(defaults, shape=shape[:-1], data_cls=cls) -def add_chat_template( - template_name: str, - template: str, - inverse_parser: callable | None = None, - model_family_keywords: list[str] | None = None, -) -> None: - r"""Add a custom chat template to the global template dictionary. - This function allows you to add custom chat templates for new model families - that support assistant token masking via the `{% generation %}` keyword. +def history_default_spec(shape=(-1,)): + """A default Composite spec for :class:`~tensordict.llm.History` objects, to use in transforms / envs. Args: - template_name (str): The name of the template (e.g., "llama", "mistral"). - This name will be used in the `chat_template_name` parameter of - `History.apply_chat_template()` and `History.from_text()`. - template (str): The Jinja2 template string. Must include `{% generation %}` - blocks around assistant message content to enable token masking. - inverse_parser (callable, optional): A function that parses formatted text back - into a History object. Should have signature `(text: str) -> History`. - If None, a basic parser will be used. - model_family_keywords (list[str], optional): Keywords to detect this model family - in the auto-detection logic. For example, ["llama", "meta-llama"] for Llama models. - If provided, the template will be automatically selected for models containing - these keywords in their name. + shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length + along the time dimension). Example: - >>> from torchrl.data.llm.chat import add_chat_template, History - >>> from transformers import AutoTokenizer - >>> - >>> # Add a custom template for Llama models - >>> llama_template = ''' - ... {% for message in messages %} - ... {%- if message['role'] == 'user' %} - ... {{ '[INST] ' + message['content'] + ' [/INST]' }} - ... {%- elif message['role'] == 'assistant' %} - ... {% generation %}{{ message['content'] + '' }}{% endgeneration %} - ... {%- endif %} - ... {% endfor %} - ... {%- if add_generation_prompt %} - ... {% generation %}{{ ' ' }}{% endgeneration %} - ... {%- endif %} - ... ''' - >>> - >>> def parse_llama_text(text: str) -> History: - ... # Custom parser for Llama format - ... import re - ... pattern = r'\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)' - ... matches = re.findall(pattern, text, re.DOTALL) - ... messages = [] - ... for user_content, assistant_content in matches: - ... messages.append(History(role="user", content=user_content.strip())) - ... messages.append(History(role="assistant", content=assistant_content.strip())) - ... return lazy_stack(messages) - >>> - >>> # Add the template with auto-detection - >>> add_chat_template( - ... template_name="llama", - ... template=llama_template, - ... inverse_parser=parse_llama_text, - ... model_family_keywords=["llama", "meta-llama"] - ... ) - >>> - >>> # Now you can use it with auto-detection - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"} - ... ]]) - >>> - >>> # Auto-detection will use the llama template - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) + >>> import tensordict + >>> from torchrl.data.llm import history_default_spec + >>> tensordict.set_list_to_stack(True).set() >>> - >>> # Or use it explicitly - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... chat_template_name="llama", - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - - .. note: - - The template must include `{% generation %}` blocks around assistant message - content to enable assistant token masking. - - The inverse parser should handle the specific format of your template. - - Model family keywords are case-insensitive and matched against the tokenizer's - `name_or_path` attribute. - - Templates are stored globally and persist for the duration of the Python session. - """ - global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS # noqa: F824 - - # Validate template contains generation blocks - if "{% generation %}" not in template: - raise ValueError( - f"Template '{template_name}' must include '{{% generation %}}' blocks " - "around assistant message content to enable token masking." - ) - - # Add template to dictionary - _CHAT_TEMPLATES[template_name] = template - - # Store inverse parser if provided - if inverse_parser is not None: - _CUSTOM_INVERSE_PARSERS[template_name] = inverse_parser - - # Store model family keywords if provided - if model_family_keywords is not None: - _CUSTOM_MODEL_FAMILY_KEYWORDS[template_name] = model_family_keywords - - torchrl_logger.info( - f"Added custom chat template '{template_name}' with assistant token masking support" - ) - - -# We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields -class ContentBase(TensorClass["nocast", "shadow"]): - """Base class for all message content types. - - Attributes: - type (str): The type of the content. - text (str, optional): The text content. - url (str, optional): The URL content. - data (str, optional): The data content. - mime_type (str, optional): The MIME type of the content. - name (str, optional): The name of the content. - size (int, optional): The size of the content. - function_name (str, optional): The name of the function. - function_args (dict, optional): The arguments of the function. - - Examples: - >>> from tensordict import lazy_stack - >>> content1 = ContentBase(type="text", text="Hello, world!") - >>> print(content1) - ContentBase( - text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None), - type=NonTensorData(data=text, batch_size=torch.Size([]), device=None), - url=None, - data=None, - mime_type=None, - name=None, - size=None, - function_name=None, - function_args=None, - batch_size=torch.Size([]), - device=None, - is_shared=False) - >>> content2 = ContentBase(type="image", url="https://example.com/image.jpg") - >>> print(content2) - ContentBase( - type=NonTensorData(data=image, batch_size=torch.Size([]), device=None), - url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None), - text=None, - data=None, - mime_type=None, - name=None, - size=None, - function_name=None, - function_args=None, - batch_size=torch.Size([]), + >>> spec = history_default_spec() + >>> print(spec) + Composite( + role: NonTensor( + shape=torch.Size([-1]), + space=None, + device=None, + dtype=None, + domain=None, + example_data=foo), + content: NonTensor( + shape=torch.Size([-1]), + space=None, + device=None, + dtype=None, + domain=None, + example_data=foo), device=None, - is_shared=False) - >>> content = lazy_stack([content1, content2]) - >>> print(content) - ContentBase( - type=NonTensorStack( - ['text', 'image'], - batch_size=torch.Size([2]), - device=None), - url=None, - data=None, - mime_type=None, - name=None, - size=None, - function_name=None, - function_args=None, - text=None, - batch_size=torch.Size([2]), + shape=torch.Size([-1])) + >>> print(spec.zero()) + History( + content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), + role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), + batch_size=torch.Size([1]), device=None, is_shared=False) - >>> # A content is typically used in a History object. Usually, its batch dimension is - >>> # one dimension greater than the History object. - >>> history = History(role="user", content=content) - - """ - - type: Literal[ - "text", "image", "audio", "video", "file", "function_call" - ] # Required: "text", "image", "audio", "video", "file", "function_call" - - # Text content - text: str | None = None - - # Media/file content (either URL or data) - url: str | None = None # HTTP URL to content - data: str | None = None # Base64 encoded content - # Metadata - mime_type: str | None = None # "image/jpeg", "audio/mp3", "application/pdf" - name: str | None = None # Original filename or description - size: int | None = None # File size in bytes - - # Function calling (for AI agents) - function_name: str | None = None - function_args: dict | None = None - - -class History(TensorClass["nocast"]): - """A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models. - - The `History` class provides a centralized API for managing conversational data, offering several advantages over - traditional list-based approaches: - - - Centralized API for conversion to and from string formats, facilitating seamless integration with language models. - - Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation - trajectories, especially useful in reinforcement learning environments. - - Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data. - - **Assistant token masking support** across multiple model families for reinforcement learning applications. - - **Recent Changes:** - - **ChatHistory Integration**: History objects are now used within :class:`~torchrl.modules.llm.policies.ChatHistory` - containers for structured conversation management in LLM environments. - - **Modular Wrapper Support**: Both vLLMWrapper and TransformersWrapper now use History objects when `input_mode="history"` - is specified, providing consistent conversation state management. - - **Environment Integration**: ChatEnv and related environments use History objects for state management and conversation tracking. - - .. note:: The `""` role is used to indicate that the element is a placeholder, - for example when the tool call was not executed but a stack requires a certain number of elements - per batch to have congruent shapes. The :meth:`~torchrl.data.llm.chat.History.apply_chat_template` - method will remove the `` role from the history. - - **Assistant Token Masking Support:** - - The class supports assistant token masking across multiple model families, allowing you to identify which tokens - in a conversation were generated by the assistant. This is crucial for reinforcement learning applications. - - **Supported Model Families:** - - - **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support - - **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format - - **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format - - **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format - - **Other models** (OPT, GPT, MPT, BLOOM, Pythia, Phi, etc.): Default `chatml_format` template - - **Example with Assistant Token Masking:** - - .. code-block:: python - - >>> from torchrl.data.llm.chat import History - >>> from torchrl.modules.llm.policies import ChatHistory - >>> from transformers import AutoTokenizer - >>> - >>> # Create a conversation history - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"}, - ... {"role": "user", "content": "How are you?"}, - ... {"role": "assistant", "content": "I'm doing well, thanks!"} - ... ]]) - >>> - >>> # Create ChatHistory container for LLM wrapper - >>> chat_history = ChatHistory(prompt=history) - >>> - >>> # Load any supported tokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") - >>> - >>> # Apply chat template with assistant token masking - >>> result = history.apply_chat_template( - ... tokenizer=tokenizer, - ... add_generation_prompt=False, - ... return_dict=True, - ... return_assistant_tokens_mask=True, - ... ) - >>> - >>> # The result contains an assistant_masks tensor - >>> assistant_masks = result["assistant_masks"] - >>> print(f"Assistant tokens: {assistant_masks.sum().item()}") - - **Integration with LLM Wrappers:** - - History objects work seamlessly with the new modular wrapper design: - - .. code-block:: python - - >>> from torchrl.modules.llm import TransformersWrapper - >>> from torchrl.modules.llm.policies import ChatHistory - >>> - >>> # Create wrapper with history input mode - >>> wrapper = TransformersWrapper( - ... model, tokenizer=tokenizer, - ... input_mode="history", - ... generate=True, - ... return_log_probs=True - ... ) - >>> - >>> # Use History with ChatHistory container - >>> history = History.from_chats([[ - ... {"role": "user", "content": "Hello"}, - ... {"role": "assistant", "content": "Hi there!"} - ... ]]) - >>> chat_history = ChatHistory(prompt=history) - >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,))) - >>> print(result["history"].response) # New response from LLM - - Attributes: - role (str): The role of the message sender. - content (str): The content of the message. - is_complete (bool): Whether the message was properly terminated with an end token. Defaults to `True`. - tool_calls (list[dict] | None): Optional list of tool calls in the message. - tool_responses (list[str] | None): Optional list of tool responses. - - Methods: - apply_chat_template: converts the `History` object to str / tokens. - append: append one element to the list of items along a given dimension. - extend: extend the list of items along a given dimension. - - Examples: - >>> # With tensordict < 0.10, we need to tell the lib that lists constitute batches - >>> import tensordict - >>> tensordict.set_list_to_stack(True).set() - >>> import transformers - >>> history0 = History( - ... role='system', - ... content='''CONTENT - ... This is the setup''', - ... ) - >>> history1 = History( - ... role='user', - ... content='''CONTENT - ... This is the first user prompt''', - ... ) - >>> history2 = History( - ... role='assistant', - ... content='''CONTENT - ... This is the second prompt, the first for the assistant.''', - ... ) - >>> history = torch.stack([history0, history1, history2]) - >>> assert history.role == ['system', 'user', 'assistant'] - >>> tokenizer = transformers.AutoTokenizer.from_pretrained("GPT2") - >>> # Apply a template to pass the history to an LLM. Note that the output has - >>> # an additional prompt to elict an answer from the LLM thanks to the 'add_generation_prompt' argument. - >>> parsed_string = history.apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True) - >>> parsed_string - <|im_start|>system - CONTENT - This is the setup<|im_end|> - - <|im_start|>user - CONTENT - This is the first user prompt<|im_end|> - - <|im_start|>assistant - CONTENT - This is the second prompt, the first for the assistant.<|im_end|> - - <|im_start|>assistant - - .. seealso:: - :class:`~torchrl.modules.llm.policies.ChatHistory`: Container for managing conversation data in LLM environments. - :class:`~torchrl.modules.llm.policies.Text`: Container for text data. - :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data. """ + return _history_default_spec(History, shape) - role: str | list[str] | list[list[str]] - content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[ - list[ContentBase] - ] - is_complete: bool = True - tool_calls: list[dict] | None = None - tool_responses: list[str] | None = None - - def __post_init__(self): - if not list_to_stack(): - raise RuntimeError( - "Please set the list_to_stack to True using tensordict.set_list_to_stack(True).set() at the beginning of your script, " - "or the LIST_TO_STACK=1 environment variable." - ) - - def apply_chat_template( - self, - *, - tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa - add_generation_prompt: bool = True, - chat_template: str | None = None, - chat_template_name: str | None = None, - continue_final_message: bool = False, - tokenize: bool | None = None, - padding: bool | str = False, - truncation: bool | str = False, - return_tensors: str | None = None, - return_dict: bool | None = None, - return_assistant_tokens_mask: bool = False, - **kwargs, - ) -> str | list[str] | TensorDict: - """Applies a chat template to the history. - - Keyword Args: - tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use. - add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`. - chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template. - chat_template_name (str, optional): The name of the chat template to use. - Prevalent over `tokenizer.chat_template`. If `None`, the method will automatically detect the model family and use the appropriate template. - Defaults to `None`. - continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`. - tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`. - padding (bool | str, optional): The padding strategy to use. Defaults to `False`. - truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`. - return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt". - return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`. - return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens. - If `True`, the mask will be written to the `assistant_masks` key. - For tokens generated by the assistant, the mask will contain `1`. - For user and system tokens, the mask will contain `0`. - This functionality is only available for chat templates that support it via the `{% generation %}` keyword. - Defaults to `False`. - - .. note:: Assistant token masking is supported across multiple model families: - - **Qwen family**: Uses custom template with full tool calling support - - **DialoGPT family**: Uses custom template for conversation format - - **Falcon family**: Uses custom template for instruction format - - **DeepSeek family**: Uses custom template with native format - - **Other models**: Use the default `chatml_format` template - - The method automatically detects the model family and selects the appropriate template. - - **kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method. - - Returns: - The formatted history. - """ - if chat_template is None: - if chat_template_name is not None: - chat_template = _CHAT_TEMPLATES[chat_template_name] - chat_template_name = None - elif tokenizer is None: - raise RuntimeError( - "You must specify a tokenizer to use when chat_template is not specified." - ) - else: - # Auto-detect model family and use appropriate template - model_name = getattr(tokenizer, "name_or_path", "").lower() - - # First check for custom model family keywords - custom_template_found = False - for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items(): - if any(keyword.lower() in model_name for keyword in keywords): - chat_template = _CHAT_TEMPLATES[template_name] - chat_template_name = None - custom_template_found = True - break - - if not custom_template_found: - # Fall back to built-in model family detection - if "qwen" in model_name: - # We prefer our implementation of the Qwen template, - # since it accounts for the assistant's masking. - chat_template = _CHAT_TEMPLATES["qwen"] - chat_template_name = None - elif "dialogpt" in model_name or "microsoft/dialo" in model_name: - # DialoGPT family - use our custom template - chat_template = _CHAT_TEMPLATES["dialogpt"] - chat_template_name = None - elif "falcon" in model_name or "tiiuae/falcon" in model_name: - # Falcon family - use our custom template - chat_template = _CHAT_TEMPLATES["falcon"] - chat_template_name = None - elif "deepseek" in model_name: - # DeepSeek family - use our custom template with generation keyword - chat_template = _CHAT_TEMPLATES["deepseek"] - chat_template_name = None - elif "llama" in model_name: - # Llama family - use our custom template - chat_template = _CHAT_TEMPLATES["llama"] - chat_template_name = None - else: - # For other models, check if their default template supports generation - if ( - hasattr(tokenizer, "chat_template") - and tokenizer.chat_template - and "{% generation %}" in tokenizer.chat_template - ): - # Use the model's own template if it supports generation - chat_template = tokenizer.chat_template - else: - # Use our default chatml_format template - chat_template = _CHAT_TEMPLATES["chatml_format"] - if chat_template is None: - chat_template = _CHAT_TEMPLATES["chatml_format"] - if tokenize is None: - if return_assistant_tokens_mask or return_tensors is not None: - tokenize = True - else: - tokenize = False - if tokenize: - if return_tensors is None: - return_tensors = "pt" - if return_dict is None and return_assistant_tokens_mask: - return_dict = True - elif return_dict is None: - return_dict = False - - if self.ndim > 1: - result = [ - self[i].apply_chat_template( - tokenizer=tokenizer, - add_generation_prompt=add_generation_prompt, - chat_template=chat_template, - chat_template_name=chat_template_name, - tokenize=tokenize, - padding=padding, - truncation=truncation, - return_tensors=return_tensors, - continue_final_message=continue_final_message, - return_dict=return_dict, - return_assistant_tokens_mask=return_assistant_tokens_mask, - **kwargs, - ) - for i in range(self.batch_size[0]) - ] - if return_dict: - return lazy_stack(result) - else: - return result - self_flat = self.view(-1) - # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts - self_flat = self_flat.tolist(tolist_first=True) - # Remove the "" role - self_flat = [item for item in self_flat if item["role"] != ""] - result = tokenizer.apply_chat_template( - conversation=self_flat, - add_generation_prompt=add_generation_prompt, - chat_template=chat_template, - tokenize=tokenize, - padding=padding, - truncation=truncation, - return_tensors=return_tensors, - continue_final_message=continue_final_message, - return_dict=return_dict, - return_assistant_tokens_mask=return_assistant_tokens_mask, - **kwargs, - ) - if ( - return_assistant_tokens_mask - and not isinstance(result, (torch.Tensor, list, str)) - and "assistant_masks" in result - and "input_ids" in result - ): - rendered = tokenizer.apply_chat_template( - conversation=self_flat, - add_generation_prompt=add_generation_prompt, - chat_template=chat_template, - tokenize=False, - continue_final_message=continue_final_message, - ) - if ( - isinstance(rendered, str) - and isinstance(result["assistant_masks"], torch.Tensor) - and isinstance(result["input_ids"], torch.Tensor) - ): - result["assistant_masks"] = _fallback_assistant_tokens_mask( - tokenizer=tokenizer, - rendered=rendered, - conversation=self_flat, - input_ids=result["input_ids"], - current_mask=result["assistant_masks"], - ) - if not isinstance(result, (torch.Tensor, list, str)): - result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1) - # If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result - if self.batch_dims == 1: - if result.batch_size[0] != 1: - raise RuntimeError( - f"Expected a batch size of 1, got {result.batch_size[0]}." - ) - result = result.squeeze(0) - return result - @classmethod - def from_text( - cls, - text: str | list[str], - chat_template_name: str | None = None, - # currently without effect - chat_template: str | None = None, - tokenizer: transformers.AutoTokenizer # noqa: F821 - | transformers.AutoProcessor # noqa: F821 - | None = None, - ) -> History: - r"""Inverts a chat template into a History object. - - Args: - text (str | list[str]): The chat template to invert. - chat_template_name (str, optional): The name of the chat template to use. - tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use. - - Returns: - History: The inverted History object. - - Examples: - >>> from torchrl.data.llm.history import History - >>> from transformers import AutoTokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") - >>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris, the capital of Germany is Berlin.\n\n" - >>> history = History.from_text(text, tokenizer=tokenizer) - >>> print(history) - History( - content=NonTensorStack( - ['You are a helpful assistant.', 'Write a python s..., - batch_size=torch.Size([3]), - device=None), - is_complete=NonTensorStack( - [True, True, False], - batch_size=torch.Size([3]), - device=None), - role=NonTensorStack( - ['system', 'user', 'assistant'], - batch_size=torch.Size([3]), - device=None), - tool_calls=None, - tool_responses=None, - batch_size=torch.Size([3]), - device=None, - is_shared=False) - """ - if chat_template_name is None: - if chat_template is not None: - # TODO: find best match given template - pass - - model_name = getattr(tokenizer, "name_or_path", "").lower() - # First check for custom model family keywords - custom_template_found = False - for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items(): - if any(keyword.lower() in model_name for keyword in keywords): - chat_template_name = template_name - custom_template_found = True - break - - if not custom_template_found: - # Fall back to built-in model family detection - if "qwen" in model_name: - # We can automatically detect the template name from the tokenizer - # and use the precoded parser. - chat_template_name = "qwen" - elif "dialogpt" in model_name or "microsoft/dialo" in model_name: - chat_template_name = "dialogpt" - elif "falcon" in model_name or "tiiuae/falcon" in model_name: - chat_template_name = "falcon" - elif "deepseek" in model_name: - chat_template_name = "deepseek" - elif "llama" in model_name: - chat_template_name = "llama" - else: - chat_template_name = "chatml_format" - - # Get the appropriate inverse parser function - if chat_template_name in ("chatml_format",): - func = cls._inv_chatml - elif chat_template_name in ("qwen",): - func = cls._inv_qwen - elif chat_template_name in ("dialogpt",): - func = cls._inv_dialogpt - elif chat_template_name in ("falcon",): - func = cls._inv_falcon - elif chat_template_name in ("deepseek",): - func = cls._inv_deepseek - elif chat_template_name in ("llama",): - func = cls._inv_llama - elif chat_template_name in _CUSTOM_INVERSE_PARSERS: - # Use custom inverse parser - func = _CUSTOM_INVERSE_PARSERS[chat_template_name] - else: - raise NotImplementedError( - f"chat_template_name '{chat_template_name}' is not supported. " - "Supported templates: 'chatml_format', 'qwen', 'dialogpt', 'falcon', 'deepseek'. " - "Use add_chat_template() to add custom templates." - ) - if isinstance(text, list): - list_of_histories = [func(t) for t in text] - try: - return lazy_stack(list_of_histories) - except RuntimeError as e: - raise RuntimeError( - f"Failed to stack histories: {list_of_histories=}" - ) from e - return func(text) - - @classmethod - def _inv_chatml(cls, text: str) -> History: - """Inverts a chatml string into a History object. - - Args: - text (str): The chatml string to invert. - - Returns: - History: The inverted History object. - """ - import json - - torchrl_logger.debug(f"Inverting chatml:\n{text}") - # Find all complete blocks (ending with im_end or endoftext) - complete_pattern = r"<\|im_start\|>(.*?)\n(.*?)<\|(im_end|endoftext)\|>" - complete_matches = re.findall(complete_pattern, text, flags=re.DOTALL) - - # Find any incomplete block at the end - incomplete_pattern = r"<\|im_start\|>(.*?)\n(.*?)$" - incomplete_matches = [] - if complete_matches: - # Look for incomplete block after the last complete one - last_complete = complete_matches[-1] - last_complete_text = f"<|im_start|>{last_complete[0]}\n{last_complete[1]}<|{last_complete[2]}|>" - remaining_text = text[ - text.rindex(last_complete_text) + len(last_complete_text) : - ] - if remaining_text.strip(): - incomplete_match = re.search( - incomplete_pattern, remaining_text, flags=re.DOTALL - ) - if incomplete_match: - incomplete_matches = [ - (incomplete_match.group(1), incomplete_match.group(2), None) - ] - else: - # No complete blocks, check entire text for incomplete block - incomplete_match = re.search(incomplete_pattern, text, flags=re.DOTALL) - if incomplete_match: - incomplete_matches = [ - (incomplete_match.group(1), incomplete_match.group(2), None) - ] - - # Combine complete and incomplete matches - matches = complete_matches + incomplete_matches - - # Define tool patterns - same as Qwen for consistency - tool_call_pattern = re.compile(r"\n(.*?)\n", re.DOTALL) - tool_response_pattern = re.compile( - r"\n(.*?)\n", re.DOTALL - ) - - parsed_messages = [] - for match in matches: - role = match[0].strip() - content = match[1].strip() - is_complete = match[2] is not None # None indicates incomplete - - # Initialize message dict - message_dict = { - "role": role, - "content": content, - "is_complete": is_complete, - "tool_calls": None, - "tool_responses": None, - } - - # Find tool calls within the message - tool_calls = tool_call_pattern.findall(content) - if tool_calls: - tool_calls_list = [] - for tool_call in tool_calls: - try: - tool_call_dict = json.loads(tool_call) - tool_calls_list.append(tool_call_dict) - except json.JSONDecodeError: - continue - if tool_calls_list: - message_dict["tool_calls"] = tool_calls_list - - # Check for tool responses - tool_responses = tool_response_pattern.findall(content) - if tool_responses: - message_dict["tool_responses"] = tool_responses - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError( - f"Couldn't get a single item out of text {text}. A common cause " - f"if that special tokens should not be omitted, did you set include_stop_str_in_output/skip_special_tokens=False?" - ) - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_qwen(cls, template): - import json - - # Define regex patterns for different parts of the template - message_pattern = re.compile( - r"<\|im_start\|>(.*?)(?:<\|(im_end|endoftext)\|>|$)", re.DOTALL - ) - tool_call_pattern = re.compile(r"\n(.*?)\n", re.DOTALL) - tool_response_pattern = re.compile( - r"\n(.*?)\n", re.DOTALL - ) - - # Find all messages and track if they end with a proper token - messages = [] - is_complete_list = [] - for match in message_pattern.finditer(template): - full_match = match.group(0) - messages.append(match.group(1)) - # Check if the message ends with a proper token - is_complete_list.append( - full_match.endswith("<|im_end|>") - or full_match.endswith("<|endoftext|>") - ) - - parsed_messages = [] - for message, is_complete in zip(messages, is_complete_list): - # Split the message into role and content - parts = message.split("\n", 1) - if len(parts) < 2: - continue - role, content = parts[0], parts[1] - - # Initialize message dict - message_dict = { - "role": role.strip(), - "content": content.strip(), - "is_complete": is_complete, - "tool_calls": None, - "tool_responses": None, - } - - # Find tool calls within the message - tool_calls = tool_call_pattern.findall(content) - if tool_calls: - tool_calls_list = [] - for tool_call in tool_calls: - try: - tool_call_dict = json.loads(tool_call) - tool_calls_list.append(tool_call_dict) - except json.JSONDecodeError: - continue - if tool_calls_list: - message_dict["tool_calls"] = tool_calls_list - - # Check for tool responses - tool_responses = tool_response_pattern.findall(content) - if tool_responses: - message_dict["tool_responses"] = tool_responses - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError( - f"Couldn't get a single item out of text {template}. A common cause " - f"if that special tokens should not be omitted, did you set include_stop_str_in_output/skip_special_tokens=False?" - ) - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_dialogpt(cls, text: str) -> History: - """Inverts a DialogPT string into a History object. - - Args: - text (str): The DialogPT string to invert. - - Returns: - History: The inverted History object. - """ - torchrl_logger.debug(f"Inverting DialogPT:\n{text}") - - # DialogPT format is simple: alternating user/assistant messages - # Split by lines and parse - lines = text.strip().split("\n") - parsed_messages = [] - - for line in lines: - line = line.strip() - if not line: - continue - - # Determine role based on content - if line.startswith("Assistant:"): - role = "assistant" - content = line[len("Assistant:") :].strip() - elif line.startswith("User:"): - role = "user" - content = line[len("User:") :].strip() - else: - # Default to user if no prefix - role = "user" - content = line - - message_dict = { - "role": role, - "content": content, - "is_complete": True, # DialogPT doesn't have explicit end tokens - "tool_calls": None, - "tool_responses": None, - } - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError(f"Couldn't get a single item out of text {text}.") - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_falcon(cls, text: str) -> History: - """Inverts a Falcon string into a History object. - - Args: - text (str): The Falcon string to invert. - - Returns: - History: The inverted History object. - """ - torchrl_logger.debug(f"Inverting Falcon:\n{text}") - - # Falcon format: "User: ... Assistant: ..." - # Split by "User:" and "Assistant:" prefixes - import re - - # Pattern to match User: and Assistant: messages - pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)" - matches = re.findall(pattern, text, re.DOTALL) - - parsed_messages = [] - for match in matches: - if len(match) != 2: - continue - prefix, content = match - content = content.strip() - if not content: - continue - - if prefix == "User:": - role = "user" - elif prefix == "Assistant:": - role = "assistant" - else: - continue - - message_dict = { - "role": role, - "content": content, - "is_complete": True, # Falcon doesn't have explicit end tokens - "tool_calls": None, - "tool_responses": None, - } - - parsed_messages.append(cls(**message_dict)) - - if not parsed_messages: - raise RuntimeError(f"Couldn't get a single item out of text {text}.") - - return lazy_stack(parsed_messages) - - @classmethod - def _inv_deepseek(cls, text: str) -> History: - """Inverts a DeepSeek string into a History object. - - Args: - text (str): The DeepSeek string to invert. - - Returns: - History: The inverted History object. - """ - torchrl_logger.debug(f"Inverting DeepSeek:\n{text}") - import re - - # Remove leading/trailing special tokens (e.g. - text = re.sub(r"^<[^>]+>", "", text) # Remove leading <...> - text = re.sub(r"<[^>]+>$", "", text) # Remove trailing <...> - # Remove any REDACTED_SPECIAL_TOKEN if present - text = re.sub(r"REDACTED_SPECIAL_TOKEN", "", text) - # Pattern to match User: and Assistant: messages - pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)" - matches = re.findall(pattern, text, re.DOTALL) - parsed_messages = [] - for match in matches: - if len(match) < 2: - continue - prefix, content = match[0], match[1] - content = content.strip() - if not content: - continue - if prefix == "User:": - role = "user" - elif prefix == "Assistant:": - role = "assistant" - else: - continue - message_dict = { - "role": role, - "content": content, - "is_complete": True, # DeepSeek doesn't have explicit end tokens - "tool_calls": None, - "tool_responses": None, - } - parsed_messages.append(cls(**message_dict)) - if not parsed_messages: - raise RuntimeError(f"Couldn't get a single item out of text {text}.") - return lazy_stack(parsed_messages) - - @classmethod - def _inv_llama(cls, text: str) -> History: - import re - - messages = [] - - # Remove BOS token if present - if text.startswith("<|begin_of_text|>"): - text = text[len("<|begin_of_text|>") :] - - # Pattern to match complete message blocks: <|header_start|>role<|header_end|>\n\ncontent<|eot|> - complete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)<\|eot\|>" - complete_matches = re.findall(complete_pattern, text, re.DOTALL) - - # Pattern to match incomplete message blocks: <|header_start|>role<|header_end|>\n\ncontent (without <|eot|>) - incomplete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)$" - - # Find any incomplete message at the end - incomplete_matches = [] - if complete_matches: - # Look for incomplete message after the last complete one - last_complete_end = text.rfind("<|eot|>") - if last_complete_end != -1: - remaining_text = text[last_complete_end + len("<|eot|>") :] - if remaining_text.strip(): - incomplete_match = re.search( - incomplete_pattern, remaining_text, re.DOTALL - ) - if incomplete_match: - incomplete_matches = [ - ( - incomplete_match.group(1), - incomplete_match.group(2), - False, - ) - ] - else: - # No complete messages, check entire text for incomplete message - incomplete_match = re.search(incomplete_pattern, text, re.DOTALL) - if incomplete_match: - incomplete_matches = [ - (incomplete_match.group(1), incomplete_match.group(2), False) - ] - - # Process complete messages - for role, content in complete_matches: - if content.strip(): - messages.append( - cls(role=role, content=content.strip(), is_complete=True) - ) - - # Process incomplete messages - for role, content, is_complete in incomplete_matches: - if content.strip(): - messages.append( - cls(role=role, content=content.strip(), is_complete=is_complete) - ) - - if not messages: - raise RuntimeError(f"Couldn't parse Llama format from text: {text}") - - from tensordict import lazy_stack - - return lazy_stack(messages) - - def append( - self, history: History, *, inplace: bool = True, dim: int = -1 - ) -> History: - """Appends a new history to the current one. - - Args: - history (History): The new history to append. - inplace (bool, optional): Whether to perform the operation in-place. Defaults to `True`. - dim (int, optional): The dimension to append along. Defaults to -1. - - Returns: - History: The appended History object. - """ - # TODO: we should remove the role from the history before appending / extending - # It works when keeping them, but it may lead to a lot of useless padding in between valid messages - if not self.batch_dims: - raise RuntimeError( - "Cannot append an element to a batchless History. Call unsqueeze(dim=0) first on self." - ) - if self.batch_dims != history.batch_dims + 1: - raise RuntimeError( - f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}." - ) - dim = _maybe_correct_neg_dim(dim, self.batch_size) - if inplace: - if ( - isinstance(self._tensordict, LazyStackedTensorDict) - and self._tensordict.stack_dim == dim - ): - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self._tensordict.append(td) - return self - else: - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - td = lazy_stack(list(self._tensordict.unbind(dim)) + [td], dim=dim) - self.__dict__["_tensordict"] = td - return self - if history.device != self.device: - if self.device is None: - history = history.copy().clear_device_() - else: - history = history.to(self.device) - return lazy_stack(list(self.unbind(dim)) + [history], dim=dim) - - def extend( - self, history: History, *, inplace: bool = True, dim: int = 0 - ) -> History: - if not self.batch_dims: - raise RuntimeError( - "Cannot add an element to a batchless History. Call unsqueeze(dim=0) first on self." - ) - if self.batch_dims != history.batch_dims: - raise RuntimeError( - f"The new history to extend must have as many dimensions as self. Got self.ndim={self.ndim} and history.ndim={self.ndim}." - ) - dim = _maybe_correct_neg_dim(dim, self.batch_size) - # if self.ndim > 1 and dim >= self.ndim - 1: - # # then we need to append each element independently - # result = [] - # for hist, new_hist in zip(self.unbind(0), history.unbind(0)): - # hist_c = hist.extend(new_hist, inplace=inplace, dim=dim - 1) - # result.append(hist_c) - # if inplace: - # return self - # return lazy_stack(result) - if inplace: - if ( - isinstance(self._tensordict, LazyStackedTensorDict) - and self._tensordict.stack_dim == dim - ): - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self._tensordict.extend(td) - return self - else: - td = lazy_stack( - list(self._tensordict.unbind(dim)) - + list(history._tensordict.unbind(dim)), - dim=dim, - ) - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self.__dict__["_tensordict"] = td - return self - if history.device != self.device: - if self.device is None: - history = history.copy().clear_device_() - else: - history = history.to(self.device) - return torch.stack(list(self.unbind(dim)) + list(history.unbind(dim)), dim=dim) - - @classmethod - def default_spec(cls, shape=(-1,)): - """A default spec to use in transforms / envs that return History objects. - - Args: - shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length - along the time dimension). - - Example: - >>> import tensordict - >>> from torchrl.data import History - >>> tensordict.set_list_to_stack(True).set() - >>> - >>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,)) - >>> spec = history.default_spec() - >>> print(spec) - Composite( - role: NonTensor( - shape=torch.Size([-1]), - space=None, - device=None, - dtype=None, - domain=None, - example_data=foo), - content: NonTensor( - shape=torch.Size([-1]), - space=None, - device=None, - dtype=None, - domain=None, - example_data=foo), - device=None, - shape=torch.Size([-1])) - >>> print(spec.zero()) - History( - content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), - role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), - batch_size=torch.Size([1]), - device=None, - is_shared=False) - - """ - from torchrl.data import Composite, NonTensor - - def get_default_value(field): - if field.default is not dataclasses.MISSING: - return field.default - elif field.type in (str, "str"): - return "foo" - else: - return None - - defaults = { - k: NonTensor( - example_data=get_default_value(cls.__dataclass_fields__[k]), - shape=shape, - ) - for k in cls.__dataclass_fields__ - } - - return Composite(defaults, shape=shape[:-1], data_cls=cls) - - @classmethod - def from_chats(cls, chats: list[list[dict]]) -> History: - """Create a History object from a list of chats. - - Args: - chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries. - """ - if isinstance(chats[0], dict): - return lazy_stack([cls(**chat) for chat in chats]) - else: - return lazy_stack([cls.from_chats(chat) for chat in chats]) +# `History` lives in tensordict, which cannot depend on torchrl's specs. +# Attach the `default_spec` classmethod here so that the established +# `History.default_spec()` API (mirroring ChatHistory/Text/Tokens) keeps +# working once torchrl is imported. +History.default_spec = classmethod(_history_default_spec) diff --git a/torchrl/envs/llm/libs/mlgym.py b/torchrl/envs/llm/libs/mlgym.py index 8527f069ccf..eec46702e89 100644 --- a/torchrl/envs/llm/libs/mlgym.py +++ b/torchrl/envs/llm/libs/mlgym.py @@ -23,7 +23,7 @@ from torchrl._utils import logger as torchrl_logger from torchrl.data import Choice, Composite, NonTensor -from torchrl.data.llm import History +from torchrl.data.llm import History, history_default_spec from torchrl.envs import ConditionalSkip, GymWrapper, Transform, TransformedEnv if TYPE_CHECKING: @@ -229,7 +229,7 @@ def set_environment_vars( env.add_commands(command_files) def transform_observation_spec(self, observation_spec: Composite) -> Composite: - observation_spec["history"] = History.default_spec() + observation_spec["history"] = history_default_spec() return observation_spec def transform_action_spec(self, action_spec: Composite) -> Composite: @@ -247,7 +247,7 @@ def transform_action_spec(self, action_spec: Composite) -> Composite: ) def transform_state_spec(self, state_spec: Composite) -> Composite: - state_spec["history"] = History.default_spec() + state_spec["history"] = history_default_spec() return state_spec diff --git a/torchrl/modules/llm/policies/common.py b/torchrl/modules/llm/policies/common.py index 461310a3363..4174f30e8d3 100644 --- a/torchrl/modules/llm/policies/common.py +++ b/torchrl/modules/llm/policies/common.py @@ -21,7 +21,7 @@ from torch import distributions as D from torch.distributions import Categorical from torch.nn.utils.rnn import pad_sequence -from torchrl.data.llm import History +from torchrl.data.llm import History, history_default_spec from torchrl.data.tensor_specs import Unbounded from torchrl.modules.distributions.discrete import LLMMaskedCategorical @@ -276,7 +276,7 @@ def default_spec( if keys is None: keys = ["prompt", "response", "full"] return Composite( - {k: History.default_spec(shape=shape + (-1,)) for k in keys}, + {k: history_default_spec(shape + (-1,)) for k in keys}, shape=shape[:-1], data_cls=cls, step_mdp_static=True,