diff --git a/test/llm/test_data.py b/test/llm/test_data.py
index a7824511622..f5ecfc51d2e 100644
--- a/test/llm/test_data.py
+++ b/test/llm/test_data.py
@@ -29,6 +29,7 @@
_CUSTOM_MODEL_FAMILY_KEYWORDS,
add_chat_template,
ContentBase,
+ history_default_spec,
)
from torchrl.data.llm.topk import TopKRewardSelector
from torchrl.modules.llm.policies.common import _extract_responses_from_full_histories
@@ -300,6 +301,11 @@ def test_history_spec(self):
assert isinstance(r, History)
assert spec.is_in(r)
assert spec.is_in(history)
+ # The free function is the canonical API now that History lives in
+ # tensordict; the classmethod above is kept for backward compatibility.
+ spec_fn = history_default_spec()
+ assert spec_fn.is_in(history)
+ assert spec_fn.is_in(r)
def test_content_base(self):
from transformers import AutoProcessor
diff --git a/torchrl/data/llm/__init__.py b/torchrl/data/llm/__init__.py
index 77d15a3e01f..b446a4e0314 100644
--- a/torchrl/data/llm/__init__.py
+++ b/torchrl/data/llm/__init__.py
@@ -9,7 +9,7 @@
TensorDictTokenizer,
TokenizedDatasetLoader,
)
-from .history import add_chat_template, ContentBase, History
+from .history import add_chat_template, ContentBase, History, history_default_spec
from .prompt import PromptData, PromptTensorDictTokenizer
from .reward import PairwiseDataset, RewardData
from .topk import TopKRewardSelector
@@ -30,5 +30,6 @@
"TokenizedDatasetLoader",
"create_infinite_iterator",
"get_dataloader",
+ "history_default_spec",
"TopKRewardSelector",
]
diff --git a/torchrl/data/llm/history.py b/torchrl/data/llm/history.py
index 4fa76eafa03..ea22b17afc9 100644
--- a/torchrl/data/llm/history.py
+++ b/torchrl/data/llm/history.py
@@ -2,1501 +2,116 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from __future__ import annotations
-
-import dataclasses
+"""Backward-compatibility re-exports for conversation containers.
-import re
-from typing import Literal, TYPE_CHECKING
-
-import torch
-
-from tensordict import (
- lazy_stack,
- LazyStackedTensorDict,
- list_to_stack,
- TensorClass,
- TensorDict,
-)
-from tensordict.utils import _maybe_correct_neg_dim
-from torchrl._utils import logger as torchrl_logger
+:class:`~tensordict.llm.History`, :class:`~tensordict.llm.ContentBase` and
+:func:`~tensordict.llm.add_chat_template` now live in ``tensordict.llm``,
+which is their canonical home. This module re-exports them so that existing
+``torchrl.data.llm.history`` import paths keep working.
-if TYPE_CHECKING:
- import transformers
+New code should import from :mod:`tensordict.llm` directly.
+The spec-related functionality stays here: tensordict cannot depend on
+torchrl's spec classes, so :func:`history_default_spec` (and the
+backward-compatible ``History.default_spec`` classmethod it powers) are
+defined in this module.
+"""
+from __future__ import annotations
-# Global storage for custom templates and their metadata
-_CHAT_TEMPLATES = {
- "chatml_format": """{%- for message in messages %}
-{%- if message['role'] == 'assistant' -%}
-{{ '<|im_start|>' + message['role'] + '\n' }}{% generation %}{{ message['content'] }}{% endgeneration %}{{ '<|im_end|>\n' }}
-{%- else -%}
-{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
-{%- endif %}
-{%- endfor %}
-{%- if add_generation_prompt -%}
-{{ '<|im_start|>assistant\n' }}
-{%- endif %}
-""",
- "qwen": """
-{%- if tools %}
- {{- '<|im_start|>system\\n' }}
- {%- if messages[0]['role'] == 'system' %}
- {{- messages[0]['content'] }}
- {%- else %}
- {{- 'You are a helpful assistant.' }}
- {%- endif %}
- {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }}
- {%- for tool in tools %}
- {{- "\\n" }}
- {{- tool | tojson }}
- {%- endfor %}
- {{- "\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n" }}
-{%- else %}
- {%- if messages[0]['role'] == 'system' %}
- {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}
- {%- else %}
- {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}
- {%- endif %}
-{%- endif %}
-{%- for message in messages %}
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
- {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
- {%- elif (message.role == "assistant" and not message.tool_calls) %}
- {{- '<|im_start|>' + message.role + '\\n' }}{% generation %}{{- message.content }}{% endgeneration %}{{- '<|im_end|>' + '\\n' }}
- {%- elif message.role == "assistant" %}
- {{- '<|im_start|>' + message.role }}
- {%- if message.content %}
- {{- '\\n' }}{% generation %}{{- message.content }}{% endgeneration %}
- {%- endif %}
- {%- for tool_call in message.tool_calls %}
- {% generation %}
- {%- if tool_call.function is defined %}
- {%- set tool_call = tool_call.function %}
- {%- endif %}
- {{- '\\n\\n{\\\"name\\\": \\\"' }}
- {{- tool_call.name }}
- {{- '\\\", \\\"arguments\\\": ' }}
- {{- tool_call.arguments | tojson }}
- {{- '}\\n' }}
- {%- endgeneration %}
- {%- endfor %}
- {{- '<|im_end|>\\n' }}
- {%- elif message.role == "tool" %}
- {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
- {{- '<|im_start|>tool' }}
- {%- endif %}
- {{- '\\n\\n' }}
- {%- if message.tool_responses %}
- {{- message.tool_responses }}
- {%- else %}
- {{- message.content }}
- {%- endif %}
- {{- '\\n' }}
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
- {{- '<|im_end|>\\n' }}
- {%- endif %}
- {%- endif %}
-{%- endfor %}
-{%- if add_generation_prompt %}
- {{- '<|im_start|>assistant\\n' }}
-{%- endif %}
-""",
- "dialogpt": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'user' %}{{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' ' }}{% endif %}""",
- "falcon": """{% for message in messages %}{% if message['role'] == 'assistant' %}{{ 'Assistant:' }}{% generation %}{{ ' ' + message['content'] }}{% endgeneration %}\n\n{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] }}\n\n{% elif message['role'] == 'system' %}{{ message['content'] }}\n\n{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}""",
- "deepseek": """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant:' }}{% generation %}{{ ' ' + message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}""",
- "llama": """{{- bos_token }}
-{%- if messages[0]['role'] == 'system' %}
- {%- set system_message = messages[0]['content']|trim %}
- {%- set messages = messages[1:] %}
-{%- else %}
- {%- set system_message = "" %}
-{%- endif %}
-{%- if system_message %}
- {{- "<|header_start|>system<|header_end|>\n\n" }}
- {{- system_message }}
- {{- "<|eot|>" }}
-{%- endif %}
-{%- for message in messages %}
- {%- if message['role'] == 'assistant' %}
- {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}{% generation %}
- {%- if message['content'] is string %}
- {{- message['content'] }}
- {%- else %}
- {%- for content in message['content'] %}
- {%- if content['type'] == 'text' %}
- {{- content['text'] | trim }}
- {%- endif %}
- {%- endfor %}
- {%- endif %}
- {%- endgeneration %}{{- "<|eot|>" }}
- {%- else %}
- {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
- {%- if message['content'] is string %}
- {{- message['content'] }}
- {%- else %}
- {%- for content in message['content'] %}
- {%- if content['type'] == 'text' %}
- {{- content['text'] | trim }}
- {%- endif %}
- {%- endfor %}
- {%- endif %}
- {{- "<|eot|>" }}
- {%- endif %}
-{%- endfor %}
-{%- if add_generation_prompt %}
- {{- '<|header_start|>assistant<|header_end|>\n\n' }}
-{%- endif %}""",
-}
+import dataclasses
-# Global storage for custom template metadata
-_CUSTOM_INVERSE_PARSERS = {}
-_CUSTOM_MODEL_FAMILY_KEYWORDS = {}
+from tensordict.llm.history import ( # noqa: F401
+ _assistant_content_spans,
+ _CHAT_TEMPLATES,
+ _CUSTOM_INVERSE_PARSERS,
+ _CUSTOM_MODEL_FAMILY_KEYWORDS,
+ _fallback_assistant_tokens_mask,
+ add_chat_template,
+ ContentBase,
+ History,
+)
+__all__ = ["add_chat_template", "ContentBase", "History", "history_default_spec"]
-def _assistant_content_spans(
- rendered: str, conversation: list[dict]
-) -> list[tuple[int, int]]:
- spans = []
- cursor = 0
- for message in conversation:
- if message.get("role") != "assistant":
- continue
- content = message.get("content")
- if not isinstance(content, str) or not content:
- continue
- start = rendered.find(content, cursor)
- if start < 0:
- start = rendered.find(content)
- if start < 0:
- continue
- end = start + len(content)
- spans.append((start, end))
- cursor = end
- return spans
+def _history_default_spec(cls, shape=(-1,)):
+ """A default spec to use in transforms / envs that return History objects.
-def _fallback_assistant_tokens_mask(
- *,
- tokenizer,
- rendered: str,
- conversation: list[dict],
- input_ids: torch.Tensor,
- current_mask: torch.Tensor,
-) -> torch.Tensor:
- if current_mask.any():
- return current_mask
- spans = _assistant_content_spans(rendered, conversation)
- if not spans:
- return current_mask
+ Args:
+ cls (type): The :class:`~tensordict.llm.History` class (or subclass) to build the spec for.
+ shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length
+ along the time dimension).
- try:
- encoded = tokenizer(
- rendered,
- add_special_tokens=False,
- return_offsets_mapping=True,
- )
- offsets = encoded.get("offset_mapping", None)
- except NotImplementedError:
- offsets = None
- if input_ids.ndim == 2:
- if input_ids.shape[0] != 1:
- target_mask = None
- fallback_mask = None
- elif offsets is None or len(offsets) != input_ids.shape[-1]:
- target_mask = current_mask[0]
- fallback_mask = None
- else:
- target_mask = current_mask[0]
- fallback_mask = torch.zeros_like(target_mask)
- elif input_ids.ndim == 1:
- if offsets is None or len(offsets) != input_ids.shape[-1]:
- target_mask = current_mask
- fallback_mask = None
+ .. seealso:: :func:`~torchrl.data.llm.history_default_spec`.
+ """
+ # Composite/NonTensor cannot be imported at module level: torchrl.data is
+ # still initializing when this module is first imported.
+ from torchrl.data import Composite, NonTensor
+
+ def get_default_value(field):
+ if field.default is not dataclasses.MISSING:
+ return field.default
+ elif field.type in (str, "str"):
+ return "foo"
else:
- target_mask = current_mask
- fallback_mask = torch.zeros_like(target_mask)
- else:
- return current_mask
-
- if fallback_mask is not None:
- for idx, (token_start, token_end) in enumerate(offsets):
- if token_start == token_end:
- continue
- for span_start, span_end in spans:
- if token_start < span_end and token_end > span_start:
- fallback_mask[idx] = 1
- break
-
- if fallback_mask is None or not fallback_mask.any():
- if target_mask is None:
- return current_mask
- fallback_mask = torch.zeros_like(target_mask)
- for span_start, span_end in spans:
- prefix_ids = tokenizer(rendered[:span_start], add_special_tokens=False).get(
- "input_ids"
- )
- prefix_and_content_ids = tokenizer(
- rendered[:span_end], add_special_tokens=False
- ).get("input_ids")
- start_idx = len(prefix_ids)
- end_idx = len(prefix_and_content_ids)
- if start_idx < end_idx <= fallback_mask.shape[-1]:
- fallback_mask[start_idx:end_idx] = 1
+ return None
- if input_ids.ndim == 2:
- current_mask = current_mask.clone()
- current_mask[0] = fallback_mask
- return current_mask
- return fallback_mask
+ defaults = {
+ k: NonTensor(
+ example_data=get_default_value(cls.__dataclass_fields__[k]),
+ shape=shape,
+ )
+ for k in cls.__dataclass_fields__
+ }
+ return Composite(defaults, shape=shape[:-1], data_cls=cls)
-def add_chat_template(
- template_name: str,
- template: str,
- inverse_parser: callable | None = None,
- model_family_keywords: list[str] | None = None,
-) -> None:
- r"""Add a custom chat template to the global template dictionary.
- This function allows you to add custom chat templates for new model families
- that support assistant token masking via the `{% generation %}` keyword.
+def history_default_spec(shape=(-1,)):
+ """A default Composite spec for :class:`~tensordict.llm.History` objects, to use in transforms / envs.
Args:
- template_name (str): The name of the template (e.g., "llama", "mistral").
- This name will be used in the `chat_template_name` parameter of
- `History.apply_chat_template()` and `History.from_text()`.
- template (str): The Jinja2 template string. Must include `{% generation %}`
- blocks around assistant message content to enable token masking.
- inverse_parser (callable, optional): A function that parses formatted text back
- into a History object. Should have signature `(text: str) -> History`.
- If None, a basic parser will be used.
- model_family_keywords (list[str], optional): Keywords to detect this model family
- in the auto-detection logic. For example, ["llama", "meta-llama"] for Llama models.
- If provided, the template will be automatically selected for models containing
- these keywords in their name.
+ shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length
+ along the time dimension).
Example:
- >>> from torchrl.data.llm.chat import add_chat_template, History
- >>> from transformers import AutoTokenizer
- >>>
- >>> # Add a custom template for Llama models
- >>> llama_template = '''
- ... {% for message in messages %}
- ... {%- if message['role'] == 'user' %}
- ... {{ '[INST] ' + message['content'] + ' [/INST]' }}
- ... {%- elif message['role'] == 'assistant' %}
- ... {% generation %}{{ message['content'] + '' }}{% endgeneration %}
- ... {%- endif %}
- ... {% endfor %}
- ... {%- if add_generation_prompt %}
- ... {% generation %}{{ ' ' }}{% endgeneration %}
- ... {%- endif %}
- ... '''
- >>>
- >>> def parse_llama_text(text: str) -> History:
- ... # Custom parser for Llama format
- ... import re
- ... pattern = r'\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)'
- ... matches = re.findall(pattern, text, re.DOTALL)
- ... messages = []
- ... for user_content, assistant_content in matches:
- ... messages.append(History(role="user", content=user_content.strip()))
- ... messages.append(History(role="assistant", content=assistant_content.strip()))
- ... return lazy_stack(messages)
- >>>
- >>> # Add the template with auto-detection
- >>> add_chat_template(
- ... template_name="llama",
- ... template=llama_template,
- ... inverse_parser=parse_llama_text,
- ... model_family_keywords=["llama", "meta-llama"]
- ... )
- >>>
- >>> # Now you can use it with auto-detection
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
- >>> history = History.from_chats([[
- ... {"role": "user", "content": "Hello"},
- ... {"role": "assistant", "content": "Hi there!"}
- ... ]])
- >>>
- >>> # Auto-detection will use the llama template
- >>> result = history.apply_chat_template(
- ... tokenizer=tokenizer,
- ... add_generation_prompt=False,
- ... return_dict=True,
- ... return_assistant_tokens_mask=True,
- ... )
+ >>> import tensordict
+ >>> from torchrl.data.llm import history_default_spec
+ >>> tensordict.set_list_to_stack(True).set()
>>>
- >>> # Or use it explicitly
- >>> result = history.apply_chat_template(
- ... tokenizer=tokenizer,
- ... chat_template_name="llama",
- ... add_generation_prompt=False,
- ... return_dict=True,
- ... return_assistant_tokens_mask=True,
- ... )
-
- .. note:
- - The template must include `{% generation %}` blocks around assistant message
- content to enable assistant token masking.
- - The inverse parser should handle the specific format of your template.
- - Model family keywords are case-insensitive and matched against the tokenizer's
- `name_or_path` attribute.
- - Templates are stored globally and persist for the duration of the Python session.
- """
- global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS # noqa: F824
-
- # Validate template contains generation blocks
- if "{% generation %}" not in template:
- raise ValueError(
- f"Template '{template_name}' must include '{{% generation %}}' blocks "
- "around assistant message content to enable token masking."
- )
-
- # Add template to dictionary
- _CHAT_TEMPLATES[template_name] = template
-
- # Store inverse parser if provided
- if inverse_parser is not None:
- _CUSTOM_INVERSE_PARSERS[template_name] = inverse_parser
-
- # Store model family keywords if provided
- if model_family_keywords is not None:
- _CUSTOM_MODEL_FAMILY_KEYWORDS[template_name] = model_family_keywords
-
- torchrl_logger.info(
- f"Added custom chat template '{template_name}' with assistant token masking support"
- )
-
-
-# We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
-class ContentBase(TensorClass["nocast", "shadow"]):
- """Base class for all message content types.
-
- Attributes:
- type (str): The type of the content.
- text (str, optional): The text content.
- url (str, optional): The URL content.
- data (str, optional): The data content.
- mime_type (str, optional): The MIME type of the content.
- name (str, optional): The name of the content.
- size (int, optional): The size of the content.
- function_name (str, optional): The name of the function.
- function_args (dict, optional): The arguments of the function.
-
- Examples:
- >>> from tensordict import lazy_stack
- >>> content1 = ContentBase(type="text", text="Hello, world!")
- >>> print(content1)
- ContentBase(
- text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None),
- type=NonTensorData(data=text, batch_size=torch.Size([]), device=None),
- url=None,
- data=None,
- mime_type=None,
- name=None,
- size=None,
- function_name=None,
- function_args=None,
- batch_size=torch.Size([]),
- device=None,
- is_shared=False)
- >>> content2 = ContentBase(type="image", url="https://example.com/image.jpg")
- >>> print(content2)
- ContentBase(
- type=NonTensorData(data=image, batch_size=torch.Size([]), device=None),
- url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None),
- text=None,
- data=None,
- mime_type=None,
- name=None,
- size=None,
- function_name=None,
- function_args=None,
- batch_size=torch.Size([]),
+ >>> spec = history_default_spec()
+ >>> print(spec)
+ Composite(
+ role: NonTensor(
+ shape=torch.Size([-1]),
+ space=None,
+ device=None,
+ dtype=None,
+ domain=None,
+ example_data=foo),
+ content: NonTensor(
+ shape=torch.Size([-1]),
+ space=None,
+ device=None,
+ dtype=None,
+ domain=None,
+ example_data=foo),
device=None,
- is_shared=False)
- >>> content = lazy_stack([content1, content2])
- >>> print(content)
- ContentBase(
- type=NonTensorStack(
- ['text', 'image'],
- batch_size=torch.Size([2]),
- device=None),
- url=None,
- data=None,
- mime_type=None,
- name=None,
- size=None,
- function_name=None,
- function_args=None,
- text=None,
- batch_size=torch.Size([2]),
+ shape=torch.Size([-1]))
+ >>> print(spec.zero())
+ History(
+ content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
+ role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
+ batch_size=torch.Size([1]),
device=None,
is_shared=False)
- >>> # A content is typically used in a History object. Usually, its batch dimension is
- >>> # one dimension greater than the History object.
- >>> history = History(role="user", content=content)
-
- """
-
- type: Literal[
- "text", "image", "audio", "video", "file", "function_call"
- ] # Required: "text", "image", "audio", "video", "file", "function_call"
-
- # Text content
- text: str | None = None
-
- # Media/file content (either URL or data)
- url: str | None = None # HTTP URL to content
- data: str | None = None # Base64 encoded content
- # Metadata
- mime_type: str | None = None # "image/jpeg", "audio/mp3", "application/pdf"
- name: str | None = None # Original filename or description
- size: int | None = None # File size in bytes
-
- # Function calling (for AI agents)
- function_name: str | None = None
- function_args: dict | None = None
-
-
-class History(TensorClass["nocast"]):
- """A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models.
-
- The `History` class provides a centralized API for managing conversational data, offering several advantages over
- traditional list-based approaches:
-
- - Centralized API for conversion to and from string formats, facilitating seamless integration with language models.
- - Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation
- trajectories, especially useful in reinforcement learning environments.
- - Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data.
- - **Assistant token masking support** across multiple model families for reinforcement learning applications.
-
- **Recent Changes:**
- - **ChatHistory Integration**: History objects are now used within :class:`~torchrl.modules.llm.policies.ChatHistory`
- containers for structured conversation management in LLM environments.
- - **Modular Wrapper Support**: Both vLLMWrapper and TransformersWrapper now use History objects when `input_mode="history"`
- is specified, providing consistent conversation state management.
- - **Environment Integration**: ChatEnv and related environments use History objects for state management and conversation tracking.
-
- .. note:: The `""` role is used to indicate that the element is a placeholder,
- for example when the tool call was not executed but a stack requires a certain number of elements
- per batch to have congruent shapes. The :meth:`~torchrl.data.llm.chat.History.apply_chat_template`
- method will remove the `` role from the history.
-
- **Assistant Token Masking Support:**
-
- The class supports assistant token masking across multiple model families, allowing you to identify which tokens
- in a conversation were generated by the assistant. This is crucial for reinforcement learning applications.
-
- **Supported Model Families:**
-
- - **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support
- - **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format
- - **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format
- - **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format
- - **Other models** (OPT, GPT, MPT, BLOOM, Pythia, Phi, etc.): Default `chatml_format` template
-
- **Example with Assistant Token Masking:**
-
- .. code-block:: python
-
- >>> from torchrl.data.llm.chat import History
- >>> from torchrl.modules.llm.policies import ChatHistory
- >>> from transformers import AutoTokenizer
- >>>
- >>> # Create a conversation history
- >>> history = History.from_chats([[
- ... {"role": "user", "content": "Hello"},
- ... {"role": "assistant", "content": "Hi there!"},
- ... {"role": "user", "content": "How are you?"},
- ... {"role": "assistant", "content": "I'm doing well, thanks!"}
- ... ]])
- >>>
- >>> # Create ChatHistory container for LLM wrapper
- >>> chat_history = ChatHistory(prompt=history)
- >>>
- >>> # Load any supported tokenizer
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
- >>>
- >>> # Apply chat template with assistant token masking
- >>> result = history.apply_chat_template(
- ... tokenizer=tokenizer,
- ... add_generation_prompt=False,
- ... return_dict=True,
- ... return_assistant_tokens_mask=True,
- ... )
- >>>
- >>> # The result contains an assistant_masks tensor
- >>> assistant_masks = result["assistant_masks"]
- >>> print(f"Assistant tokens: {assistant_masks.sum().item()}")
-
- **Integration with LLM Wrappers:**
-
- History objects work seamlessly with the new modular wrapper design:
-
- .. code-block:: python
-
- >>> from torchrl.modules.llm import TransformersWrapper
- >>> from torchrl.modules.llm.policies import ChatHistory
- >>>
- >>> # Create wrapper with history input mode
- >>> wrapper = TransformersWrapper(
- ... model, tokenizer=tokenizer,
- ... input_mode="history",
- ... generate=True,
- ... return_log_probs=True
- ... )
- >>>
- >>> # Use History with ChatHistory container
- >>> history = History.from_chats([[
- ... {"role": "user", "content": "Hello"},
- ... {"role": "assistant", "content": "Hi there!"}
- ... ]])
- >>> chat_history = ChatHistory(prompt=history)
- >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
- >>> print(result["history"].response) # New response from LLM
-
- Attributes:
- role (str): The role of the message sender.
- content (str): The content of the message.
- is_complete (bool): Whether the message was properly terminated with an end token. Defaults to `True`.
- tool_calls (list[dict] | None): Optional list of tool calls in the message.
- tool_responses (list[str] | None): Optional list of tool responses.
-
- Methods:
- apply_chat_template: converts the `History` object to str / tokens.
- append: append one element to the list of items along a given dimension.
- extend: extend the list of items along a given dimension.
-
- Examples:
- >>> # With tensordict < 0.10, we need to tell the lib that lists constitute batches
- >>> import tensordict
- >>> tensordict.set_list_to_stack(True).set()
- >>> import transformers
- >>> history0 = History(
- ... role='system',
- ... content='''CONTENT
- ... This is the setup''',
- ... )
- >>> history1 = History(
- ... role='user',
- ... content='''CONTENT
- ... This is the first user prompt''',
- ... )
- >>> history2 = History(
- ... role='assistant',
- ... content='''CONTENT
- ... This is the second prompt, the first for the assistant.''',
- ... )
- >>> history = torch.stack([history0, history1, history2])
- >>> assert history.role == ['system', 'user', 'assistant']
- >>> tokenizer = transformers.AutoTokenizer.from_pretrained("GPT2")
- >>> # Apply a template to pass the history to an LLM. Note that the output has
- >>> # an additional prompt to elict an answer from the LLM thanks to the 'add_generation_prompt' argument.
- >>> parsed_string = history.apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True)
- >>> parsed_string
- <|im_start|>system
- CONTENT
- This is the setup<|im_end|>
-
- <|im_start|>user
- CONTENT
- This is the first user prompt<|im_end|>
-
- <|im_start|>assistant
- CONTENT
- This is the second prompt, the first for the assistant.<|im_end|>
-
- <|im_start|>assistant
-
- .. seealso::
- :class:`~torchrl.modules.llm.policies.ChatHistory`: Container for managing conversation data in LLM environments.
- :class:`~torchrl.modules.llm.policies.Text`: Container for text data.
- :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
"""
+ return _history_default_spec(History, shape)
- role: str | list[str] | list[list[str]]
- content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[
- list[ContentBase]
- ]
- is_complete: bool = True
- tool_calls: list[dict] | None = None
- tool_responses: list[str] | None = None
-
- def __post_init__(self):
- if not list_to_stack():
- raise RuntimeError(
- "Please set the list_to_stack to True using tensordict.set_list_to_stack(True).set() at the beginning of your script, "
- "or the LIST_TO_STACK=1 environment variable."
- )
-
- def apply_chat_template(
- self,
- *,
- tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
- add_generation_prompt: bool = True,
- chat_template: str | None = None,
- chat_template_name: str | None = None,
- continue_final_message: bool = False,
- tokenize: bool | None = None,
- padding: bool | str = False,
- truncation: bool | str = False,
- return_tensors: str | None = None,
- return_dict: bool | None = None,
- return_assistant_tokens_mask: bool = False,
- **kwargs,
- ) -> str | list[str] | TensorDict:
- """Applies a chat template to the history.
-
- Keyword Args:
- tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
- add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
- chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
- chat_template_name (str, optional): The name of the chat template to use.
- Prevalent over `tokenizer.chat_template`. If `None`, the method will automatically detect the model family and use the appropriate template.
- Defaults to `None`.
- continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
- tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
- padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
- truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`.
- return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
- return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
- return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
- If `True`, the mask will be written to the `assistant_masks` key.
- For tokens generated by the assistant, the mask will contain `1`.
- For user and system tokens, the mask will contain `0`.
- This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
- Defaults to `False`.
-
- .. note:: Assistant token masking is supported across multiple model families:
- - **Qwen family**: Uses custom template with full tool calling support
- - **DialoGPT family**: Uses custom template for conversation format
- - **Falcon family**: Uses custom template for instruction format
- - **DeepSeek family**: Uses custom template with native format
- - **Other models**: Use the default `chatml_format` template
-
- The method automatically detects the model family and selects the appropriate template.
-
- **kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
-
- Returns:
- The formatted history.
- """
- if chat_template is None:
- if chat_template_name is not None:
- chat_template = _CHAT_TEMPLATES[chat_template_name]
- chat_template_name = None
- elif tokenizer is None:
- raise RuntimeError(
- "You must specify a tokenizer to use when chat_template is not specified."
- )
- else:
- # Auto-detect model family and use appropriate template
- model_name = getattr(tokenizer, "name_or_path", "").lower()
-
- # First check for custom model family keywords
- custom_template_found = False
- for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
- if any(keyword.lower() in model_name for keyword in keywords):
- chat_template = _CHAT_TEMPLATES[template_name]
- chat_template_name = None
- custom_template_found = True
- break
-
- if not custom_template_found:
- # Fall back to built-in model family detection
- if "qwen" in model_name:
- # We prefer our implementation of the Qwen template,
- # since it accounts for the assistant's masking.
- chat_template = _CHAT_TEMPLATES["qwen"]
- chat_template_name = None
- elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
- # DialoGPT family - use our custom template
- chat_template = _CHAT_TEMPLATES["dialogpt"]
- chat_template_name = None
- elif "falcon" in model_name or "tiiuae/falcon" in model_name:
- # Falcon family - use our custom template
- chat_template = _CHAT_TEMPLATES["falcon"]
- chat_template_name = None
- elif "deepseek" in model_name:
- # DeepSeek family - use our custom template with generation keyword
- chat_template = _CHAT_TEMPLATES["deepseek"]
- chat_template_name = None
- elif "llama" in model_name:
- # Llama family - use our custom template
- chat_template = _CHAT_TEMPLATES["llama"]
- chat_template_name = None
- else:
- # For other models, check if their default template supports generation
- if (
- hasattr(tokenizer, "chat_template")
- and tokenizer.chat_template
- and "{% generation %}" in tokenizer.chat_template
- ):
- # Use the model's own template if it supports generation
- chat_template = tokenizer.chat_template
- else:
- # Use our default chatml_format template
- chat_template = _CHAT_TEMPLATES["chatml_format"]
- if chat_template is None:
- chat_template = _CHAT_TEMPLATES["chatml_format"]
- if tokenize is None:
- if return_assistant_tokens_mask or return_tensors is not None:
- tokenize = True
- else:
- tokenize = False
- if tokenize:
- if return_tensors is None:
- return_tensors = "pt"
- if return_dict is None and return_assistant_tokens_mask:
- return_dict = True
- elif return_dict is None:
- return_dict = False
-
- if self.ndim > 1:
- result = [
- self[i].apply_chat_template(
- tokenizer=tokenizer,
- add_generation_prompt=add_generation_prompt,
- chat_template=chat_template,
- chat_template_name=chat_template_name,
- tokenize=tokenize,
- padding=padding,
- truncation=truncation,
- return_tensors=return_tensors,
- continue_final_message=continue_final_message,
- return_dict=return_dict,
- return_assistant_tokens_mask=return_assistant_tokens_mask,
- **kwargs,
- )
- for i in range(self.batch_size[0])
- ]
- if return_dict:
- return lazy_stack(result)
- else:
- return result
- self_flat = self.view(-1)
- # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
- self_flat = self_flat.tolist(tolist_first=True)
- # Remove the "" role
- self_flat = [item for item in self_flat if item["role"] != ""]
- result = tokenizer.apply_chat_template(
- conversation=self_flat,
- add_generation_prompt=add_generation_prompt,
- chat_template=chat_template,
- tokenize=tokenize,
- padding=padding,
- truncation=truncation,
- return_tensors=return_tensors,
- continue_final_message=continue_final_message,
- return_dict=return_dict,
- return_assistant_tokens_mask=return_assistant_tokens_mask,
- **kwargs,
- )
- if (
- return_assistant_tokens_mask
- and not isinstance(result, (torch.Tensor, list, str))
- and "assistant_masks" in result
- and "input_ids" in result
- ):
- rendered = tokenizer.apply_chat_template(
- conversation=self_flat,
- add_generation_prompt=add_generation_prompt,
- chat_template=chat_template,
- tokenize=False,
- continue_final_message=continue_final_message,
- )
- if (
- isinstance(rendered, str)
- and isinstance(result["assistant_masks"], torch.Tensor)
- and isinstance(result["input_ids"], torch.Tensor)
- ):
- result["assistant_masks"] = _fallback_assistant_tokens_mask(
- tokenizer=tokenizer,
- rendered=rendered,
- conversation=self_flat,
- input_ids=result["input_ids"],
- current_mask=result["assistant_masks"],
- )
- if not isinstance(result, (torch.Tensor, list, str)):
- result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
- # If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
- if self.batch_dims == 1:
- if result.batch_size[0] != 1:
- raise RuntimeError(
- f"Expected a batch size of 1, got {result.batch_size[0]}."
- )
- result = result.squeeze(0)
- return result
- @classmethod
- def from_text(
- cls,
- text: str | list[str],
- chat_template_name: str | None = None,
- # currently without effect
- chat_template: str | None = None,
- tokenizer: transformers.AutoTokenizer # noqa: F821
- | transformers.AutoProcessor # noqa: F821
- | None = None,
- ) -> History:
- r"""Inverts a chat template into a History object.
-
- Args:
- text (str | list[str]): The chat template to invert.
- chat_template_name (str, optional): The name of the chat template to use.
- tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use.
-
- Returns:
- History: The inverted History object.
-
- Examples:
- >>> from torchrl.data.llm.history import History
- >>> from transformers import AutoTokenizer
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
- >>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris, the capital of Germany is Berlin.\n\n"
- >>> history = History.from_text(text, tokenizer=tokenizer)
- >>> print(history)
- History(
- content=NonTensorStack(
- ['You are a helpful assistant.', 'Write a python s...,
- batch_size=torch.Size([3]),
- device=None),
- is_complete=NonTensorStack(
- [True, True, False],
- batch_size=torch.Size([3]),
- device=None),
- role=NonTensorStack(
- ['system', 'user', 'assistant'],
- batch_size=torch.Size([3]),
- device=None),
- tool_calls=None,
- tool_responses=None,
- batch_size=torch.Size([3]),
- device=None,
- is_shared=False)
- """
- if chat_template_name is None:
- if chat_template is not None:
- # TODO: find best match given template
- pass
-
- model_name = getattr(tokenizer, "name_or_path", "").lower()
- # First check for custom model family keywords
- custom_template_found = False
- for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
- if any(keyword.lower() in model_name for keyword in keywords):
- chat_template_name = template_name
- custom_template_found = True
- break
-
- if not custom_template_found:
- # Fall back to built-in model family detection
- if "qwen" in model_name:
- # We can automatically detect the template name from the tokenizer
- # and use the precoded parser.
- chat_template_name = "qwen"
- elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
- chat_template_name = "dialogpt"
- elif "falcon" in model_name or "tiiuae/falcon" in model_name:
- chat_template_name = "falcon"
- elif "deepseek" in model_name:
- chat_template_name = "deepseek"
- elif "llama" in model_name:
- chat_template_name = "llama"
- else:
- chat_template_name = "chatml_format"
-
- # Get the appropriate inverse parser function
- if chat_template_name in ("chatml_format",):
- func = cls._inv_chatml
- elif chat_template_name in ("qwen",):
- func = cls._inv_qwen
- elif chat_template_name in ("dialogpt",):
- func = cls._inv_dialogpt
- elif chat_template_name in ("falcon",):
- func = cls._inv_falcon
- elif chat_template_name in ("deepseek",):
- func = cls._inv_deepseek
- elif chat_template_name in ("llama",):
- func = cls._inv_llama
- elif chat_template_name in _CUSTOM_INVERSE_PARSERS:
- # Use custom inverse parser
- func = _CUSTOM_INVERSE_PARSERS[chat_template_name]
- else:
- raise NotImplementedError(
- f"chat_template_name '{chat_template_name}' is not supported. "
- "Supported templates: 'chatml_format', 'qwen', 'dialogpt', 'falcon', 'deepseek'. "
- "Use add_chat_template() to add custom templates."
- )
- if isinstance(text, list):
- list_of_histories = [func(t) for t in text]
- try:
- return lazy_stack(list_of_histories)
- except RuntimeError as e:
- raise RuntimeError(
- f"Failed to stack histories: {list_of_histories=}"
- ) from e
- return func(text)
-
- @classmethod
- def _inv_chatml(cls, text: str) -> History:
- """Inverts a chatml string into a History object.
-
- Args:
- text (str): The chatml string to invert.
-
- Returns:
- History: The inverted History object.
- """
- import json
-
- torchrl_logger.debug(f"Inverting chatml:\n{text}")
- # Find all complete blocks (ending with im_end or endoftext)
- complete_pattern = r"<\|im_start\|>(.*?)\n(.*?)<\|(im_end|endoftext)\|>"
- complete_matches = re.findall(complete_pattern, text, flags=re.DOTALL)
-
- # Find any incomplete block at the end
- incomplete_pattern = r"<\|im_start\|>(.*?)\n(.*?)$"
- incomplete_matches = []
- if complete_matches:
- # Look for incomplete block after the last complete one
- last_complete = complete_matches[-1]
- last_complete_text = f"<|im_start|>{last_complete[0]}\n{last_complete[1]}<|{last_complete[2]}|>"
- remaining_text = text[
- text.rindex(last_complete_text) + len(last_complete_text) :
- ]
- if remaining_text.strip():
- incomplete_match = re.search(
- incomplete_pattern, remaining_text, flags=re.DOTALL
- )
- if incomplete_match:
- incomplete_matches = [
- (incomplete_match.group(1), incomplete_match.group(2), None)
- ]
- else:
- # No complete blocks, check entire text for incomplete block
- incomplete_match = re.search(incomplete_pattern, text, flags=re.DOTALL)
- if incomplete_match:
- incomplete_matches = [
- (incomplete_match.group(1), incomplete_match.group(2), None)
- ]
-
- # Combine complete and incomplete matches
- matches = complete_matches + incomplete_matches
-
- # Define tool patterns - same as Qwen for consistency
- tool_call_pattern = re.compile(r"\n(.*?)\n", re.DOTALL)
- tool_response_pattern = re.compile(
- r"\n(.*?)\n", re.DOTALL
- )
-
- parsed_messages = []
- for match in matches:
- role = match[0].strip()
- content = match[1].strip()
- is_complete = match[2] is not None # None indicates incomplete
-
- # Initialize message dict
- message_dict = {
- "role": role,
- "content": content,
- "is_complete": is_complete,
- "tool_calls": None,
- "tool_responses": None,
- }
-
- # Find tool calls within the message
- tool_calls = tool_call_pattern.findall(content)
- if tool_calls:
- tool_calls_list = []
- for tool_call in tool_calls:
- try:
- tool_call_dict = json.loads(tool_call)
- tool_calls_list.append(tool_call_dict)
- except json.JSONDecodeError:
- continue
- if tool_calls_list:
- message_dict["tool_calls"] = tool_calls_list
-
- # Check for tool responses
- tool_responses = tool_response_pattern.findall(content)
- if tool_responses:
- message_dict["tool_responses"] = tool_responses
-
- parsed_messages.append(cls(**message_dict))
-
- if not parsed_messages:
- raise RuntimeError(
- f"Couldn't get a single item out of text {text}. A common cause "
- f"if that special tokens should not be omitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
- )
-
- return lazy_stack(parsed_messages)
-
- @classmethod
- def _inv_qwen(cls, template):
- import json
-
- # Define regex patterns for different parts of the template
- message_pattern = re.compile(
- r"<\|im_start\|>(.*?)(?:<\|(im_end|endoftext)\|>|$)", re.DOTALL
- )
- tool_call_pattern = re.compile(r"\n(.*?)\n", re.DOTALL)
- tool_response_pattern = re.compile(
- r"\n(.*?)\n", re.DOTALL
- )
-
- # Find all messages and track if they end with a proper token
- messages = []
- is_complete_list = []
- for match in message_pattern.finditer(template):
- full_match = match.group(0)
- messages.append(match.group(1))
- # Check if the message ends with a proper token
- is_complete_list.append(
- full_match.endswith("<|im_end|>")
- or full_match.endswith("<|endoftext|>")
- )
-
- parsed_messages = []
- for message, is_complete in zip(messages, is_complete_list):
- # Split the message into role and content
- parts = message.split("\n", 1)
- if len(parts) < 2:
- continue
- role, content = parts[0], parts[1]
-
- # Initialize message dict
- message_dict = {
- "role": role.strip(),
- "content": content.strip(),
- "is_complete": is_complete,
- "tool_calls": None,
- "tool_responses": None,
- }
-
- # Find tool calls within the message
- tool_calls = tool_call_pattern.findall(content)
- if tool_calls:
- tool_calls_list = []
- for tool_call in tool_calls:
- try:
- tool_call_dict = json.loads(tool_call)
- tool_calls_list.append(tool_call_dict)
- except json.JSONDecodeError:
- continue
- if tool_calls_list:
- message_dict["tool_calls"] = tool_calls_list
-
- # Check for tool responses
- tool_responses = tool_response_pattern.findall(content)
- if tool_responses:
- message_dict["tool_responses"] = tool_responses
-
- parsed_messages.append(cls(**message_dict))
-
- if not parsed_messages:
- raise RuntimeError(
- f"Couldn't get a single item out of text {template}. A common cause "
- f"if that special tokens should not be omitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
- )
-
- return lazy_stack(parsed_messages)
-
- @classmethod
- def _inv_dialogpt(cls, text: str) -> History:
- """Inverts a DialogPT string into a History object.
-
- Args:
- text (str): The DialogPT string to invert.
-
- Returns:
- History: The inverted History object.
- """
- torchrl_logger.debug(f"Inverting DialogPT:\n{text}")
-
- # DialogPT format is simple: alternating user/assistant messages
- # Split by lines and parse
- lines = text.strip().split("\n")
- parsed_messages = []
-
- for line in lines:
- line = line.strip()
- if not line:
- continue
-
- # Determine role based on content
- if line.startswith("Assistant:"):
- role = "assistant"
- content = line[len("Assistant:") :].strip()
- elif line.startswith("User:"):
- role = "user"
- content = line[len("User:") :].strip()
- else:
- # Default to user if no prefix
- role = "user"
- content = line
-
- message_dict = {
- "role": role,
- "content": content,
- "is_complete": True, # DialogPT doesn't have explicit end tokens
- "tool_calls": None,
- "tool_responses": None,
- }
-
- parsed_messages.append(cls(**message_dict))
-
- if not parsed_messages:
- raise RuntimeError(f"Couldn't get a single item out of text {text}.")
-
- return lazy_stack(parsed_messages)
-
- @classmethod
- def _inv_falcon(cls, text: str) -> History:
- """Inverts a Falcon string into a History object.
-
- Args:
- text (str): The Falcon string to invert.
-
- Returns:
- History: The inverted History object.
- """
- torchrl_logger.debug(f"Inverting Falcon:\n{text}")
-
- # Falcon format: "User: ... Assistant: ..."
- # Split by "User:" and "Assistant:" prefixes
- import re
-
- # Pattern to match User: and Assistant: messages
- pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
- matches = re.findall(pattern, text, re.DOTALL)
-
- parsed_messages = []
- for match in matches:
- if len(match) != 2:
- continue
- prefix, content = match
- content = content.strip()
- if not content:
- continue
-
- if prefix == "User:":
- role = "user"
- elif prefix == "Assistant:":
- role = "assistant"
- else:
- continue
-
- message_dict = {
- "role": role,
- "content": content,
- "is_complete": True, # Falcon doesn't have explicit end tokens
- "tool_calls": None,
- "tool_responses": None,
- }
-
- parsed_messages.append(cls(**message_dict))
-
- if not parsed_messages:
- raise RuntimeError(f"Couldn't get a single item out of text {text}.")
-
- return lazy_stack(parsed_messages)
-
- @classmethod
- def _inv_deepseek(cls, text: str) -> History:
- """Inverts a DeepSeek string into a History object.
-
- Args:
- text (str): The DeepSeek string to invert.
-
- Returns:
- History: The inverted History object.
- """
- torchrl_logger.debug(f"Inverting DeepSeek:\n{text}")
- import re
-
- # Remove leading/trailing special tokens (e.g.
- text = re.sub(r"^<[^>]+>", "", text) # Remove leading <...>
- text = re.sub(r"<[^>]+>$", "", text) # Remove trailing <...>
- # Remove any REDACTED_SPECIAL_TOKEN if present
- text = re.sub(r"REDACTED_SPECIAL_TOKEN", "", text)
- # Pattern to match User: and Assistant: messages
- pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
- matches = re.findall(pattern, text, re.DOTALL)
- parsed_messages = []
- for match in matches:
- if len(match) < 2:
- continue
- prefix, content = match[0], match[1]
- content = content.strip()
- if not content:
- continue
- if prefix == "User:":
- role = "user"
- elif prefix == "Assistant:":
- role = "assistant"
- else:
- continue
- message_dict = {
- "role": role,
- "content": content,
- "is_complete": True, # DeepSeek doesn't have explicit end tokens
- "tool_calls": None,
- "tool_responses": None,
- }
- parsed_messages.append(cls(**message_dict))
- if not parsed_messages:
- raise RuntimeError(f"Couldn't get a single item out of text {text}.")
- return lazy_stack(parsed_messages)
-
- @classmethod
- def _inv_llama(cls, text: str) -> History:
- import re
-
- messages = []
-
- # Remove BOS token if present
- if text.startswith("<|begin_of_text|>"):
- text = text[len("<|begin_of_text|>") :]
-
- # Pattern to match complete message blocks: <|header_start|>role<|header_end|>\n\ncontent<|eot|>
- complete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)<\|eot\|>"
- complete_matches = re.findall(complete_pattern, text, re.DOTALL)
-
- # Pattern to match incomplete message blocks: <|header_start|>role<|header_end|>\n\ncontent (without <|eot|>)
- incomplete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)$"
-
- # Find any incomplete message at the end
- incomplete_matches = []
- if complete_matches:
- # Look for incomplete message after the last complete one
- last_complete_end = text.rfind("<|eot|>")
- if last_complete_end != -1:
- remaining_text = text[last_complete_end + len("<|eot|>") :]
- if remaining_text.strip():
- incomplete_match = re.search(
- incomplete_pattern, remaining_text, re.DOTALL
- )
- if incomplete_match:
- incomplete_matches = [
- (
- incomplete_match.group(1),
- incomplete_match.group(2),
- False,
- )
- ]
- else:
- # No complete messages, check entire text for incomplete message
- incomplete_match = re.search(incomplete_pattern, text, re.DOTALL)
- if incomplete_match:
- incomplete_matches = [
- (incomplete_match.group(1), incomplete_match.group(2), False)
- ]
-
- # Process complete messages
- for role, content in complete_matches:
- if content.strip():
- messages.append(
- cls(role=role, content=content.strip(), is_complete=True)
- )
-
- # Process incomplete messages
- for role, content, is_complete in incomplete_matches:
- if content.strip():
- messages.append(
- cls(role=role, content=content.strip(), is_complete=is_complete)
- )
-
- if not messages:
- raise RuntimeError(f"Couldn't parse Llama format from text: {text}")
-
- from tensordict import lazy_stack
-
- return lazy_stack(messages)
-
- def append(
- self, history: History, *, inplace: bool = True, dim: int = -1
- ) -> History:
- """Appends a new history to the current one.
-
- Args:
- history (History): The new history to append.
- inplace (bool, optional): Whether to perform the operation in-place. Defaults to `True`.
- dim (int, optional): The dimension to append along. Defaults to -1.
-
- Returns:
- History: The appended History object.
- """
- # TODO: we should remove the role from the history before appending / extending
- # It works when keeping them, but it may lead to a lot of useless padding in between valid messages
- if not self.batch_dims:
- raise RuntimeError(
- "Cannot append an element to a batchless History. Call unsqueeze(dim=0) first on self."
- )
- if self.batch_dims != history.batch_dims + 1:
- raise RuntimeError(
- f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}."
- )
- dim = _maybe_correct_neg_dim(dim, self.batch_size)
- if inplace:
- if (
- isinstance(self._tensordict, LazyStackedTensorDict)
- and self._tensordict.stack_dim == dim
- ):
- td = history._tensordict
- if td.device != self.device:
- if self.device is None:
- td = td.copy().clear_device_()
- else:
- td = td.to(self.device)
- self._tensordict.append(td)
- return self
- else:
- td = history._tensordict
- if td.device != self.device:
- if self.device is None:
- td = td.copy().clear_device_()
- else:
- td = td.to(self.device)
- td = lazy_stack(list(self._tensordict.unbind(dim)) + [td], dim=dim)
- self.__dict__["_tensordict"] = td
- return self
- if history.device != self.device:
- if self.device is None:
- history = history.copy().clear_device_()
- else:
- history = history.to(self.device)
- return lazy_stack(list(self.unbind(dim)) + [history], dim=dim)
-
- def extend(
- self, history: History, *, inplace: bool = True, dim: int = 0
- ) -> History:
- if not self.batch_dims:
- raise RuntimeError(
- "Cannot add an element to a batchless History. Call unsqueeze(dim=0) first on self."
- )
- if self.batch_dims != history.batch_dims:
- raise RuntimeError(
- f"The new history to extend must have as many dimensions as self. Got self.ndim={self.ndim} and history.ndim={self.ndim}."
- )
- dim = _maybe_correct_neg_dim(dim, self.batch_size)
- # if self.ndim > 1 and dim >= self.ndim - 1:
- # # then we need to append each element independently
- # result = []
- # for hist, new_hist in zip(self.unbind(0), history.unbind(0)):
- # hist_c = hist.extend(new_hist, inplace=inplace, dim=dim - 1)
- # result.append(hist_c)
- # if inplace:
- # return self
- # return lazy_stack(result)
- if inplace:
- if (
- isinstance(self._tensordict, LazyStackedTensorDict)
- and self._tensordict.stack_dim == dim
- ):
- td = history._tensordict
- if td.device != self.device:
- if self.device is None:
- td = td.copy().clear_device_()
- else:
- td = td.to(self.device)
- self._tensordict.extend(td)
- return self
- else:
- td = lazy_stack(
- list(self._tensordict.unbind(dim))
- + list(history._tensordict.unbind(dim)),
- dim=dim,
- )
- if td.device != self.device:
- if self.device is None:
- td = td.copy().clear_device_()
- else:
- td = td.to(self.device)
- self.__dict__["_tensordict"] = td
- return self
- if history.device != self.device:
- if self.device is None:
- history = history.copy().clear_device_()
- else:
- history = history.to(self.device)
- return torch.stack(list(self.unbind(dim)) + list(history.unbind(dim)), dim=dim)
-
- @classmethod
- def default_spec(cls, shape=(-1,)):
- """A default spec to use in transforms / envs that return History objects.
-
- Args:
- shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length
- along the time dimension).
-
- Example:
- >>> import tensordict
- >>> from torchrl.data import History
- >>> tensordict.set_list_to_stack(True).set()
- >>>
- >>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,))
- >>> spec = history.default_spec()
- >>> print(spec)
- Composite(
- role: NonTensor(
- shape=torch.Size([-1]),
- space=None,
- device=None,
- dtype=None,
- domain=None,
- example_data=foo),
- content: NonTensor(
- shape=torch.Size([-1]),
- space=None,
- device=None,
- dtype=None,
- domain=None,
- example_data=foo),
- device=None,
- shape=torch.Size([-1]))
- >>> print(spec.zero())
- History(
- content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
- role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
- batch_size=torch.Size([1]),
- device=None,
- is_shared=False)
-
- """
- from torchrl.data import Composite, NonTensor
-
- def get_default_value(field):
- if field.default is not dataclasses.MISSING:
- return field.default
- elif field.type in (str, "str"):
- return "foo"
- else:
- return None
-
- defaults = {
- k: NonTensor(
- example_data=get_default_value(cls.__dataclass_fields__[k]),
- shape=shape,
- )
- for k in cls.__dataclass_fields__
- }
-
- return Composite(defaults, shape=shape[:-1], data_cls=cls)
-
- @classmethod
- def from_chats(cls, chats: list[list[dict]]) -> History:
- """Create a History object from a list of chats.
-
- Args:
- chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries.
- """
- if isinstance(chats[0], dict):
- return lazy_stack([cls(**chat) for chat in chats])
- else:
- return lazy_stack([cls.from_chats(chat) for chat in chats])
+# `History` lives in tensordict, which cannot depend on torchrl's specs.
+# Attach the `default_spec` classmethod here so that the established
+# `History.default_spec()` API (mirroring ChatHistory/Text/Tokens) keeps
+# working once torchrl is imported.
+History.default_spec = classmethod(_history_default_spec)
diff --git a/torchrl/envs/llm/libs/mlgym.py b/torchrl/envs/llm/libs/mlgym.py
index 8527f069ccf..eec46702e89 100644
--- a/torchrl/envs/llm/libs/mlgym.py
+++ b/torchrl/envs/llm/libs/mlgym.py
@@ -23,7 +23,7 @@
from torchrl._utils import logger as torchrl_logger
from torchrl.data import Choice, Composite, NonTensor
-from torchrl.data.llm import History
+from torchrl.data.llm import History, history_default_spec
from torchrl.envs import ConditionalSkip, GymWrapper, Transform, TransformedEnv
if TYPE_CHECKING:
@@ -229,7 +229,7 @@ def set_environment_vars(
env.add_commands(command_files)
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
- observation_spec["history"] = History.default_spec()
+ observation_spec["history"] = history_default_spec()
return observation_spec
def transform_action_spec(self, action_spec: Composite) -> Composite:
@@ -247,7 +247,7 @@ def transform_action_spec(self, action_spec: Composite) -> Composite:
)
def transform_state_spec(self, state_spec: Composite) -> Composite:
- state_spec["history"] = History.default_spec()
+ state_spec["history"] = history_default_spec()
return state_spec
diff --git a/torchrl/modules/llm/policies/common.py b/torchrl/modules/llm/policies/common.py
index 461310a3363..4174f30e8d3 100644
--- a/torchrl/modules/llm/policies/common.py
+++ b/torchrl/modules/llm/policies/common.py
@@ -21,7 +21,7 @@
from torch import distributions as D
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
-from torchrl.data.llm import History
+from torchrl.data.llm import History, history_default_spec
from torchrl.data.tensor_specs import Unbounded
from torchrl.modules.distributions.discrete import LLMMaskedCategorical
@@ -276,7 +276,7 @@ def default_spec(
if keys is None:
keys = ["prompt", "response", "full"]
return Composite(
- {k: History.default_spec(shape=shape + (-1,)) for k in keys},
+ {k: history_default_spec(shape + (-1,)) for k in keys},
shape=shape[:-1],
data_cls=cls,
step_mdp_static=True,