diff --git a/configs/instructblip/README.md b/configs/instructblip/README.md
new file mode 100644
index 00000000000..8a324e3c67a
--- /dev/null
+++ b/configs/instructblip/README.md
@@ -0,0 +1,53 @@
+# MiniGPT4
+
+> [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500)
+
+
+
+## Abstract
+
+Large-scale pre-training and instruction tuning have been successful at creating general-purpose language models with broad competence. However, building general-purpose vision-language models is challenging due to the rich input distributions and task diversity resulting from the additional visual input. Although
+vision-language pretraining has been widely studied, vision-language instruction tuning remains under-explored. In this paper, we conduct a systematic and comprehensive study on vision-language instruction tuning based on the pretrained BLIP-2 models. We gather 26 publicly available datasets, covering a wide variety of tasks and capabilities, and transform them into instruction tuning format. Additionally, we introduce an instruction-aware Query Transformer, which extracts informative features tailored to the given instruction. Trained on 13 held-in datasets, InstructBLIP attains state-of-the-art zero-shot performance across all 13 held-out datasets, substantially outperforming BLIP-2 and larger Flamingo models. Our models also lead to state-of-the-art performance when finetuned on individual downstream tasks (e.g., 90.7% accuracy on ScienceQA questions with image contexts). Furthermore, we qualitatively demonstrate the advantages of InstructBLIP over concurrent multimodal models. All InstructBLIP models are open-sourced.
+
+
+

