0

我正在使用内置tf.nn.seq2seq.embedding_attention_seq2seq()函数,但参数有一些问题feed_previous,在训练期间,groundtruth 被输入解码器,而在测试期间,我们将最后一个时间步的输出输入解码器。问题是,一旦我设置了feed_previous参数,我就无法更改该参数。我想在每个 epoch 测试我的模型,我应该怎么做?

4

1 回答 1

0

文档中,您可以为 feed_previous 提供一个布尔张量。

feed_previous = tf.placeholder(tf.bool)
model = tf.nn.seq2seq.embedding_attention_seq2seq(..feed_previous=feed_previous...)
sess.run(loss, feed_dict={feed_previous=is_training, ...})
于 2017-01-10T00:52:13.513 回答