我正在尝试使用某种受保护的除法,Tensorflow.where
但不知何故,它似乎跳过了where
语句中设置的条件。
主要思想是,当除法时x/y
,如果y == 0.
则除法的结果是x
而不是抛出和错误。
我的代码如下:
def Pdivide(x,y):
result = tf.where(y == 0., x, x/y)
return result
但不知何故,这个条件被跳过了:
>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])
>>>Pdivide(a,b)
>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)
预期输出:
>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)
PS:使用eager
执行。