1

我有一个函数需要在二维火炬张量的每个条目处计算,它取决于两个轴的索引值。现在,我只能将其实现为嵌套的 for 循环,在两个轴上进行迭代。这很慢(并且需要执行> 10 ^ 5次),我想加快它以获得更好的缩放。

vs = 200
nt = 12
b = torch.ones(vs)/vs
n_kw = torch.rand((nt, vs))
n_k = torch.rand((nt,))

def estimate_p(nt, vs, n_kw, n_k):
    p = torch.zeros((nt, vs))
    
    for i in range(0, nt):
        for j in range(0, vs):
            p[i,j] = (n_kw[i, j] + b[j])/(n_k[i] + torch.sum(b))
    return p

有没有办法根据i,j索引对这个/地图进行矢量化?

4

1 回答 1

1

尝试播放广播

def estimate_p(nt, vs, n_kw, n_k):
    return (n_kw + b) / (n_k + b.sum()).unsqueeze(-1)
于 2021-01-27T19:00:55.203 回答