Source code for fragment_utils

"""
Module: fragment_utils.py

Utilities for batch processing of molecular fragments, including rigid‐fragment
identification and computation of per‐fragment properties.

Dependencies
------------
torch
typing
"""
import torch
from typing import List, Tuple, Dict, Union

[docs] def identify_rigid_fragments_batch( atom_mask: torch.BoolTensor, # (B, N) bond_atom1: torch.LongTensor, # (B, M) bond_atom2: torch.LongTensor, # (B, M) bond_is_rotatable: torch.BoolTensor, # (B, M) device: torch.device ) -> torch.LongTensor: """ Identify rigid fragments in a batch via iterative label propagation on GPU. Parameters ---------- atom_mask : torch.BoolTensor of shape (B, N) True for real atoms, False for padding slots. bond_atom1 : torch.LongTensor of shape (B, M) First‐atom indices for each bond (–1 for padding). bond_atom2 : torch.LongTensor of shape (B, M) Second‐atom indices for each bond (–1 for padding). bond_is_rotatable : torch.BoolTensor of shape (B, M) True if the bond is rotatable; non-rotatable bonds join fragments. device : torch.device Device to perform computation on (e.g. 'cuda'). Returns ------- frag_id : torch.LongTensor of shape (B, N) Fragment ID for each atom (0..K−1 for real atoms, −1 for padding). """ B, N = atom_mask.shape _, M = bond_atom1.shape # Move everything to GPU atom_mask = atom_mask.to(device) bond_atom1 = bond_atom1.to(device).long() bond_atom2 = bond_atom2.to(device).long() bond_is_rot = bond_is_rotatable.to(device) # 1) initialize each atom’s “label” = its own index, or –1 if padding labels = ( torch.arange(N, device=device) .unsqueeze(0).expand(B, N) .where(atom_mask, torch.full((B, N), -1, device=device)) .clone() ) # Precompute which bonds actually tie atoms together valid_bond = (bond_atom1 >= 0) & (~bond_is_rot) # (B, M) # For safe gather, clamp negative indices to zero (we’ll mask them out later) u_idx = bond_atom1.clamp(min=0) v_idx = bond_atom2.clamp(min=0) # A “big” label so that invalid bonds never win the min-scatter BIG = torch.tensor(N, device=device, dtype=labels.dtype) # 2) propagate minima across each non-rotatable bond for up to N iterations # (worst-case chain length = N−1) for _ in range(N): lu = labels.gather(1, u_idx) # (B, M) lv = labels.gather(1, v_idx) mn = torch.min(lu, lv) # mask out all rotatable or padding bonds by setting to BIG mn = torch.where(valid_bond, mn, BIG) # scatter‐reduce the minima back into labels at both endpoints labels.scatter_reduce_(1, u_idx, mn, reduce='amin', include_self=True) labels.scatter_reduce_(1, v_idx, mn, reduce='amin', include_self=True) # 3) remap each unique “root” label to 0..K−1 per batch frag_id = torch.full_like(labels, -1) for b in range(B): lb = labels[b] # (N,) mask = atom_mask[b] # (N,) # only consider real-atom labels real_labels = lb[mask] if real_labels.numel() == 0: continue # sorted unique roots uniq = torch.unique(real_labels) # map each atom’s root to its index in uniq idx = torch.searchsorted(uniq, lb) frag_id[b] = torch.where(mask, idx, torch.tensor(-1, device=device)) return frag_id
[docs] def prepare_fragments_batch( atom_fragment_id: torch.LongTensor, atom_coords: torch.Tensor, atom_frac_coords: torch.Tensor, atom_weights: torch.Tensor, atom_charges: torch.Tensor, atom_symbol_codes: torch.LongTensor, code_to_element: List[str], code_H: int, device: torch.device ) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Assemble per‐fragment tensors and compute chemical formulas for a batch. Parameters ---------- atom_fragment_id : torch.LongTensor of shape (B, N) Fragment index per atom (−1 for padding). atom_coords : torch.Tensor of shape (B, N, 3) Cartesian coordinates, padded to N atoms. atom_frac_coords : torch.Tensor of shape (B, N, 3) Fractional coordinates, padded similarly. atom_weights : torch.Tensor of shape (B, N) Atomic weights (zero for padding). atom_charges : torch.Tensor of shape (B, N) Partial charges (zero for padding). atom_symbol_codes : torch.LongTensor of shape (B, N) Integer element codes per atom. code_to_element : List[str] Mapping from element code to symbol. code_H : int Integer code corresponding to hydrogen. device : torch.device Device to perform computation on. Returns ------- dict with keys: fragment_structure_id : torch.LongTensor of shape (F,) fragment_local_id : torch.LongTensor of shape (F,) fragment_n_atoms : torch.LongTensor of shape (F,) fragment_atom_coords : torch.Tensor of shape (F, max_A, 3) fragment_atom_frac_coords : torch.Tensor of shape (F, max_A, 3) fragment_atom_weight : torch.Tensor of shape (F, max_A) fragment_atom_charge : torch.Tensor of shape (F, max_A) fragment_atom_mask : torch.BoolTensor of shape (F, max_A) fragment_atom_heavy_mask : torch.BoolTensor of shape (F, max_A) fragment_formula : List[str] of length F """ # Move tensors to device atom_fragment_id = atom_fragment_id.to(device) atom_coords = atom_coords.to(device) atom_frac_coords = atom_frac_coords.to(device) atom_weights = atom_weights.to(device) atom_charges = atom_charges.to(device) atom_symbol_codes = atom_symbol_codes.to(device) B, N = atom_fragment_id.shape # Determine number of fragments per structure n_frags = (atom_fragment_id.max(dim=1).values + 1).tolist() # list length B n_frags_t = torch.tensor(n_frags, dtype=torch.long, device=device) # shape (B,) # Flatten struct and local fragment IDs struct_ids = [] frag_local_ids = [] for b in range(B): for f in range(n_frags[b]): struct_ids.append(b) frag_local_ids.append(f) F = len(struct_ids) struct_ids_t = torch.tensor(struct_ids, dtype=torch.long, device=device) frag_local_ids_t = torch.tensor(frag_local_ids, dtype=torch.long, device=device) # Build fragment-atom mask (F, N) # broadcast struct and local IDs to compare # we index per-fragment: mask[i] = atom_fragment_id[struct_ids[i]] == frag_local_ids[i] frag_atom_mask_full = atom_fragment_id[struct_ids_t] == frag_local_ids_t.unsqueeze(1) frag_n_atoms = frag_atom_mask_full.sum(dim=1) # (F,) max_A = int(frag_n_atoms.max().item()) # Allocate padded fragment tensors frag_coords = torch.zeros((F, max_A, 3), device=device) frag_frac = torch.zeros((F, max_A, 3), device=device) frag_weights = torch.zeros((F, max_A), device=device) frag_charges = torch.zeros((F, max_A), device=device) frag_atom_mask = torch.zeros((F, max_A), dtype=torch.bool, device=device) frag_heavy_mask = torch.zeros((F, max_A), dtype=torch.bool, device=device) fragment_formulas: List[str] = [] fragment_formulas: List[List[str]] = [[] for _ in range(B)] # Fill padded fragment tensors and compute formulas for idx in range(F): b = struct_ids[idx] f = frag_local_ids[idx] # atom indices for this fragment atom_inds = torch.nonzero(atom_fragment_id[b] == f, as_tuple=False).squeeze(1) nA = atom_inds.size(0) # Gather properties frag_coords[idx, :nA] = atom_coords[b, atom_inds] frag_frac[idx, :nA] = atom_frac_coords[b, atom_inds] frag_weights[idx, :nA] = atom_weights[b, atom_inds] frag_charges[idx, :nA] = atom_charges[b, atom_inds] frag_atom_mask[idx, :nA] = True # Heavy atom mask via codes heavy_src = atom_symbol_codes[b, atom_inds] frag_heavy_mask[idx, :nA] = heavy_src != code_H # Compute formula counts on CPU codes_cpu = heavy_src.cpu().tolist() # Count occurrences counts: Dict[int,int] = {} for c in codes_cpu: counts[c] = counts.get(c, 0) + 1 codes = list(counts.keys()) codes_sorted = sorted(codes, key=lambda c: code_to_element[c]) formula = ''.join( f"{code_to_element[c]}{counts[c]}" for c in codes_sorted ) # fragment_formulas.append(formula) fragment_formulas[b].append(formula) return { 'fragment_structure_id': struct_ids_t, 'n_fragments': n_frags_t, 'fragment_local_id': frag_local_ids_t, 'fragment_n_atoms': frag_n_atoms, 'fragment_atom_coords': frag_coords, 'fragment_atom_frac_coords': frag_frac, 'fragment_atom_weight': frag_weights, 'fragment_atom_charge': frag_charges, 'fragment_atom_mask': frag_atom_mask, 'fragment_atom_heavy_mask': frag_heavy_mask, 'fragment_formula': fragment_formulas }
[docs] def compute_center_of_mass_batch( atom_coords: torch.Tensor, # (B, N, 3) atom_frac_coords: torch.Tensor, # (B, N, 3) atom_weights: torch.Tensor, # (B, N) atom_mask: torch.BoolTensor, # (B, N) device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute Cartesian and fractional centers of mass for each fragment. Parameters ---------- atom_coords : torch.Tensor of shape (B, N, 3) Cartesian coordinates, padded to N atoms. atom_frac_coords : torch.Tensor of shape (B, N, 3) Fractional coordinates, padded similarly. atom_weights : torch.Tensor of shape (B, N) Atomic weights (zero for padding). atom_mask : torch.BoolTensor of shape (B, N) True for real atoms, False for padding. device : torch.device Device to perform computation on. Returns ------- com_coords : torch.Tensor of shape (B, 3) Cartesian center of mass per fragment. com_frac_coords : torch.Tensor of shape (B, 3) Fractional center of mass per fragment. """ # 1) move everything to device atom_coords = atom_coords.to(device) atom_frac_coords = atom_frac_coords.to(device) atom_weights = atom_weights.to(device) atom_mask = atom_mask.to(device) # 2) mask out padding atoms (mask → float 0/1) w = atom_weights * atom_mask.to(atom_weights.dtype) # (B, N) # 3) total mass per fragment: (B,1) total_mass = w.sum(dim=1, keepdim=True) # 4) weighted sums → COMs # w.unsqueeze(-1) is (B, N, 1), broadcasts over coords' last dim com_coords = (atom_coords * w.unsqueeze(-1)).sum(dim=1) / total_mass com_frac_coords = (atom_frac_coords * w.unsqueeze(-1)).sum(dim=1) / total_mass return { 'fragment_com_coords': com_coords, 'fragment_com_frac_coords': com_frac_coords }
[docs] def compute_centroid_batch( atom_coords: torch.Tensor, # (B, N, 3) atom_frac_coords: torch.Tensor, # (B, N, 3) atom_mask: torch.BoolTensor, # (B, N) device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute geometric centroids in Cartesian and fractional coordinates. Parameters ---------- atom_coords : torch.Tensor of shape (B, N, 3) Cartesian coordinates, padded to N atoms. atom_frac_coords : torch.Tensor of shape (B, N, 3) Fractional coordinates, padded similarly. atom_mask : torch.BoolTensor of shape (B, N) True for real atoms, False for padding. device : torch.device Device to perform computation on. Returns ------- centroid_coords : torch.Tensor of shape (B, 3) Cartesian centroids per fragment. centroid_frac_coords : torch.Tensor of shape (B, 3) Fractional centroids per fragment. """ # 1) move to device atom_coords = atom_coords.to(device) atom_frac_coords = atom_frac_coords.to(device) atom_mask = atom_mask.to(device) # 2) convert mask to float (1.0 for real atoms, 0.0 for padding) m = atom_mask.to(atom_coords.dtype) # (B, N) # 3) count real atoms per fragment: (B,1) count = m.sum(dim=1, keepdim=True) # avoid division by zero (if a fragment somehow has zero atoms) count = torch.clamp(count, min=1.0) # 4) sum positions only over real atoms sum_cart = (atom_coords * m.unsqueeze(-1)).sum(dim=1) # (B, 3) sum_frac = (atom_frac_coords * m.unsqueeze(-1)).sum(dim=1) # (B, 3) # 5) average → centroids centroid_coords = sum_cart / count centroid_frac_coords = sum_frac / count return { 'fragment_cen_coords': centroid_coords, 'fragment_cen_frac_coords': centroid_frac_coords }
[docs] def compute_inertia_tensor_batch( atom_coords: torch.Tensor, # (B, N, 3) atom_weights: torch.Tensor, # (B, N) atom_mask: torch.BoolTensor, # (B, N) com_coords: torch.Tensor, # (B, 3) device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute each fragment’s inertia tensor, eigenvalues, and oriented eigenvectors. Parameters ---------- atom_coords : torch.Tensor of shape (B, N, 3) Cartesian coordinates, padded to N atoms. atom_weights : torch.Tensor of shape (B, N) Atomic weights, zero for padding. atom_mask : torch.BoolTensor of shape (B, N) True for real atoms. com_coords : torch.Tensor of shape (B, 3) Pre‐computed center of mass coordinates. device : torch.device Device to perform computation on. Returns ------- inertia_tensors : torch.Tensor of shape (B, 3, 3) Inertia tensor for each fragment. eigvals : torch.Tensor of shape (B, 3) Eigenvalues (λ₁ ≤ λ₂ ≤ λ₃) per fragment. eigvecs : torch.Tensor of shape (B, 3, 3) Corresponding right‐handed eigenvectors (columns). """ # Move to device atom_coords = atom_coords.to(device) atom_weights = atom_weights.to(device) atom_mask = atom_mask.to(device) com_coords = com_coords.to(device) # Expand mask & weights mask3 = atom_mask.unsqueeze(-1) # (B, N, 1) w3 = atom_weights.unsqueeze(-1) * mask3 # (B, N, 1) # r_i = position relative to COM r = (atom_coords - com_coords.unsqueeze(1)) * mask3 # (B, N, 3) # r² and outer products r2 = (r * r).sum(dim=-1, keepdim=True) # (B, N, 1) outer = r.unsqueeze(-1) * r.unsqueeze(-2) # (B, N, 3, 3) # Identity for broadcasting I = torch.eye(3, device=device).view(1, 1, 3, 3) # Per‐atom inertia contributions # w * [ (r·r) I₃ − (r⊗r) ] terms = w3.unsqueeze(-1) * (r2.unsqueeze(-1) * I - outer) # (B, N, 3, 3) # Sum over atoms → inertia tensor inertia_tensors = terms.sum(dim=1) # (B, 3, 3) # Diagonalize eigvals, eigvecs = torch.linalg.eigh(inertia_tensors) # ascending λ's # Fix eigenvector signs so each has its max‐abs component ≥ 0 for i in range(3): vec = eigvecs[..., i] # (B, 3) max_idx = vec.abs().argmax(dim=1, keepdim=True) # (B, 1) sign = vec.gather(1, max_idx).sign() # (B, 1) sign[sign == 0] = 1.0 eigvecs[..., i] = vec * sign # Enforce right‐handedness dets = torch.linalg.det(eigvecs) # (B,) left = dets < 0 if left.any(): eigvecs[left, :, 2] *= -1 return { 'fragment_inertia_tensors': inertia_tensors, 'fragment_inertia_eigvals': eigvals, 'fragment_inertia_eigvecs': eigvecs }
[docs] def compute_quadrupole_tensor_batch( atom_coords: torch.Tensor, # (B, N, 3) atom_charges: torch.Tensor, # (B, N) atom_mask: torch.BoolTensor, # (B, N) com_coords: torch.Tensor, # (B, 3) device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute each fragment’s quadrupole tensor, eigenvalues, and eigenvectors. Parameters ---------- atom_coords : torch.Tensor of shape (B, N, 3) Cartesian coordinates, padded to N atoms. atom_charges : torch.Tensor of shape (B, N) Atomic charges, zero for padding. atom_mask : torch.BoolTensor of shape (B, N) True for real atoms. com_coords : torch.Tensor of shape (B, 3) Pre‐computed center of mass coordinates. device : torch.device Device to perform computation on. Returns ------- quad_tensors : torch.Tensor of shape (B, 3, 3) Quadrupole tensor Q per fragment. eigvals : torch.Tensor of shape (B, 3) Eigenvalues of Q (ascending). eigvecs : torch.Tensor of shape (B, 3, 3) Right‐handed eigenvectors (columns). """ # Move to device atom_coords = atom_coords.to(device) atom_charges = atom_charges.to(device) atom_mask = atom_mask.to(device) com_coords = com_coords.to(device) # Masks & charges mask3 = atom_mask.unsqueeze(-1) # (B, N, 1) q4 = atom_charges.unsqueeze(-1).unsqueeze(-1) # (B, N, 1, 1) q4 = q4 * mask3.unsqueeze(-1) # zero out padding # Shift to COM r = (atom_coords - com_coords.unsqueeze(1)) * mask3 # (B, N, 3) r2 = (r * r).sum(dim=-1, keepdim=True) # (B, N, 1) outer = r.unsqueeze(-1) * r.unsqueeze(-2) # (B, N, 3, 3) I = torch.eye(3, device=device).view(1, 1, 3, 3) # Per‐atom quadrupole: q [3 (r⊗r) − |r|² I] terms = q4 * (3.0 * outer - r2.unsqueeze(-1) * I) # (B, N, 3, 3) quad_tensors = terms.sum(dim=1) # (B, 3, 3) # Diagonalize eigvals, eigvecs = torch.linalg.eigh(quad_tensors) # Fix eigenvector signs (max‐abs component ≥ 0) for i in range(3): vec = eigvecs[..., i] max_idx = vec.abs().argmax(dim=1, keepdim=True) sign = vec.gather(1, max_idx).sign() sign[sign == 0] = 1.0 eigvecs[..., i] = vec * sign # Enforce right‐handedness dets = torch.linalg.det(eigvecs) left = dets < 0 if left.any(): eigvecs[left, :, 2] *= -1 return { 'fragment_quadrupole_tensors': quad_tensors, 'fragment_quadrupole_eigvals': eigvals, 'fragment_quadrupole_eigvecs': eigvecs }