2 回答
TA贡献1833条经验 获得超4个赞
我浏览了torchtext源代码以更好地了解sort_key在做什么,并了解了为什么我的原始想法不起作用。
我不确定这是否是最好的解决方案,但是我想出了一个可行的解决方案。我创建了一个tokenizer函数,如果它比最长的过滤器长度短,则填充文本,然后从那里创建BucketIterator。
FILTER_SIZES = [3,4,5]
spacy_en = spacy.load('en')
def tokenizer(text):
token = [t.text for t in spacy_en.tokenizer(text)]
if len(token) < FILTER_SIZES[-1]:
for i in range(0, FILTER_SIZES[-1] - len(token)):
token.append('<PAD>')
return token
TEXT = Field(sequential=True, tokenize=tokenizer, lower=True, tensor_type=torch.cuda.LongTensor)
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)
TA贡献1853条经验 获得超6个赞
尽管@ paul41的方法有效,但还是有些滥用。这样做的正确方法是使用preprocessing或postprocessing(相应地在数字化之前或之后)。这是一个示例postprocessing:
def get_pad_to_min_len_fn(min_length):
def pad_to_min_len(batch, vocab, min_length=min_length):
pad_idx = vocab.stoi['<pad>']
for idx, ex in enumerate(batch):
if len(ex) < min_length:
batch[idx] = ex + [pad_idx] * (min_length - len(ex))
return batch
return pad_to_min_len
FILTER_SIZES = [3,4,5]
min_len_padding = get_pad_to_min_len_fn(min_length=max(FILTER_SIZES))
TEXT = Field(sequential=True, use_vocab=True, lower=True, batch_first=True,
postprocessing=min_len_padding)
如果在主循环中定义了嵌套函数(例如min_length = max(FILTER_SIZES)),则需要将参数传递给内部函数,但如果可行,则可以在函数内部对参数进行硬编码。
添加回答
举报