diff --git a/dig/auggraph/dataset/aug_dataset.py b/dig/auggraph/dataset/aug_dataset.py index 9c199678..22cd0c5c 100644 --- a/dig/auggraph/dataset/aug_dataset.py +++ b/dig/auggraph/dataset/aug_dataset.py @@ -183,9 +183,9 @@ def __getitem__(self, index): while pos_index == index: pos_index = random.sample(self.label_to_index_list[anchor_label], 1)[0] - neg_label = random.sample(self.label_to_index_list.keys(), 1)[0] + neg_label = random.sample(list(self.label_to_index_list.keys()), 1)[0] while neg_label == anchor_label: - neg_label = random.sample(self.label_to_index_list.keys(), 1)[0] + neg_label = random.sample(list(self.label_to_index_list.keys()), 1)[0] neg_index = random.sample(self.label_to_index_list[neg_label], 1)[0] pos_data, neg_data = self.dataset[pos_index], self.dataset[neg_index] @@ -201,4 +201,4 @@ def __len__(self): Returns: The number of samples in the original dataset. """ - return len(self.dataset) \ No newline at end of file + return len(self.dataset)