1

我正在尝试使用某种受保护的除法,Tensorflow.where但不知何故,它似​​乎跳过了where语句中设置的条件。

主要思想是,当除法时x/y,如果y == 0.则除法的结果是x而不是抛出和错误。

我的代码如下:

def Pdivide(x,y):
    result = tf.where(y == 0., x, x/y) 
    return result

但不知何故,这个条件被跳过了:

>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])

>>>Pdivide(a,b)

>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)

预期输出:

>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)

PS:使用eager执行。

4

1 回答 1

0

好的,答案显然很简单。

由于某种原因,张量元素无法与简单相比,==但使用tf.equal(y, 0.)可以解决问题并产生正确的输出。

于 2019-02-03T11:46:11.490 回答