Source code for geometry_utils

"""
Module: geometry_utils.py

Batch-based utilities for computing geometric descriptors of molecules and fragments using PyTorch. Includes:
  - Best-fit plane / centroid
  - Planarity metrics (RMSD, max deviation, planarity score)
  - Global Steinhardt Q_l order parameters
  - Distances to special crystallographic planes (fractional)
  - Angles between bonds and special planes (fractional)
  - Quaternion computation from rotation matrices

Dependencies
------------
math
numpy
torch
typing
"""
import math
import numpy as np
import torch
from typing import List, Tuple, Dict, Any, Optional

_RAW_PLANE_NORMALS = torch.tensor([
    [1, 0, 0], [0, 1, 0], [0, 0, 1],
    [1, 1, 0], [1,-1, 0], [1, 0, 1], [1, 0,-1],
    [0, 1, 1], [0, 1,-1],
    [1, 1, 1], [1, 1,-1], [1,-1, 1], [1,-1,-1],
], dtype=torch.float32)
_DENOMINATORS = torch.tensor([4,6], dtype=torch.float32)

[docs] def compute_distances_to_crystallographic_planes_frac_batch( atom_frac_coords: torch.Tensor, atom_mask: torch.BoolTensor, device: torch.device ) -> torch.Tensor: """ Compute fractional distances of each atom to the 26 special crystallographic planes. Parameters ---------- atom_frac_coords : torch.Tensor, shape (B, A, 3) Fractional coordinates of atoms, padded to A slots per structure. atom_mask : torch.BoolTensor, shape (B, A) True for valid atoms, False for padding. device : torch.device Device on which to perform the computation. Returns ------- torch.Tensor, shape (B, A, 26) Absolute fractional distances to each plane (13 normals × 2 denominators). """ # Move inputs & constants on to the right device & dtype frac = atom_frac_coords.to(device) mask = atom_mask.to(device) dtype = frac.dtype normals = _RAW_PLANE_NORMALS.to(device=device, dtype=dtype) # (13,3) dens = _DENOMINATORS.to(device=device, dtype=dtype) # (2,) # unit normals: (13,3) norms = normals.norm(dim=1) # (13,) u_n = normals / norms.unsqueeze(1) # project each atom onto each normal: (B, A, 13) proj = frac.matmul(u_n.t()) * mask.unsqueeze(-1).to(dtype) # build denominator‐norm products: (13,2) D = norms.unsqueeze(1) * dens.unsqueeze(0) # (13,2) # compute projection × D → (B, A, 13, 2) pD = proj.unsqueeze(-1) * D.unsqueeze(0).unsqueeze(0) # distance to nearest plane: |pD – round(pD)| / D dist = (pD - pD.round()).abs() / D.unsqueeze(0).unsqueeze(0) # (B,A,13,2) dist2 = dist.permute(0, 1, 3, 2) # → (B, A, 2, 13) dist = dist2.reshape(dist2.shape[0], dist2.shape[1], -1) # flatten to (B, A, 26) return dist.reshape(dist.shape[0], dist.shape[1], -1)
[docs] def compute_angles_between_bonds_and_crystallographic_planes_frac_batch( atom_frac_coords: torch.Tensor, bond_atom1: torch.LongTensor, bond_atom2: torch.LongTensor, bond_mask: torch.BoolTensor, device: torch.device ) -> torch.Tensor: """ Compute angles (in degrees) between each bond vector and the 13 crystallographic plane normals. Parameters ---------- atom_frac_coords : torch.Tensor, shape (B, A, 3) Fractional coordinates of all atoms, padded to A per structure. bond_atom1 : torch.LongTensor, shape (B, M) Index of the first atom in each bond slot. bond_atom2 : torch.LongTensor, shape (B, M) Index of the second atom in each bond slot. bond_mask : torch.BoolTensor, shape (B, M) True for real bonds, False for padding. device : torch.device Device on which to perform the computation. Returns ------- torch.Tensor, shape (B, M, 13) Bond–plane angles in degrees; zeros where bond_mask is False. """ # move inputs onto device coords = atom_frac_coords.to(device) # (B, A, 3) idx1 = bond_atom1.to(device) # (B, M) idx2 = bond_atom2.to(device) # (B, M) mask = bond_mask.to(device) # (B, M) dtype = coords.dtype # prepare normals raw = _RAW_PLANE_NORMALS.to(device=device, dtype=dtype) # (13, 3) norms = raw.norm(dim=1) # (13,) unit_normals = raw / norms.unsqueeze(1) # (13, 3) # gather bond endpoints B, A, _ = coords.shape _, M = idx1.shape batch_idx = torch.arange(B, device=device).unsqueeze(1).expand(B, M) # (B, M) v1 = coords[batch_idx, idx1] # (B, M, 3) v2 = coords[batch_idx, idx2] # (B, M, 3) # compute bond vectors, zeroing out padding bond_vecs = (v2 - v1) * mask.unsqueeze(-1).to(dtype) # (B, M, 3) bond_len = bond_vecs.norm(dim=2).clamp(min=1e-8) # (B, M) # project onto each normal: dot = |v| cosθ → (B, M, 13) projs = torch.einsum('bmc,fc->bmf', bond_vecs, unit_normals) # cosθ = |dot| / ‖v‖ → clamp → arccos → degrees cos_theta = (projs.abs() / bond_len.unsqueeze(-1)).clamp(0.0, 1.0) angles_rad = torch.acos(cos_theta) # (B, M, 13) angles_deg = torch.rad2deg(angles_rad) return angles_deg # (B, M, 13)
[docs] def compute_atom_vectors_to_point_batch( atom_coords: torch.Tensor, # (B, N, 3) atom_frac_coords: torch.Tensor, # (B, N, 3) atom_mask: torch.BoolTensor, # (B, N) com_coords: torch.Tensor, # (B, 3) com_frac_coords: torch.Tensor, # (B, 3) device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute displacement vectors and Euclidean distances from each atom to a reference point. Parameters ---------- atom_coords : torch.Tensor, shape (B, N, 3) Cartesian coordinates, padded to N per fragment. atom_frac_coords : torch.Tensor, shape (B, N, 3) Fractional coordinates, padded to N per fragment. atom_mask : torch.BoolTensor, shape (B, N) True for real atoms, False for padding. com_coords : torch.Tensor, shape (B, 3) Reference points in Cartesian space. com_frac_coords : torch.Tensor, shape (B, 3) Reference points in fractional space. device : torch.device Device for computation. Returns ------- dists_cart : torch.Tensor, shape (B, N) Euclidean distances in Cartesian space. dists_frac : torch.Tensor, shape (B, N) Euclidean distances in fractional space. vecs_cart : torch.Tensor, shape (B, N, 3) Cartesian displacement vectors (atom → point). vecs_frac : torch.Tensor, shape (B, N, 3) Fractional displacement vectors. """ # 1) Move everything to the target device atom_coords = atom_coords.to(device) atom_frac_coords = atom_frac_coords.to(device) atom_mask = atom_mask.to(device) com_coords = com_coords.to(device) com_frac_coords = com_frac_coords.to(device) # 2) Expand mask for vector operations mask3 = atom_mask.to(atom_coords.dtype).unsqueeze(-1) # (B, N, 1) # 3) Compute displacement vectors and zero out padding vecs_cart = (atom_coords - com_coords.unsqueeze(1)) * mask3 vecs_frac = (atom_frac_coords - com_frac_coords.unsqueeze(1)) * mask3 # 4) Compute Euclidean distances (norm over last dim) dists_cart = vecs_cart.norm(dim=-1) # (B, N) dists_frac = vecs_frac.norm(dim=-1) # (B, N) return { 'fragment_atom_to_com_dist': dists_cart, 'fragment_atom_to_com_frac_dist': dists_frac, 'fragment_atom_to_com_vec': vecs_cart, 'fragment_atom_to_com_frac_vec': vecs_frac }
[docs] def compute_bond_rotatability_batch( atom_symbols: List[List[str]], bond_atom1_idx: torch.LongTensor, bond_atom2_idx: torch.LongTensor, bond_is_cyclic: torch.BoolTensor, bond_types: List[List[str]], bond_is_rotatable_raw: List[List[bool]], device: torch.device ) -> torch.BoolTensor: """ Determine which bonds are rotatable based on CSD criteria. Parameters ---------- atom_symbols : List[List[str]] Element symbols per atom for each structure (B × N). bond_atom1_idx : torch.LongTensor, shape (B, M) Index of the first atom in each bond slot. bond_atom2_idx : torch.LongTensor, shape (B, M) Index of the second atom in each bond slot. bond_is_cyclic : torch.BoolTensor, shape (B, M) True if the bond is in a ring. bond_types : List[List[str]] Bond types (e.g. 'single', 'double', 'triple') per slot. bond_is_rotatable_raw : List[List[bool]] Original CSD rotatability flags per slot. device : torch.device Device for the output tensor. Returns ------- torch.BoolTensor, shape (B, M) True where the bond passes all checks and is rotatable. """ B, M = bond_atom1_idx.shape result = torch.zeros((B, M), dtype=torch.bool, device=device) idx1_batch = bond_atom1_idx.tolist() idx2_batch = bond_atom2_idx.tolist() cyclic_batch = bond_is_cyclic.tolist() for b in range(B): symbols = atom_symbols[b] types = bond_types[b] raw = bond_is_rotatable_raw[b] idx1 = idx1_batch[b] idx2 = idx2_batch[b] cyclic = cyclic_batch[b] nb = len(types) # Build adjacency for actual bonds adj: Dict[int, List[Tuple[int,int]]] = {i: [] for i in range(len(symbols))} for j in range(nb): a1, a2 = idx1[j], idx2[j] if a1 < 0 or a2 < 0: continue adj[a1].append((a2, j)) adj[a2].append((a1, j)) def is_sp(atom_idx: int) -> bool: if symbols[atom_idx] not in ('C', 'N'): return False non_h = [nbr for nbr, _ in adj[atom_idx] if symbols[nbr] != 'H'] has_triple = any(types[slot].lower() == 'triple' for _, slot in adj[atom_idx]) return len(non_h) == 2 and has_triple def is_cumulated_double(atom_idx: int) -> bool: neigh = adj[atom_idx] if len(neigh) != 2: return False return all(types[slot].lower() == 'double' for _, slot in neigh) # Evaluate only actual bonds for j in range(nb): a1, a2 = idx1[j], idx2[j] # 1) Terminal bond? non_h1 = [nbr for nbr, _ in adj[a1] if symbols[nbr] != 'H'] non_h2 = [nbr for nbr, _ in adj[a2] if symbols[nbr] != 'H'] if len(non_h1) <= 1 or len(non_h2) <= 1: continue # 2) Ring or non-single? if cyclic[j] or types[j].lower() != 'single': continue # 3) Linear arrangement? if is_sp(a1) or is_sp(a2) or (is_cumulated_double(a1) and is_cumulated_double(a2)): continue # 4) Fallback if raw[j]: result[b, j] = True return result
[docs] def compute_bond_angles_batch( atom_labels: List[List[str]], atom_coords: torch.Tensor, atom_mask: torch.BoolTensor, bond_atom1_idx: torch.LongTensor, bond_atom2_idx: torch.LongTensor, bond_mask: torch.BoolTensor, device: torch.device ) -> Tuple[List[List[str]], torch.Tensor, torch.BoolTensor, torch.LongTensor]: """ Compute all bond angles (i–j–k) in a batch. Parameters ---------- atom_labels : List[List[str]] Atom labels per structure (B × N). atom_coords : torch.Tensor, shape (B, N, 3) Cartesian coordinates of atoms. atom_mask : torch.BoolTensor, shape (B, N) True for real atoms. bond_atom1_idx : torch.LongTensor, shape (B, M) Index of first atom in each bond slot. bond_atom2_idx : torch.LongTensor, shape (B, M) Index of second atom in each bond slot. bond_mask : torch.BoolTensor, shape (B, M) True for real bonds. device : torch.device Device for computation. Returns ------- angle_ids : List[List[str]] Per-structure list of “i–j–k” strings. angles : torch.Tensor, shape (B, P_max) Angle values in degrees (0 where padding). mask_ang : torch.BoolTensor, shape (B, P_max) True for real angles. idx_tensor : torch.LongTensor, shape (B, P_max, 3) Atom index triples for each angle. """ # move everything to device coords = atom_coords.to(device) # (B, N, 3) mask_a = atom_mask.to(device) # (B, N) idx1 = bond_atom1_idx.to(device).long() # (B, M) idx2 = bond_atom2_idx.to(device).long() # (B, M) mask_b = bond_mask.to(device) # (B, M) B, N, _ = coords.shape _, M = idx1.shape # 1) Which bonds are “real”? valid = mask_b & mask_a.gather(1, idx1) & mask_a.gather(1, idx2) # (B, M) # 2) Build the M×M m<n mask m = torch.arange(M, device=device) mm, nn = torch.meshgrid(m, m, indexing='ij') # each (M, M) tril = mm < nn # (M, M) pair_valid = valid.unsqueeze(2) & valid.unsqueeze(1) & tril.unsqueeze(0) # (B, M, M) # 3) Detect which bond-pairs share an atom idx1m, idx1n = idx1.unsqueeze(2), idx1.unsqueeze(1) # (B, M, 1), (B, 1, M) idx2m, idx2n = idx2.unsqueeze(2), idx2.unsqueeze(1) shared_11 = (idx1m == idx1n) & pair_valid shared_12 = (idx1m == idx2n) & pair_valid shared_21 = (idx2m == idx1n) & pair_valid shared_22 = (idx2m == idx2n) & pair_valid # combine for convenience share1 = shared_11 | shared_12 # central is idx1m share_n1 = shared_11 | shared_21 # for selecting wing2 shared_all = shared_11 | shared_12 | shared_21 | shared_22 # (B, M, M) # 4) Extract the triples (i, j, k) for each true shared_all[b, m, n] # central atom j: central = torch.where(share1, idx1m, idx2m) # (B, M, M) # wing1 = the “other end” of bond m wing1 = torch.where(share1, idx2m, idx1m) # wing2 = the “other end” of bond n wing2 = torch.where(share_n1, idx2n, idx1n) # 5) Figure out how many angles per structure counts = shared_all.view(B, -1).sum(dim=1) # (B,) P_max = int(counts.max().item()) # 6) Pack into fixed‐shape tensor + mask idx_tensor = torch.full((B, P_max, 3), -1, dtype=torch.long, device=device) mask_ang = torch.zeros((B, P_max), dtype=torch.bool, device=device) for b in range(B): m_idx, n_idx = torch.where(shared_all[b]) # each (P_b,) P_b = m_idx.size(0) if P_b > 0: j = central[b, m_idx, n_idx] # (P_b,) i = wing1[b, m_idx, n_idx] k = wing2[b, m_idx, n_idx] idx_tensor[b, :P_b, 0] = i idx_tensor[b, :P_b, 1] = j idx_tensor[b, :P_b, 2] = k mask_ang[b, :P_b] = True # 7) Compute angles in one go b_idx = torch.arange(B, device=device).unsqueeze(1) ia, ij, ik = idx_tensor.unbind(dim=2) # each (B, P_max) v1 = coords[b_idx, ia] - coords[b_idx, ij] v2 = coords[b_idx, ik] - coords[b_idx, ij] v1n = v1 / v1.norm(dim=-1, keepdim=True).clamp(min=1e-8) v2n = v2 / v2.norm(dim=-1, keepdim=True).clamp(min=1e-8) cosθ = (v1n * v2n).sum(dim=-1).clamp(-1.0, 1.0) angles = torch.acos(cosθ) * (180.0 / np.pi) angles = angles * mask_ang.float() # 8) Build string IDs on CPU (cheap: only P_b entries per structure) angle_ids: List[List[str]] = [] for b in range(B): ids = [] L = mask_ang[b].sum().item() for p in range(L): i, j, k = idx_tensor[b, p].tolist() ids.append(f"{atom_labels[b][i]}-{atom_labels[b][j]}-{atom_labels[b][k]}") angle_ids.append(ids) return angle_ids, angles, mask_ang, idx_tensor
[docs] def compute_torsion_angles_batch( atom_labels: List[List[str]], atom_coords: torch.Tensor, atom_mask: torch.BoolTensor, bond_atom1_idx: torch.Tensor, bond_atom2_idx: torch.Tensor, bond_mask: torch.BoolTensor, device: torch.device ) -> Tuple[List[List[str]], torch.Tensor, torch.BoolTensor, torch.LongTensor]: """ Compute all torsion angles (i–j–k–l) for each molecule in a batch. Parameters ---------- atom_labels : List[List[str]] Atom labels per structure (B × N). atom_coords : torch.Tensor, shape (B, N, 3) Cartesian coordinates. atom_mask : torch.BoolTensor, shape (B, N) True for real atoms. bond_atom1_idx : torch.Tensor, shape (B, M) Index of first atom in each bond slot. bond_atom2_idx : torch.Tensor, shape (B, M) Index of second atom in each bond slot. bond_mask : torch.BoolTensor, shape (B, M) True for real bonds. device : torch.device Device for computation. Returns ------- torsion_ids : List[List[str]] Per-structure list of “i–j–k–l” strings. torsions : torch.Tensor, shape (B, T_max) Dihedral angles in degrees (0 where padding). mask_tor : torch.BoolTensor, shape (B, T_max) True for real torsions. idx_tensor : torch.LongTensor, shape (B, T_max, 4) Atom index quadruplets for each torsion. """ coords = atom_coords.to(device) # (B, N, 3) mask_a = atom_mask.to(device) # (B, N) # ensure bond indices are int64 for gather idx1 = bond_atom1_idx.to(device).long() # (B, M) idx2 = bond_atom2_idx.to(device).long() # (B, M) mask_b = bond_mask.to(device) # (B, M) B, N, _ = coords.shape # 1) Identify valid bonds valid = mask_b & \ mask_a.gather(1, idx1) & \ mask_a.gather(1, idx2) # (B, M) torsion_ids: List[List[str]] = [] torsion_quads: List[torch.Tensor] = [] max_T = 0 for b in range(B): idx1_b = idx1[b] # (M,) idx2_b = idx2[b] valid_b = valid[b] # 2) Build adjacency matrix from valid bonds adj = torch.zeros((N, N), dtype=torch.bool, device=device) a = idx1_b[valid_b] c = idx2_b[valid_b] adj[a, c] = True adj[c, a] = True # 3) Central bonds j<k j_pairs, k_pairs = torch.where(torch.triu(adj, diagonal=1)) # each (P,) P = j_pairs.size(0) if P == 0: torsion_quads.append(torch.zeros((0, 4), dtype=torch.long, device=device)) torsion_ids.append([]) continue # 4) Neighbor masks for j and k (exclude partner) neigh_i = adj[j_pairs].clone() # (P, N) neigh_i[torch.arange(P, device=device), k_pairs] = False neigh_l = adj[k_pairs].clone() # (P, N) neigh_l[torch.arange(P, device=device), j_pairs] = False # 5) Cartesian product to find all i, j, k, l quads mask_il = neigh_i.unsqueeze(2) & neigh_l.unsqueeze(1) # (P, N, N) mask_flat = mask_il.view(P, -1) # (P, N*N) p_idx, flat_idx = torch.where(mask_flat) # (T_b,) i_idx = flat_idx // N l_idx = flat_idx % N j_idx = j_pairs[p_idx] k_idx = k_pairs[p_idx] quads = torch.stack([i_idx, j_idx, k_idx, l_idx], dim=1) # (T_b, 4) torsion_quads.append(quads) # 6) Build label strings per quad labels_b = atom_labels[b] ids_b = [f"{labels_b[i]}-{labels_b[j]}-{labels_b[k]}-{labels_b[l]}" for i, j, k, l in quads.tolist()] torsion_ids.append(ids_b) max_T = max(max_T, quads.size(0)) # 7) Pad all batches to (B, max_T) idx_tensor = torch.full((B, max_T, 4), -1, dtype=torch.long, device=device) mask_tor = torch.zeros((B, max_T), dtype=torch.bool, device=device) for b in range(B): tb = torsion_quads[b] L = tb.size(0) if L > 0: idx_tensor[b, :L] = tb mask_tor[b, :L] = True # 8) Batch compute dihedral angles for all quads b_idx = torch.arange(B, device=device).unsqueeze(1) p1 = coords[b_idx, idx_tensor[:, :, 0]] p2 = coords[b_idx, idx_tensor[:, :, 1]] p3 = coords[b_idx, idx_tensor[:, :, 2]] p4 = coords[b_idx, idx_tensor[:, :, 3]] b1 = p2 - p1 b2 = p3 - p2 b3 = p4 - p3 n1 = torch.cross(b1, b2, dim=-1) n2 = torch.cross(b2, b3, dim=-1) b2u = b2 / b2.norm(dim=-1, keepdim=True).clamp(min=1e-8) x = (n1 * n2).sum(dim=-1) y = (torch.cross(n2, n1, dim=-1) * b2u).sum(dim=-1) torsions = torch.atan2(y, x) * (180.0 / np.pi) torsions = torsions * mask_tor.float() return torsion_ids, torsions, mask_tor, idx_tensor
[docs] def compute_quaternions_from_rotation_matrices( R: torch.Tensor, # (B, 3, 3) device: torch.device ) -> torch.Tensor: """ Convert rotation matrices into unit quaternions [w, x, y, z] with w ≥ 0. Parameters ---------- R : torch.Tensor, shape (B, 3, 3) Proper rotation matrices (RᵀR = I, det=+1). device : torch.device Device for computation. Returns ------- torch.Tensor, shape (B, 4) Unit quaternions corresponding to each rotation matrix. """ # 1) Move to device & dtype R = R.to(device) # 2) trace and raw w t = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] # (B,) qw = 0.5 * torch.sqrt(torch.clamp(t + 1.0, min=0.0)) # (B,) # 3) safe denominator for x,y,z qw_safe = torch.clamp(qw, min=1e-8) qx = (R[..., 2, 1] - R[..., 1, 2]) / (4.0 * qw_safe) qy = (R[..., 0, 2] - R[..., 2, 0]) / (4.0 * qw_safe) qz = (R[..., 1, 0] - R[..., 0, 1]) / (4.0 * qw_safe) # 4) stack, normalize, and enforce w ≥ 0 quats = torch.stack((qw, qx, qy, qz), dim=-1) quats = quats / quats.norm(dim=1, keepdim=True) neg = quats[:, 0] < 0 quats[neg] = -quats[neg] return quats
[docs] def compute_global_steinhardt_order_parameters_batch( atom_to_com_vecs: torch.Tensor, # (B, N, 3) atom_mask: torch.BoolTensor, # (B, N) atom_weights: Optional[torch.Tensor], # (B, N) or None device: torch.device, l_values: List[int] = [2, 4, 6, 8, 10], eps: float = 1e-12 ) -> torch.Tensor: """ Compute global Steinhardt Q_l order parameters for a batch. Parameters ---------- atom_to_com_vecs : torch.Tensor, shape (B, N, 3) Positions relative to the center of mass. atom_mask : torch.BoolTensor, shape (B, N) True for real atoms, False for padding. atom_weights : torch.Tensor or None, shape (B, N) Per-atom weights, or None for uniform weights. device : torch.device Device for computation. l_values : List[int], optional List of ℓ values at which to compute Q_l. eps : float, optional Small constant to avoid division by zero. Returns ------- torch.Tensor, shape (B, len(l_values)) Q_l values for each batch entry. """ # Move to device atom_to_com_vecs = atom_to_com_vecs.to(device) atom_mask = atom_mask.to(device) if atom_weights is not None: atom_weights = atom_weights.to(device) B, N, _ = atom_to_com_vecs.shape # build weight tensor, zeroing out padding if atom_weights is not None: w = atom_weights * atom_mask.to(atom_weights.dtype) else: w = atom_mask.to(atom_to_com_vecs.dtype) # total weight per molecule W_sum = w.sum(dim=1).clamp(min=eps) # (B,) # spherical coords x = atom_to_com_vecs[..., 0] y = atom_to_com_vecs[..., 1] z = atom_to_com_vecs[..., 2] r = torch.linalg.norm(atom_to_com_vecs, dim=-1) # (B, N) r_safe = r + eps cos_theta = torch.clamp(z / r_safe, -1.0, 1.0) # (B, N) phi = torch.atan2(y, x) # (B, N) Qs = torch.zeros((B, len(l_values)), dtype=atom_to_com_vecs.dtype, device=device) for idx_l, l in enumerate(l_values): sum_m = torch.zeros(B, dtype=atom_to_com_vecs.dtype, device=device) for m in range(l + 1): # associated Legendre P_l^m(cosθ) P_lm = _assoc_legendre(l, m, cos_theta) # (B, N) P_lm = P_lm * atom_mask.to(P_lm.dtype) # normalization constant norm_lm = math.sqrt((2*l + 1)/(4*math.pi) * math.factorial(l - m) / math.factorial(l + m)) norm_lm_t = atom_to_com_vecs.new_tensor(norm_lm) # scalar tensor # weighted spherical harmonic component P_norm = P_lm * norm_lm_t # (B, N) P_weighted = P_norm * w # (B, N) cos_mphi = torch.cos(m * phi) sin_mphi = torch.sin(m * phi) real_part = (P_weighted * cos_mphi).sum(dim=1) / W_sum imag_part = (P_weighted * sin_mphi).sum(dim=1) / W_sum Qlm_sq = real_part.pow(2) + imag_part.pow(2) weight = 1.0 if m == 0 else 2.0 sum_m = sum_m + weight * Qlm_sq Qs[:, idx_l] = torch.sqrt((4 * math.pi)/(2*l + 1) * sum_m) return Qs
def _assoc_legendre( l: int, m: int, x: torch.Tensor # (B, N), dtype float ) -> torch.Tensor: """ Compute the associated Legendre polynomial P_l^m(x) via recurrence. Parameters ---------- l : int Degree of the polynomial. m : int Order of the polynomial (0 ≤ m ≤ l). x : torch.Tensor, shape (B, N) Input values in the interval [–1, 1]. Returns ------- torch.Tensor, shape (B, N) Evaluated P_l^m(x) values. """ # base: P_m^m(x) if m == 0: p_mm = x.new_ones(x.shape) else: df = 1.0 for k in range(1, 2*m, 2): df *= k p_mm = ((-1)**m) * df * (1 - x**2).pow(m / 2) if l == m: return p_mm # next level P_{m+1}^m p_m1m = x * (2*m + 1) * p_mm if l == m + 1: return p_m1m # upward recurrence p_lm_minus2 = p_mm p_lm_minus1 = p_m1m for ll in range(m+2, l+1): p_lm = ((2*ll - 1) * x * p_lm_minus1 - (ll + m - 1) * p_lm_minus2) / (ll - m) p_lm_minus2, p_lm_minus1 = p_lm_minus1, p_lm return p_lm_minus1
[docs] def compute_best_fit_plane_batch( coords: torch.Tensor, # (B, N, 3) mask: torch.BoolTensor, # (B, N) device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute best-fit plane normals and centroids via SVD. Parameters ---------- coords : torch.Tensor, shape (B, N, 3) Cartesian coordinates of atoms, padded to N. mask : torch.BoolTensor, shape (B, N) True for real atoms, False for padding. device : torch.device Device for computation. Returns ------- normals : torch.Tensor, shape (B, 3) Unit normals of the best-fit planes (z ≥ 0). centroids : torch.Tensor, shape (B, 3) Centroid of valid atoms per batch entry. """ # 1) move to device coords = coords.to(device) mask = mask.to(device) # 2) expand mask and compute centroids mask_f = mask.unsqueeze(-1).to(coords.dtype) # (B, N, 1) counts = mask_f.sum(dim=1) # (B, 1) centroids = (coords * mask_f).sum(dim=1) / counts # (B, 3) # 3) center & mask out padding centered = (coords - centroids.unsqueeze(1)) * mask_f # (B, N, 3) # 4) batched SVD → Vh: (B, 3, 3) _, _, Vh = torch.linalg.svd(centered, full_matrices=False) # 5) plane normals = last right‐singular vector normals = Vh[:, -1, :] # (B, 3) # 6) enforce consistent orientation (z ≥ 0) neg_z = normals[:, 2] < 0 normals[neg_z] = -normals[neg_z] return { 'fragment_plane_centroid': centroids, 'fragment_plane_normal': normals }
[docs] def compute_planarity_metrics_batch( coords: torch.Tensor, # (B, N, 3) mask: torch.BoolTensor, # (B, N) normals: torch.Tensor, # (B, 3) — unit plane normals, z ≥ 0 centroids: torch.Tensor, # (B, 3) — centroids of valid atoms device: torch.device, decay_width: float = 0.5 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute planarity metrics (RMSD, max deviation, planarity score) for a batch. Parameters ---------- coords : torch.Tensor, shape (B, N, 3) Cartesian coordinates of atoms, padded to N. mask : torch.BoolTensor, shape (B, N) True for real atoms. normals : torch.Tensor, shape (B, 3) Plane normals (unit vectors). centroids : torch.Tensor, shape (B, 3) Centroids of valid atoms. device : torch.device Device for computation. decay_width : float, optional Width parameter for exponential planarity score. Returns ------- rmsd : torch.Tensor, shape (B,) Root‐mean‐square deviation from the plane. max_dev : torch.Tensor, shape (B,) Maximum absolute deviation from the plane. planarity_score : torch.Tensor, shape (B,) Exponential planarity score exp(–rmsd/decay_width). """ # 1) re‐center atoms and mask coords = coords.to(device) mask = mask.to(device) mask_f = mask.unsqueeze(-1).to(coords.dtype) # (B, N, 1) centered = (coords - centroids.unsqueeze(1)) * mask_f # (B, N, 3) # 2) signed distances to plane dists = torch.abs((centered * normals.unsqueeze(1)).sum(dim=-1)) # (B, N) # 3) atom counts counts = mask.to(dtype=coords.dtype).sum(dim=1) # (B,) # 4) RMSD rmsd = torch.sqrt((dists**2 * mask.to(dtype=coords.dtype)).sum(dim=1) / counts) # 5) max deviation (ignore padding) dists_masked = dists.masked_fill(~mask, float("-inf")) max_dev = dists_masked.max(dim=1).values # 6) planarity score planarity_score = torch.exp(-rmsd / decay_width) # 7) handle too‐few‐atoms case (<3) invalid = counts < 3 if invalid.any(): rmsd[invalid] = float("nan") max_dev[invalid] = float("nan") planarity_score[invalid] = float("nan") return { 'fragment_planarity_rmsd': rmsd, 'fragment_planarity_max_dev': max_dev, 'fragment_planarity_score': planarity_score }
[docs] def compute_fragment_pairwise_vectors_and_distances_batch( coords: torch.Tensor, mask: torch.BoolTensor, heavy_mask: torch.BoolTensor, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute pairwise displacement vectors and Euclidean distances between non-hydrogen atoms in each fragment. Parameters ---------- coords : torch.Tensor, shape (F, A, 3) Cartesian coordinates for each of F fragments, padded to A atoms. mask : torch.BoolTensor, shape (F, A) True for real atoms slots, False for padding. heavy_mask : torch.BoolTensor, shape (F, A) True for heavy (non-H) atom slots, False otherwise. device : torch.device Device on which to perform the computation. Returns ------- distances : torch.Tensor, shape (F, A, A) Pairwise Euclidean distances. Entry `[f,i,j]` is the distance between atom i and j in fragment f if both are heavy & real; zero otherwise. vectors : torch.Tensor, shape (F, A, A, 3) Pairwise displacement vectors: coords[j]−coords[i] for each heavy-atom pair, zero for any non-heavy or padding atom involved. atom1_idx : torch.LongTensor, shape (P,) The “i” index of each unique heavy-atom pair (i<j), across all fragments. atom2_idx : torch.LongTensor, shape (P,) The “j” index of each unique heavy-atom pair (i<j), across all fragments. Notes ----- The order of entries in `atom1_indices` and `atom2_indices` matches the order you’d get by iterating through `torch.nonzero(pair_valid & (j>i))`. """ # move inputs to device coords = coords.to(device) # (F, A, 3) mask = mask.to(device) # (F, A) heavy = heavy_mask.to(device) # (F, A) # valid heavy‐atom slots valid = mask & heavy # (F, A) # all pairwise diffs & norms diffs = coords.unsqueeze(2) - coords.unsqueeze(1) # (F, A, A, 3) dists = diffs.norm(dim=-1) # (F, A, A) # mask out any pair with H or padding pair_valid = valid.unsqueeze(2) & valid.unsqueeze(1) # (F, A, A) distances = dists * pair_valid.to(dists.dtype) vectors = diffs * pair_valid.unsqueeze(-1).to(diffs.dtype) # extract only unique i<j pairs upper_valid = torch.triu(pair_valid, diagonal=1) # (F, A, A) frag_idx, i_idx, j_idx = torch.nonzero(upper_valid, as_tuple=True) return { 'fragment_atom_pair_atom1_idx': i_idx, 'fragment_atom_pair_atom2_idx': j_idx, 'fragment_atom_pair_idx': frag_idx, 'fragment_atom_pair_dist': distances, 'fragment_atom_pair_vec': vectors }