@@ -235,8 +235,93 @@ def build(self):
235235 return datasets
236236
237237
238+ class MultiModalDatasetBuilder (BaseDatasetBuilder ):
239+ """
240+ MultiModalDatasetBuilder is a utility class designed to construct datasets
241+ suitable for multi-modal tasks. This class simplifies the creation of
242+ datasets that incorporate data of multiple modalities, such as text,
243+ images, video, or audio.
244+ """
245+ train_dataset_cls , eval_dataset_cls = None , None
246+
247+ def __init__ (self , cfg = None ):
248+ super ().__init__ (cfg )
249+ if isinstance (self .data_type , str ):
250+ self .data_type = [self .data_type ]
251+
252+ def _build_processor (self , cfg_name ):
253+ cfg = self .config .get (cfg_name )
254+ return {
255+ split : self ._build_proc_from_cfg (cfg .get (split ))
256+ if cfg is not None
257+ else None
258+ for split in ['train' , 'eval' ]
259+ }
260+
261+ def build_processors (self ):
262+ self .text_processors = self ._build_processor ("text_processor" )
263+
264+ self .processors = {
265+ split : {
266+ modality : self ._build_proc_from_cfg (
267+ self .config .get (f"{ 'vis' if 'image' in modality else modality } _processor" ).get (split )
268+ )
269+ for modality in self .data_type
270+ }
271+ for split in ['train' , 'eval' ]
272+ }
273+
274+ def _download_multimodal (self , modality ):
275+ storage_path = utils .get_cache_path (self .config .build_info .get (modality ).storage )
276+ if not os .path .exists (storage_path ):
277+ warnings .warn (f"The specified path { storage_path } for { modality } inputs does not exist." )
278+
279+ def _download_data (self ):
280+ self ._download_ann ()
281+ for modality in self .data_type :
282+ self ._download_multimodal (modality )
283+
284+ def _get_absolute_path (self , path ):
285+ if not os .path .isabs (path ):
286+ return utils .get_cache_path (path )
287+ return path
288+
289+ def build (self ):
290+ self .build_processors ()
291+ build_info = self .config .build_info
292+ datasets = {}
293+
294+ for split , info in build_info .annotations .items ():
295+ if split not in ["train" , "val" , "test" ]:
296+ continue
297+
298+ is_train = split == "train"
299+ dataset_args = self ._get_dataset_args (info , is_train )
300+
301+ dataset_cls = self .train_dataset_cls if is_train else self .eval_dataset_cls
302+ datasets [split ] = dataset_cls (** dataset_args )
303+
304+ return datasets
305+
306+ def _get_dataset_args (self , info , is_train ):
307+ dataset_args = dict (self .config .build_info .get ('kwargs' , {}))
308+
309+ for modality in self .data_type :
310+ proc_name = f"{ 'vis' if 'image' in modality else modality } _processor"
311+ dataset_args [proc_name ] = self .processors ["train" if is_train else "eval" ][modality ]
312+ mm_path = self ._get_absolute_path (self .config .build_info .get (modality ).storage )
313+ dataset_args [f"{ 'vis' if 'image' in modality else modality } _root" ] = mm_path
314+
315+ dataset_args ['text_processor' ] = self .text_processors ["train" if is_train else "eval" ]
316+ dataset_args ["ann_paths" ] = [self ._get_absolute_path (path ) for path in info .storage ]
317+ dataset_args ['modalities' ] = self .data_type
318+
319+ # Conform to base
320+ for key in ['vis_processor' , 'vis_root' , 'test_processor' ]:
321+ dataset_args .setdefault (key , None )
322+
323+ return dataset_args
324+
238325def load_dataset_config (cfg_path ):
239326 cfg = OmegaConf .load (cfg_path ).datasets
240- cfg = cfg [list (cfg .keys ())[0 ]]
241-
242- return cfg
327+ return next (iter (cfg .values ()))
0 commit comments