""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import os from minigpt4.common.registry import registry from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder from minigpt4.datasets.datasets.laion_dataset import LaionDataset from minigpt4.datasets.datasets.cc_combine_dataset import CCCombineDataset, CCAlignDataset @registry.register_builder("cc_combine") class CCCombineBuilder(BaseDatasetBuilder): train_dataset_cls = CCCombineDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_combine/defaults.yaml"} def _download_ann(self): pass def _download_vis(self): pass def build(self): self.build_processors() build_info = self.config.build_info datasets = dict() split = "train" # create datasets # [NOTE] return inner_datasets (wds.DataPipeline) dataset_cls = self.train_dataset_cls datasets[split] = dataset_cls( vis_processor=self.vis_processors[split], text_processor=self.text_processors[split], location=build_info.storage, ).inner_dataset return datasets @registry.register_builder("laion") class LaionBuilder(BaseDatasetBuilder): train_dataset_cls = LaionDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} def _download_ann(self): pass def _download_vis(self): pass def build(self): self.build_processors() build_info = self.config.build_info datasets = dict() split = "train" # create datasets # [NOTE] return inner_datasets (wds.DataPipeline) dataset_cls = self.train_dataset_cls datasets[split] = dataset_cls( vis_processor=self.vis_processors[split], text_processor=self.text_processors[split], location=build_info.storage, ).inner_dataset return datasets @registry.register_builder("cc_align") class CCAlignBuilder(BaseDatasetBuilder): train_dataset_cls = CCAlignDataset DATASET_CONFIG_DICT = { "default": "configs/datasets/cc_combine/align.yaml", }