+
+
+## How to use it?
+
+
+
+**Use the model**
+
+```python
+from mmpretrain import inference_model
+
+result = inference_model('instructblip-vicuna7b_3rdparty-zeroshot_caption', 'demo/cat-dog.png')
+print(result)
+# {'pred_caption': 'a blanket next to each other in the grass\na cute puppy and kitten wallpapers'}
+```
+
+
+
+## Models and results
+
+For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAIR/MiniGPT-4) for preparation guidelines.
+
+### Pretrained models
+
+| Model | Params (M) | Flops (G) | Config | Download |
+| :-------------------------------------------------- | :--------: | :-------: | :----------------------------------------------: | :--------------------------------------------------------------------------------: |
+| `instructblip-vicuna7b_3rdparty-zeroshot_caption`\* | 8121.32 | N/A | [config](instructblip-vicuna7b_8xb32_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/instructblip/instruct-blip_vicuna7b_trimmed.pth) |
+
+*Models with * are converted from the [official repo](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip). The config files of these models are only for inference. We haven't reproduce the training results.*
+
+## Citation
+
+```bibtex
+@article{dai2023instructblip,
+ title={InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning},
+ author={Dai, Wenliang and Li, Junnan and Li, Dongxu and Tiong, Anthony Meng Huat and Zhao, Junqi and Wang, Weisheng and Li, Boyang and Fung, Pascale and Hoi, Steven},
+ journal={arXiv preprint arXiv:2305.06500},
+ year={2023}
+}
+```
diff --git a/configs/instructblip/instructblip-vicuna7b_8xb32_caption.py b/configs/instructblip/instructblip-vicuna7b_8xb32_caption.py
new file mode 100644
index 00000000000..5ce024c48d1
--- /dev/null
+++ b/configs/instructblip/instructblip-vicuna7b_8xb32_caption.py
@@ -0,0 +1,77 @@
+_base_ = [
+ '../_base_/datasets/coco_caption.py',
+ '../_base_/default_runtime.py',
+]
+
+# model settings
+model = dict(
+ type='InstructBlipCaption',
+ llm_tokenizer=dict(
+ type='LlamaTokenizer',
+ name_or_path=
+ '/mnt/petrelfs/share_data/liuyuan/llm_weights/vicuna_weights_7b'),
+ vision_encoder=dict(
+ type='BEiTViT',
+ # eva-g without the final layer
+ arch=dict(
+ embed_dims=1408,
+ num_layers=39,
+ num_heads=16,
+ feedforward_channels=6144,
+ ),
+ img_size=224,
+ patch_size=14,
+ out_indices=-2,
+ layer_scale_init_value=0.0,
+ use_abs_pos_emb=True,
+ use_rel_pos_bias=False,
+ frozen_stages=39,
+ final_norm=False,
+ use_shared_rel_pos_bias=False,
+ out_type='raw',
+ pretrained= # noqa
+ 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
+ ),
+ text_backbone=dict(
+ type='AutoModelForCausalLM',
+ name_or_path=
+ '/mnt/petrelfs/share_data/liuyuan/llm_weights/vicuna_weights_7b'),
+ Qformer=dict(
+ type='Qformer',
+ model_style='bert-base-uncased',
+ vision_model_width=1408,
+ add_cross_attention=True,
+ cross_attention_freq=2,
+ num_query_token=32),
+ prompt='Write a short description for the image.',
+ max_txt_len=30)
+
+# schedule settings
+optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))
+
+param_scheduler = [
+ dict(
+ type='CosineAnnealingLR',
+ by_epoch=True,
+ begin=0,
+ end=10,
+ )
+]
+
+train_cfg = dict(max_epochs=10)
+val_cfg = dict()
+test_cfg = dict()
+
+# dataset settings
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='Resize',
+ scale=(224, 224),
+ interpolation='bicubic',
+ backend='pillow'),
+ dict(type='PackInputs', meta_keys=['image_id']),
+]
+
+val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+test_dataloader = val_dataloader
diff --git a/configs/instructblip/metafile.yml b/configs/instructblip/metafile.yml
new file mode 100644
index 00000000000..345269eeed0
--- /dev/null
+++ b/configs/instructblip/metafile.yml
@@ -0,0 +1,33 @@
+Collections:
+ - Name: InstructBLIP
+ Metadata:
+ Training Data:
+ - COCO
+ - VG
+ - CC3M
+ - CC12M
+ - SBU
+ - LAION-400M
+ Architecture:
+ - Transformer
+ - Q-Former
+ Paper:
+ Title: 'InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning'
+ URL: https://arxiv.org/abs/2305.06500
+ README: configs/instructblip/README.md
+
+Models:
+ - Name: instructblip-vicuna7b_3rdparty-zeroshot_caption
+ Metadata:
+ FLOPs: null
+ Parameters: xxx
+ In Collection: InstructBLIP
+ Results:
+ - Task: Image Caption
+ Dataset: COCO
+ Metrics: null
+ Weights: https://download.openmmlab.com/mmclassification/v1/instructblip/instruct-blip_vicuna7b_trimmed.pth
+ Config: configs/instructblip/instructblip-vicuna7b_8xb32_caption.py
+ Converted From:
+ Weights: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth
+ Code: https://github.com/salesforce/LAVIS
diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py
index 072c0f84f72..0714f1b9c0d 100644
--- a/mmpretrain/models/multimodal/__init__.py
+++ b/mmpretrain/models/multimodal/__init__.py
@@ -6,6 +6,7 @@
from .blip2 import * # noqa: F401,F403
from .chinese_clip import * # noqa: F401, F403
from .flamingo import * # noqa: F401, F403
+ from .instructblip import * # noqa: F401,F403
from .llava import * # noqa: F401, F403
from .minigpt4 import * # noqa: F401, F403
from .ofa import * # noqa: F401, F403
@@ -17,5 +18,6 @@
register_multimodal_placeholder([
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
- 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter'
+ 'OFA', 'ChineseCLIP', 'InstructBlipCaption', 'MiniGPT4', 'Llava',
+ 'Otter'
], MODELS)
diff --git a/mmpretrain/models/multimodal/instructblip/__init__.py b/mmpretrain/models/multimodal/instructblip/__init__.py
new file mode 100644
index 00000000000..41a0ec7d868
--- /dev/null
+++ b/mmpretrain/models/multimodal/instructblip/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .instructblip_caption import InstructBlipCaption
+
+__all__ = ['InstructBlipCaption']
diff --git a/mmpretrain/models/multimodal/instructblip/instructblip_caption.py b/mmpretrain/models/multimodal/instructblip/instructblip_caption.py
new file mode 100644
index 00000000000..a9c2a348149
--- /dev/null
+++ b/mmpretrain/models/multimodal/instructblip/instructblip_caption.py
@@ -0,0 +1,271 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional
+
+import torch
+from mmengine.model import BaseModel
+from torch import nn
+from transformers import BertTokenizer
+
+from mmpretrain.registry import MODELS, TOKENIZER
+from mmpretrain.structures import DataSample
+
+
+@MODELS.register_module()
+class InstructBlipCaption(BaseModel):
+ """InstructBlip Caption.
+
+ Module for InstructBlip Caption task.
+
+ Args:
+ vision_encoder (dict): The config dict for vision backbone.
+ text_backbone (dict): The config dict for text backbone.
+ Qformer (dict): The config dict for multimodal backbone.
+ llm_proj (dict): The config dict for vision neck.
+ llm_tokenizer: (Optional[dict]): The config for llm_tokenizer.
+ Defaults to None.
+ prompt (str): Prompt used for training and eval.
+ Defaults to ''.
+ max_txt_len (int): Max text length of input text.
+ num_captions (int): Number of captions to be generated for each image.
+ data_preprocessor (Optional[dict]): The config for preprocessing input
+ data. If None or no specified type, it will use
+ "MultiModalDataPreprocessor" as type.
+ See :class:`MultiModalDataPreprocessor` for more details.
+ Defaults to None.
+ init_cfg (Optional[dict]): the config to control the initialization.
+ Defaults to None.
+ """
+ _no_split_modules = ['BEiTViT', 'BertLayer']
+
+ def __init__(self,
+ vision_encoder: dict,
+ text_backbone: dict,
+ Qformer: dict,
+ llm_tokenizer: Optional[dict] = None,
+ prompt: str = '',
+ max_txt_len: int = 256,
+ end_sym: str = '\n',
+ num_captions: int = 1,
+ qformer_text_input=True,
+ data_preprocessor: Optional[dict] = None,
+ init_cfg: Optional[dict] = None) -> None:
+ if data_preprocessor is None:
+ data_preprocessor = {}
+ if isinstance(data_preprocessor, dict):
+ data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
+ data_preprocessor = MODELS.build(data_preprocessor)
+
+ super().__init__(
+ init_cfg=init_cfg, data_preprocessor=data_preprocessor)
+
+ # build vision model
+ vision_encoder_weight = vision_encoder.pop('pretrained', None)
+ self.vision_encoder = MODELS.build(vision_encoder)
+ self.ln_vision = nn.LayerNorm(self.vision_encoder.embed_dims)
+
+ if vision_encoder_weight is not None:
+ from mmengine.runner.checkpoint import load_checkpoint
+ load_checkpoint(self.vision_encoder, vision_encoder_weight)
+
+ # build Qformer
+ self.tokenizer = BertTokenizer.from_pretrained(
+ 'bert-base-uncased', truncation_side='left')
+ self.tokenizer.add_special_tokens({'bos_token': '[DEC]'})
+ self.Qformer = MODELS.build(Qformer)
+
+ if not qformer_text_input:
+ self.Qformer.bert.embeddings.word_embeddings = None
+ self.Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ else:
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
+ self.Qformer.cls = None
+
+ # build language model
+ self.llm_tokenizer = TOKENIZER.build(llm_tokenizer)
+ self.text_backbone = MODELS.build(text_backbone)
+
+ self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+ self.llm_tokenizer.add_special_tokens({'bos_token': ''})
+ self.llm_tokenizer.add_special_tokens({'eos_token': ''})
+ self.llm_tokenizer.add_special_tokens({'unk_token': ''})
+ # self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
+
+ self.text_backbone.resize_token_embeddings(len(self.llm_tokenizer))
+ self.eos_token_id = self.llm_tokenizer(
+ '\n', add_special_tokens=False).input_ids[0]
+
+ # freeze the text backbone
+ for _, param in self.text_backbone.named_parameters():
+ param.requires_grad = False
+
+ self.query_tokens = nn.Parameter(
+ torch.zeros(1, self.Qformer.bert.config.query_length,
+ self.Qformer.bert.config.hidden_size))
+ self.query_tokens.data.normal_(
+ mean=0.0, std=self.Qformer.bert.config.initializer_range)
+
+ # build linear projection layer
+ self.llm_proj = nn.Linear(self.Qformer.config.hidden_size,
+ self.text_backbone.config.hidden_size)
+
+ self.prompt = prompt
+ self.max_txt_len = max_txt_len
+ self.end_sym = end_sym
+ self.end_token_id = self.llm_tokenizer.encode(end_sym)[-1]
+ self.num_captions = num_captions
+ prompt_tokens = self.llm_tokenizer(prompt, return_tensors='pt')
+ self.prompt_length = prompt_tokens.attention_mask.sum(1)
+ self.qformer_text_input = qformer_text_input
+
+ if hasattr(self, 'register_load_state_dict_post_hook'):
+ self.register_load_state_dict_post_hook(self._ignore_llm_keys_hook)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ data_samples: Optional[List] = None,
+ mode: str = 'loss',
+ ) -> List[DataSample]:
+ """The unified entry for a forward process in both training and test.
+ The method should accept two modes: "predict" and "loss":
+
+ - "predict": Forward and return the predictions, which are fully
+ processed to a list of :obj:`DataSample`.
+ - "loss": Forward and return a dict of losses according to the given
+ inputs and data samples.
+
+ Note that this method doesn't handle neither back propagation nor
+ optimizer updating, which are done in the :meth:`train_step`.
+
+ Args:
+ images (torch.Tensor): pre_processed img tensor (N, C, ...).
+ data_samples (List[DataSample], optional):
+ mode (str): Return what kind of value. Defaults to 'loss'.
+
+ Returns:
+ The return type depends on ``mode``.
+ - If ``mode="loss"``, return a dict of tensor.
+ """
+ if mode == 'loss':
+ return self.loss(images, data_samples)
+ elif mode == 'predict':
+ return self.predict(images, data_samples)
+ else:
+ raise RuntimeError(f'Invalid mode "{mode}".')
+
+ def predict(self,
+ images: torch.Tensor,
+ data_samples: Optional[list] = None,
+ **kwargs) -> List[DataSample]:
+ """Predict captions from a batch of inputs.
+
+ Args:
+ images (torch.Tensor): The input tensor with shape
+ (N, C, ...) in general.
+ data_samples (List[DataSample], optional): The annotation
+ data of every samples. Defaults to None.
+ **kwargs: Other keyword arguments accepted by the ``predict``
+ method of :attr:`head`.
+
+ Returns:
+ List[DataSample]: Return list of data samples.
+ """
+ self.llm_tokenizer.padding_side = 'left'
+
+ # extract image features from
+ image_embeds = self.ln_vision(self.vision_encoder(images)[0])
+ image_atts = torch.ones(
+ image_embeds.size()[:-1],
+ dtype=torch.long,
+ ).to(images.device)
+
+ prompt = [self.prompt] * image_embeds.size(0)
+
+ # distill image features to query tokens
+ query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1)
+
+ if self.qformer_text_input:
+ text_Qformer = self.tokenizer(
+ prompt,
+ padding='longest',
+ truncation=True,
+ max_length=self.max_txt_len,
+ return_tensors='pt',
+ ).to(images.device)
+ query_atts = torch.ones(
+ query_tokens.size()[:-1], dtype=torch.long).to(images.device)
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],
+ dim=1)
+
+ if self.qformer_text_input:
+ query_outputs = self.Qformer.bert(
+ text_Qformer.input_ids,
+ attention_mask=Qformer_atts,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ else:
+ query_outputs = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ inputs_llama = self.llm_proj(
+ query_outputs.last_hidden_state[:, :query_tokens.size(1), :])
+ attns_llama = torch.ones(
+ inputs_llama.size()[:-1], dtype=torch.long).to(images.device)
+
+ llama_tokens = self.llm_tokenizer(
+ prompt, padding='longest', return_tensors='pt').to(images.device)
+
+ inputs_embeds = self.text_backbone.get_input_embeddings()(
+ llama_tokens.input_ids)
+ inputs_embeds = torch.cat([inputs_llama, inputs_embeds], dim=1)
+ attention_mask = torch.cat([attns_llama, llama_tokens.attention_mask],
+ dim=1)
+
+ outputs = self.text_backbone.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ do_sample=False,
+ top_p=0.9,
+ temperature=1.,
+ num_beams=5,
+ max_new_tokens=self.max_txt_len,
+ min_length=1,
+ repetition_penalty=1.5,
+ length_penalty=1.0,
+ num_return_sequences=self.num_captions,
+ )
+
+ output_text = self.llm_tokenizer.batch_decode(
+ outputs[:, self.prompt_length:], skip_special_tokens=True)
+ output_text = [text.strip() for text in output_text]
+
+ out_data_samples = []
+ if data_samples is None:
+ data_samples = [None for _ in range(len(output_text))]
+
+ for data_sample, decode_token in zip(data_samples, output_text):
+ if data_sample is None:
+ data_sample = DataSample()
+ data_sample.pred_caption = decode_token
+ out_data_samples.append(data_sample)
+
+ return out_data_samples
+
+ @staticmethod
+ def _ignore_llm_keys_hook(module, incompatible_keys):
+ """Avoid warning missing keys of the LLM model."""
+ import re
+ llm_pattern = '^text_backbone'
+ for key in list(incompatible_keys.missing_keys):
+ if re.match(llm_pattern, key):
+ incompatible_keys.missing_keys.remove(key)
diff --git a/model-index.yml b/model-index.yml
index 3fb3d0457d6..4f25902ff1d 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -82,3 +82,4 @@ Import:
- configs/minigpt4/metafile.yml
- configs/llava/metafile.yml
- configs/otter/metafile.yml
+ - configs/instructblip/metafile.yml