from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
if torch.is_tensor(alpha):
alpha = alpha.item()
if torch.is_tensor(beta):
beta = beta.item()
if torch.is_tensor(gamma):
gamma = gamma.item()
return torch.tensor(wigner_D_matrix(order, alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype)
After Change
abc[i] = x.item()
if dtype is None:
dtype = x.dtype
if device is None:
device = x.device
if dtype is None:
dtype = torch.get_default_dtype()
return torch.tensor(wigner_D_matrix(order, *abc), dtype=dtype, device=device)