为了账号安全,请及时绑定邮箱和手机立即绑定

torchtext BucketIterator最小填充

torchtext BucketIterator最小填充

尚方宝剑之说 2021-04-08 18:39:22
我正在尝试在torchtext中使用BucketIterator.splits函数从csv文件中加载数据以在CNN中使用。除非我的批处理中最长的句子比最大的过滤器大小短,否则一切都正常。在我的示例中,我使用了大小分别为3、4和5的过滤器,因此,如果最长的句子没有至少5个单词,则会出现错误。有没有一种方法可以让BucketIterator动态设置批次的填充,还可以设置最小填充长度?这是我用于BucketIterator的代码:train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)我希望有一种方法可以设置sort_key或类似的最小长度?我尝试了这个,但是不起作用:FILTER_SIZES = [3,4,5]train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text) if len(x.text) >= FILTER_SIZES[-1] else FILTER_SIZES[-1], batch_size=batch_size, repeat=False, device=device) 
查看完整描述

2 回答

?
潇潇雨雨

TA贡献1833条经验 获得超4个赞

我浏览了torchtext源代码以更好地了解sort_key在做什么,并了解了为什么我的原始想法不起作用。


我不确定这是否是最好的解决方案,但是我想出了一个可行的解决方案。我创建了一个tokenizer函数,如果它比最长的过滤器长度短,则填充文本,然后从那里创建BucketIterator。


FILTER_SIZES = [3,4,5]

spacy_en = spacy.load('en')


def tokenizer(text):

    token = [t.text for t in spacy_en.tokenizer(text)]

    if len(token) < FILTER_SIZES[-1]:

        for i in range(0, FILTER_SIZES[-1] - len(token)):

            token.append('<PAD>')

    return token


TEXT = Field(sequential=True, tokenize=tokenizer, lower=True, tensor_type=torch.cuda.LongTensor)


train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)


查看完整回答
反对 回复 2021-04-27
?
墨色风雨

TA贡献1853条经验 获得超6个赞

尽管@ paul41的方法有效,但还是有些滥用。这样做的正确方法是使用preprocessing或postprocessing(相应地在数字化之前或之后)。这是一个示例postprocessing:


def get_pad_to_min_len_fn(min_length):

    def pad_to_min_len(batch, vocab, min_length=min_length):

        pad_idx = vocab.stoi['<pad>']

        for idx, ex in enumerate(batch):

            if len(ex) < min_length:

                batch[idx] = ex + [pad_idx] * (min_length - len(ex))

        return batch

    return pad_to_min_len


FILTER_SIZES = [3,4,5]

min_len_padding = get_pad_to_min_len_fn(min_length=max(FILTER_SIZES))


TEXT = Field(sequential=True, use_vocab=True, lower=True, batch_first=True, 

             postprocessing=min_len_padding)

如果在主循环中定义了嵌套函数(例如min_length = max(FILTER_SIZES)),则需要将参数传递给内部函数,但如果可行,则可以在函数内部对参数进行硬编码。


查看完整回答
反对 回复 2021-04-27
  • 2 回答
  • 0 关注
  • 479 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信