of the amplitude and phase RBMS
:rtype: list[torch.Tensor, torch.Tensor]
rotated_grad = [
torch.zeros(
2, getattr(self, net).num_pars, dtype=torch.double, device=self.device
)
for net in self.networks
]
Upsi, Upsi_v, v = unitaries.rotate_psi_inner_prod(
self, basis, sample, include_extras=True
)
grad_vp0 = self.rbm_am.effective_energy_gradient(v, reduce=False)
grad_vp1 = self.rbm_ph.effective_energy_gradient(v, reduce=False)
// since grad_vp0/1 are real, can just treat the scalar multiplication
// and addition as a matrix multiplication
torch.matmul(Upsi_v, grad_vp0, out=rotated_grad[0])
torch.matmul(Upsi_v, grad_vp1, out=rotated_grad[1])
grad = [
cplx.real(cplx.scalar_divide(rotated_grad[0], Upsi)), // Real
-cplx.imag(cplx.scalar_divide(rotated_grad[1], Upsi)), // Imaginary