Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

分类算法sentence句子编码的时候,没理解到mask处理逻辑 #158

Open
cwqJim2023 opened this issue Nov 14, 2023 · 1 comment

Comments

@cwqJim2023
Copy link

提问时请尽可能提供如下信息:

基本信息

  • 你使用的操作系统:
  • 你使用的Python版本:
  • 你使用的Pytorch版本:
  • 你使用的bert4torch版本:
  • 你加载的预训练模型:

核心代码

# 请在此处贴上你的核心代码

def collate_fn(batch):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for text, label in batch:
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append([label])

batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
return [batch_token_ids, batch_segment_ids], batch_labels.flatten()

加载数据集

train_dataloader = DataLoader(MyDataset(['E:/data/corpus/sentence_classification/sentiment/sentiment.train.data']), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset(['E:/data/corpus/sentence_classification/sentiment/sentiment.valid.data']), batch_size=batch_size, collate_fn=collate_fn)
test_dataloader = DataLoader(MyDataset(['E:/data/corpus/sentence_classification/sentiment/sentiment.test.data']), batch_size=batch_size, collate_fn=collate_fn)

请问 token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen),mask部分是怎么处理的?

输出信息

# 请在此处贴上你的调试输出

自我尝试

此处请贴上你的自我尝试过程

@cwqJim2023 cwqJim2023 changed the title 分类算法sentence句子编码的时候,没看到maskbufe 分类算法sentence句子编码的时候,没理解到mask处理逻辑 Nov 14, 2023
@Tongjilibo
Copy link
Owner

mask是在框架内部处理的,默认是0,如果你的config.json文件,或者你在build_transformer_model()时候传入pad_token_id,则按照该token_id自行计算attention_mask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants