Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not compatible with huggingface dataset #289

Open
quang-anh-nguyen opened this issue Jul 6, 2023 · 0 comments
Open

Not compatible with huggingface dataset #289

quang-anh-nguyen opened this issue Jul 6, 2023 · 0 comments

Comments

@quang-anh-nguyen
Copy link

Hello, I was trying to use the PromptDataLoader for an instance of the datasets.Dataset class, as shown in the below code.

train_loader = opr.PromptDataLoader(
    dataset=datasets['train'], 
    template=template, 
    tokenizer=tokenizer,
    tokenizer_wrapper_class=wrapper_plm
)

But I always get the error

NotImplementedError                       Traceback (most recent call last)
Cell In[253], line 3
      1 from openprompt.data_utils import InputExample
----> 3 train_loader = opr.PromptDataLoader(
      4     dataset=datasets['train'], 
      5     template=template, 
      6     tokenizer=tokenizer,
      7     tokenizer_wrapper_class=wrapper_plm
      8 )

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/openprompt/pipeline_base.py:100, in PromptDataLoader.__init__(self, dataset, template, tokenizer_wrapper, tokenizer, tokenizer_wrapper_class, verbalizer, max_seq_length, batch_size, shuffle, teacher_forcing, decoder_max_length, predict_eos_token, truncate_method, drop_last, **kwargs)
     96 assert hasattr(self.template, 'wrap_one_example'), "Your prompt has no function variable \
     97                                                  named wrap_one_example"
     99 # process
--> 100 self.wrap()
    101 self.tokenize()
    103 if self.shuffle:

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/openprompt/pipeline_base.py:129, in PromptDataLoader.wrap(self)
    127         self.wrapped_dataset.append(wrapped_example)
    128 else:
--> 129     raise NotImplementedError

NotImplementedError: 

When I looked at the source code, apparently the reason is that in the PromptDataLoader.wrap method, dataset must be torch.utils.data.Dataset or List[InputExample]. However, changing the class will be very complicated.

    def wrap(self):
        r"""A simple interface to pass the examples to prompt, and wrap the text with template.
        """
        if isinstance(self.raw_dataset, Dataset) or isinstance(self.raw_dataset, List):
            assert len(self.raw_dataset) > 0, 'The dataset to be wrapped is empty.'
            # for idx, example in tqdm(enumerate(self.raw_dataset),desc='Wrapping'):
            for idx, example in enumerate(self.raw_dataset):
                if self.verbalizer is not None and hasattr(self.verbalizer, 'wrap_one_example'): # some verbalizer may also process the example.
                    example = self.verbalizer.wrap_one_example(example)
                wrapped_example = self.template.wrap_one_example(example)
                self.wrapped_dataset.append(wrapped_example)
        else:
            raise NotImplementedError

Can you please make it compatible with datasets.Dataset, since I believe that many people use huggingface? Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant