ra = geometry.unsqueeze(2) // [batch, a, 1, xyz]
k = self.kernel(rb - ra) // [batch, a, b, 6 * i, j]
return torch.einsum("zabij,zacij,zadij,zbcij,zbdij,zcdij,zcdj->zabi",
(*k.split(k.size(3) // 6, 3), features)) // [batch, a, b, channel]
class ApplyKernel(torch.nn.Module):
After Change
ra = geometry.unsqueeze(2) // [batch, a, 1, xyz]
k = self.kernel(rb - ra) // [batch, a, b, 6 * i, j]
k1, k2, k3, k4, k5, k6 = k.split(k.size(3) // 6, 3)
out = torch.einsum("zabij,zcdj->zabi", (k1, features)) // [batch, a, b, channel]
out += torch.einsum("zacij,zcdj->zabi", (k2, features)) // [batch, a, b, channel]
out += torch.einsum("zadij,zcdj->zabi", (k3, features)) // [batch, a, b, channel]
out += torch.einsum("zbcij,zcdj->zabi", (k4, features)) // [batch, a, b, channel]