You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- GCC version:11.4.0
- Torch version:2.4.0+cu121
- Linux system version:Ubuntu 22.04
- CUDA version:12.5
- Torch's CUDA version (as per `torch.cuda.version()`):12.1
from model_center.dataset.bertdataset import DATASET
from model_center.dataset import DistributedDataLoader
from model_center.tokenizer import BertTokenizer
Is there an existing issue for this?
Description of the Bug
KeyError occurs when loading dataset.
Environment Information
To Reproduce
https://github.com/OpenBMB/ModelCenter/blob/main/README-ZH.md#3-%E5%87%86%E5%A4%87%E6%95%B0%E6%8D%AE%E9%9B%86
from model_center.dataset.bertdataset import DATASET
from model_center.dataset import DistributedDataLoader
from model_center.tokenizer import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
splits = ['train', 'dev']
dataset = {}
for split in splits:
dataset[split] = DATASET['BoolQ']('your_path_to_dataset', split, bmt.rank(), bmt.world_size(), tokenizer, max_encoder_length=512)
batch_size = 64
train_dataloader = DistributedDataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
dev_dataloader = DistributedDataLoader(dataset['dev'], batch_size=batch_size, shuffle=False)
Expected Behavior
{
"name": "KeyError",
"message": "'label'",
"stack": "---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[8], line 10
7 dataset = {}
9 for split in splits:
---> 10 dataset[split] = DATASET['BoolQ']('/home/bmtrain/bert', split, bmt.rank(), bmt.world_size(), tokenizer, max_encoder_length=512)
12 batch_size = 64
13 train_dataloader = DistributedDataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
File ~/.local/lib/python3.10/site-packages/model_center/dataset/bertdataset/superglue.py:92, in BoolQ_Dataset.init(self, path, split, rank, world_size, tokenizer, max_encoder_length)
90 from tqdm import tqdm
91 for row in self.read_data("BoolQ", path, split, rank, world_size):
---> 92 label = 1 if row["label"]==True else 0
93 text_a = row['passage']
94 text_b = row['question']
KeyError: 'label'"
}
Screenshots
No response
Additional Information
No response
Confirmation
The text was updated successfully, but these errors were encountered: