import torch import math def flat_to_sym(V, N): A = torch.zeros(N, N, dtype=V.dtype, device=V.device) idxs = torch.tril_indices(N, N, device=V.device) A[idxs.unbind()] = V A[idxs[1, :], idxs[0, :]] = V return A def block_LDL(H, b, check_nan=True): n = H.shape[0] assert (n % b == 0) m = n // b try: L = torch.linalg.cholesky(H) except: return None DL = torch.diagonal(L.reshape(m, b, m, b), dim1=0, dim2=2).permute(2, 0, 1) D = (DL @ DL.permute(0, 2, 1)).cpu() DL = torch.linalg.inv(DL) L = L.view(n, m, b) for i in range(m): L[:, i, :] = L[:, i, :] @ DL[i, :, :] if check_nan and L.isnan().any(): return None L = L.reshape(n, n) return (L, D.to(DL.device)) def approx_int_sqrt(n): p = int(math.floor(math.sqrt(n))) while (n % p != 0): p -= 1 return (p, n // p) def regularize_H(H, n, sigma_reg): H.div_(torch.diag(H).mean()) idx = torch.arange(n) H[idx,idx] += sigma_reg return H # return H / torch.diag(H).mean() + sigma_reg * torch.eye(n, device=H.device)