:rtype: list[torch.Tensor, torch.Tensor]
rotated_grad = [
torch.zeros(
2, getattr(self, net).num_pars, dtype=torch.double, device=self.device
)
for net in self.networks
]
Upsi, Upsi_v, v = unitaries.rotate_psi_inner_prod(
self, basis, sample, include_extras=True
)
grad_vp0 = self.rbm_am.effective_energy_gradient(v, reduce=False)
grad_vp1 = self.rbm_ph.effective_energy_gradient(v, reduce=False)
// since grad_vp0/1 are real, can just treat the scalar multiplication
// and addition as a matrix multiplication
torch.matmul(Upsi_v, grad_vp0, out=rotated_grad[0])
torch.matmul(Upsi_v, grad_vp1, out=rotated_grad[1])
grad = [
After Change
vr = v.reshape(-1, v.shape[-1])
raw_grads = [
self.am_grads(vr).reshape(2, *v.shape[:-1], -1),
self.ph_grads(vr).reshape(2, *v.shape[:-1], -1),
]
rotated_grad = [cplx.einsum("s...,s...g->...g", Upsi_v, g) for g in raw_grads]
grad = [
cplx.real(cplx.scalar_divide(rotated_grad[0], Upsi)),