1 回答
TA贡献1858条经验 获得超8个赞
我自己发现的……需要重置模型的state. 所以下面的代码对我有用:
def translate(model, sentence, vocab_dir, vocab_file):
empty_state = model.state # save empty state
tokenized_sentence = next(trax.data.tokenize(iter([sentence]), vocab_dir=vocab_dir,
vocab_file=vocab_file))
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized_sentence[None, :], temperature=0.0)[0][:-1]
translation = trax.data.detokenize(tokenized_translation, vocab_dir=vocab_dir,
vocab_file=vocab_file)
model.state = empty_state # reset state
return translation
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(input_vocab_size=33300, d_model=512, d_ff=2048, n_heads=8,
n_encoder_layers=6, n_decoder_layers=6, max_len=2048,
mode='predict')
# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
weights_only=True)
print(translate(model, 'It is nice to learn new things today!',
vocab_dir='gs://trax-ml/vocabs/', vocab_file='ende_32k.subword'))
print(translate(model, 'I would like to try another example.',
vocab_dir='gs://trax-ml/vocabs/', vocab_file='ende_32k.subword'))
添加回答
举报