geometry_utils module
Module: geometry_utils.py
- Batch-based utilities for computing geometric descriptors of molecules and fragments using PyTorch. Includes:
Best-fit plane / centroid
Planarity metrics (RMSD, max deviation, planarity score)
Global Steinhardt Q_l order parameters
Distances to special crystallographic planes (fractional)
Angles between bonds and special planes (fractional)
Quaternion computation from rotation matrices
Dependencies
math numpy torch typing
- geometry_utils.compute_distances_to_crystallographic_planes_frac_batch(atom_frac_coords, atom_mask, device)[source]
Compute fractional distances of each atom to the 26 special crystallographic planes.
- Parameters:
atom_frac_coords (torch.Tensor, shape (B, A, 3)) – Fractional coordinates of atoms, padded to A slots per structure.
atom_mask (torch.BoolTensor, shape (B, A)) – True for valid atoms, False for padding.
device (torch.device) – Device on which to perform the computation.
- Returns:
Absolute fractional distances to each plane (13 normals × 2 denominators).
- Return type:
torch.Tensor, shape (B, A, 26)
- geometry_utils.compute_angles_between_bonds_and_crystallographic_planes_frac_batch(atom_frac_coords, bond_atom1, bond_atom2, bond_mask, device)[source]
Compute angles (in degrees) between each bond vector and the 13 crystallographic plane normals.
- Parameters:
atom_frac_coords (torch.Tensor, shape (B, A, 3)) – Fractional coordinates of all atoms, padded to A per structure.
bond_atom1 (torch.LongTensor, shape (B, M)) – Index of the first atom in each bond slot.
bond_atom2 (torch.LongTensor, shape (B, M)) – Index of the second atom in each bond slot.
bond_mask (torch.BoolTensor, shape (B, M)) – True for real bonds, False for padding.
device (torch.device) – Device on which to perform the computation.
- Returns:
Bond–plane angles in degrees; zeros where bond_mask is False.
- Return type:
torch.Tensor, shape (B, M, 13)
- geometry_utils.compute_atom_vectors_to_point_batch(atom_coords, atom_frac_coords, atom_mask, com_coords, com_frac_coords, device)[source]
Compute displacement vectors and Euclidean distances from each atom to a reference point.
- Parameters:
atom_coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates, padded to N per fragment.
atom_frac_coords (torch.Tensor, shape (B, N, 3)) – Fractional coordinates, padded to N per fragment.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms, False for padding.
com_coords (torch.Tensor, shape (B, 3)) – Reference points in Cartesian space.
com_frac_coords (torch.Tensor, shape (B, 3)) – Reference points in fractional space.
device (torch.device) – Device for computation.
- Returns:
dists_cart (torch.Tensor, shape (B, N)) – Euclidean distances in Cartesian space.
dists_frac (torch.Tensor, shape (B, N)) – Euclidean distances in fractional space.
vecs_cart (torch.Tensor, shape (B, N, 3)) – Cartesian displacement vectors (atom → point).
vecs_frac (torch.Tensor, shape (B, N, 3)) – Fractional displacement vectors.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- geometry_utils.compute_bond_rotatability_batch(atom_symbols, bond_atom1_idx, bond_atom2_idx, bond_is_cyclic, bond_types, bond_is_rotatable_raw, device)[source]
Determine which bonds are rotatable based on CSD criteria.
- Parameters:
atom_symbols (List[List[str]]) – Element symbols per atom for each structure (B × N).
bond_atom1_idx (torch.LongTensor, shape (B, M)) – Index of the first atom in each bond slot.
bond_atom2_idx (torch.LongTensor, shape (B, M)) – Index of the second atom in each bond slot.
bond_is_cyclic (torch.BoolTensor, shape (B, M)) – True if the bond is in a ring.
bond_types (List[List[str]]) – Bond types (e.g. ‘single’, ‘double’, ‘triple’) per slot.
bond_is_rotatable_raw (List[List[bool]]) – Original CSD rotatability flags per slot.
device (torch.device) – Device for the output tensor.
- Returns:
True where the bond passes all checks and is rotatable.
- Return type:
torch.BoolTensor, shape (B, M)
- geometry_utils.compute_bond_angles_batch(atom_labels, atom_coords, atom_mask, bond_atom1_idx, bond_atom2_idx, bond_mask, device)[source]
Compute all bond angles (i–j–k) in a batch.
- Parameters:
atom_labels (List[List[str]]) – Atom labels per structure (B × N).
atom_coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates of atoms.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms.
bond_atom1_idx (torch.LongTensor, shape (B, M)) – Index of first atom in each bond slot.
bond_atom2_idx (torch.LongTensor, shape (B, M)) – Index of second atom in each bond slot.
bond_mask (torch.BoolTensor, shape (B, M)) – True for real bonds.
device (torch.device) – Device for computation.
- Returns:
angle_ids (List[List[str]]) – Per-structure list of “i–j–k” strings.
angles (torch.Tensor, shape (B, P_max)) – Angle values in degrees (0 where padding).
mask_ang (torch.BoolTensor, shape (B, P_max)) – True for real angles.
idx_tensor (torch.LongTensor, shape (B, P_max, 3)) – Atom index triples for each angle.
- Return type:
Tuple[List[List[str]], torch.Tensor, torch.BoolTensor, torch.LongTensor]
- geometry_utils.compute_torsion_angles_batch(atom_labels, atom_coords, atom_mask, bond_atom1_idx, bond_atom2_idx, bond_mask, device)[source]
Compute all torsion angles (i–j–k–l) for each molecule in a batch.
- Parameters:
atom_labels (List[List[str]]) – Atom labels per structure (B × N).
atom_coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms.
bond_atom1_idx (torch.Tensor, shape (B, M)) – Index of first atom in each bond slot.
bond_atom2_idx (torch.Tensor, shape (B, M)) – Index of second atom in each bond slot.
bond_mask (torch.BoolTensor, shape (B, M)) – True for real bonds.
device (torch.device) – Device for computation.
- Returns:
torsion_ids (List[List[str]]) – Per-structure list of “i–j–k–l” strings.
torsions (torch.Tensor, shape (B, T_max)) – Dihedral angles in degrees (0 where padding).
mask_tor (torch.BoolTensor, shape (B, T_max)) – True for real torsions.
idx_tensor (torch.LongTensor, shape (B, T_max, 4)) – Atom index quadruplets for each torsion.
- Return type:
Tuple[List[List[str]], torch.Tensor, torch.BoolTensor, torch.LongTensor]
- geometry_utils.compute_quaternions_from_rotation_matrices(R, device)[source]
Convert rotation matrices into unit quaternions [w, x, y, z] with w ≥ 0.
- Parameters:
R (torch.Tensor, shape (B, 3, 3)) – Proper rotation matrices (RᵀR = I, det=+1).
device (torch.device) – Device for computation.
- Returns:
Unit quaternions corresponding to each rotation matrix.
- Return type:
torch.Tensor, shape (B, 4)
- geometry_utils.compute_global_steinhardt_order_parameters_batch(atom_to_com_vecs, atom_mask, atom_weights, device, l_values=[2, 4, 6, 8, 10], eps=1e-12)[source]
Compute global Steinhardt Q_l order parameters for a batch.
- Parameters:
atom_to_com_vecs (torch.Tensor, shape (B, N, 3)) – Positions relative to the center of mass.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms, False for padding.
atom_weights (torch.Tensor or None, shape (B, N)) – Per-atom weights, or None for uniform weights.
device (torch.device) – Device for computation.
l_values (List[int], optional) – List of ℓ values at which to compute Q_l.
eps (float, optional) – Small constant to avoid division by zero.
- Returns:
Q_l values for each batch entry.
- Return type:
torch.Tensor, shape (B, len(l_values))
- geometry_utils.compute_best_fit_plane_batch(coords, mask, device)[source]
Compute best-fit plane normals and centroids via SVD.
- Parameters:
coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates of atoms, padded to N.
mask (torch.BoolTensor, shape (B, N)) – True for real atoms, False for padding.
device (torch.device) – Device for computation.
- Returns:
normals (torch.Tensor, shape (B, 3)) – Unit normals of the best-fit planes (z ≥ 0).
centroids (torch.Tensor, shape (B, 3)) – Centroid of valid atoms per batch entry.
- Return type:
- geometry_utils.compute_planarity_metrics_batch(coords, mask, normals, centroids, device, decay_width=0.5)[source]
Compute planarity metrics (RMSD, max deviation, planarity score) for a batch.
- Parameters:
coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates of atoms, padded to N.
mask (torch.BoolTensor, shape (B, N)) – True for real atoms.
normals (torch.Tensor, shape (B, 3)) – Plane normals (unit vectors).
centroids (torch.Tensor, shape (B, 3)) – Centroids of valid atoms.
device (torch.device) – Device for computation.
decay_width (float, optional) – Width parameter for exponential planarity score.
- Returns:
rmsd (torch.Tensor, shape (B,)) – Root‐mean‐square deviation from the plane.
max_dev (torch.Tensor, shape (B,)) – Maximum absolute deviation from the plane.
planarity_score (torch.Tensor, shape (B,)) – Exponential planarity score exp(–rmsd/decay_width).
- Return type:
- geometry_utils.compute_fragment_pairwise_vectors_and_distances_batch(coords, mask, heavy_mask, device)[source]
Compute pairwise displacement vectors and Euclidean distances between non-hydrogen atoms in each fragment.
- Parameters:
coords (torch.Tensor, shape (F, A, 3)) – Cartesian coordinates for each of F fragments, padded to A atoms.
mask (torch.BoolTensor, shape (F, A)) – True for real atoms slots, False for padding.
heavy_mask (torch.BoolTensor, shape (F, A)) – True for heavy (non-H) atom slots, False otherwise.
device (torch.device) – Device on which to perform the computation.
- Returns:
distances (torch.Tensor, shape (F, A, A)) – Pairwise Euclidean distances. Entry [f,i,j] is the distance between atom i and j in fragment f if both are heavy & real; zero otherwise.
vectors (torch.Tensor, shape (F, A, A, 3)) – Pairwise displacement vectors: coords[j]−coords[i] for each heavy-atom pair, zero for any non-heavy or padding atom involved.
atom1_idx (torch.LongTensor, shape (P,)) – The “i” index of each unique heavy-atom pair (i<j), across all fragments.
atom2_idx (torch.LongTensor, shape (P,)) – The “j” index of each unique heavy-atom pair (i<j), across all fragments.
- Return type:
Notes
The order of entries in atom1_indices and atom2_indices matches the order you’d get by iterating through torch.nonzero(pair_valid & (j>i)).
Geometric Calculations and Descriptors
The geometry_utils module provides GPU-accelerated batch computations for molecular and crystallographic geometric descriptors. All functions operate on PyTorch tensors and support batch processing for high-throughput analysis.
Key Features:
Bond geometry analysis - angles, torsions, planarity metrics
Crystallographic calculations - distances to special planes, order parameters
Molecular descriptors - inertia tensors, quaternions, shape parameters
GPU acceleration - optimized PyTorch operations for large datasets
Batch processing - efficient handling of multiple structures simultaneously
Core Functions
Bond Angle Calculations
- geometry_utils.compute_bond_angles_batch(atom_labels, atom_coords, atom_mask, bond_atom1_idx, bond_atom2_idx, bond_mask, device)[source]
Compute all bond angles (i–j–k) in a batch.
- Parameters:
atom_labels (List[List[str]]) – Atom labels per structure (B × N).
atom_coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates of atoms.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms.
bond_atom1_idx (torch.LongTensor, shape (B, M)) – Index of first atom in each bond slot.
bond_atom2_idx (torch.LongTensor, shape (B, M)) – Index of second atom in each bond slot.
bond_mask (torch.BoolTensor, shape (B, M)) – True for real bonds.
device (torch.device) – Device for computation.
- Returns:
angle_ids (List[List[str]]) – Per-structure list of “i–j–k” strings.
angles (torch.Tensor, shape (B, P_max)) – Angle values in degrees (0 where padding).
mask_ang (torch.BoolTensor, shape (B, P_max)) – True for real angles.
idx_tensor (torch.LongTensor, shape (B, P_max, 3)) – Atom index triples for each angle.
- Return type:
Tuple[List[List[str]], torch.Tensor, torch.BoolTensor, torch.LongTensor]
Comprehensive Bond Angle Analysis
Computes all unique bond angles (i–j–k) where atom j is the central vertex connected to both i and k.
Algorithm:
Build adjacency graph from bond connectivity
Identify all valid angle triplets (i–j–k)
Compute vectorized angle calculations using dot products
Return organized results with proper masking
Parameters:
atom_labels (
List[List[str]]) - Atom labels per structure for identificationatom_coords (
torch.Tensor, shape (B, N, 3)) - Cartesian coordinatesatom_mask (
torch.BoolTensor, shape (B, N)) - Valid atom indicatorsbond_atom1_idx (
torch.LongTensor, shape (B, M)) - First bond atom indicesbond_atom2_idx (
torch.LongTensor, shape (B, M)) - Second bond atom indicesbond_mask (
torch.BoolTensor, shape (B, M)) - Valid bond indicatorsdevice (
torch.device) - GPU/CPU device for computation
Returns:
angle_ids (
List[List[str]]) - Angle identifiers as “i–j–k” stringsangles (
torch.Tensor, shape (B, P_max)) - Angle values in degreesmask_ang (
torch.BoolTensor, shape (B, P_max)) - Valid angle indicatorsidx_tensor (
torch.LongTensor, shape (B, P_max, 3)) - Atom index triplets
Usage Example:
import torch from geometry_utils import compute_bond_angles_batch # Sample molecular data atom_labels = [['C1', 'C2', 'C3', 'H1']] coords = torch.tensor([[[0.0, 0.0, 0.0], # C1 [1.4, 0.0, 0.0], # C2 [2.1, 1.2, 0.0], # C3 [3.2, 1.2, 0.0]]]).float() # H1 atom_mask = torch.tensor([[True, True, True, True]]) # Bond connectivity: C1-C2, C2-C3, C3-H1 bond_atom1 = torch.tensor([[0, 1, 2]]) bond_atom2 = torch.tensor([[1, 2, 3]]) bond_mask = torch.tensor([[True, True, True]]) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') angle_ids, angles, mask_ang, idx_tensor = compute_bond_angles_batch( atom_labels, coords, atom_mask, bond_atom1, bond_atom2, bond_mask, device ) print(f"Found {mask_ang.sum().item()} valid angles:") for i, angle_id in enumerate(angle_ids[0]): if mask_ang[0, i]: print(f" {angle_id}: {angles[0, i]:.1f}°")
Performance Notes:
Scales as O(B × M²) where M is the maximum number of bonds per structure
GPU acceleration provides 10-50× speedup over CPU for large batches
Memory usage: ~4 bytes per angle × batch_size × max_angles
Torsion Angle Calculations
- geometry_utils.compute_torsion_angles_batch(atom_labels, atom_coords, atom_mask, bond_atom1_idx, bond_atom2_idx, bond_mask, device)[source]
Compute all torsion angles (i–j–k–l) for each molecule in a batch.
- Parameters:
atom_labels (List[List[str]]) – Atom labels per structure (B × N).
atom_coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms.
bond_atom1_idx (torch.Tensor, shape (B, M)) – Index of first atom in each bond slot.
bond_atom2_idx (torch.Tensor, shape (B, M)) – Index of second atom in each bond slot.
bond_mask (torch.BoolTensor, shape (B, M)) – True for real bonds.
device (torch.device) – Device for computation.
- Returns:
torsion_ids (List[List[str]]) – Per-structure list of “i–j–k–l” strings.
torsions (torch.Tensor, shape (B, T_max)) – Dihedral angles in degrees (0 where padding).
mask_tor (torch.BoolTensor, shape (B, T_max)) – True for real torsions.
idx_tensor (torch.LongTensor, shape (B, T_max, 4)) – Atom index quadruplets for each torsion.
- Return type:
Tuple[List[List[str]], torch.Tensor, torch.BoolTensor, torch.LongTensor]
Dihedral Angle Computation
Calculates all valid torsion (dihedral) angles for molecular conformations using the four-atom sequence i–j–k–l.
Parameters:
atom_labels (
List[List[str]]) - Atom identification labelsatom_coords (
torch.Tensor, shape (B, N, 3)) - 3D atomic coordinatesatom_mask (
torch.BoolTensor, shape (B, N)) - Valid atom maskbond_atom1_idx (
torch.LongTensor, shape (B, M)) - Bond connectivity indicesbond_atom2_idx (
torch.LongTensor, shape (B, M)) - Bond connectivity indicesbond_mask (
torch.BoolTensor, shape (B, M)) - Valid bond maskdevice (
torch.device) - Computation device
Returns:
torsion_ids (
List[List[str]]) - Torsion identifiers as “i–j–k–l”torsions (
torch.Tensor, shape (B, T_max)) - Dihedral angles in degrees (-180° to +180°)mask_tor (
torch.BoolTensor, shape (B, T_max)) - Valid torsion maskidx_tensor (
torch.LongTensor, shape (B, T_max, 4)) - Four-atom index sets
Mathematical Implementation:
Uses the standard dihedral angle formula with cross products:
\[\phi = \arctan2\left(\vec{n_1} \times \vec{n_2} \cdot \hat{b_2}, \vec{n_1} \cdot \vec{n_2}\right)\]Where \(\vec{n_1} = \vec{b_1} \times \vec{b_2}\) and \(\vec{n_2} = \vec{b_2} \times \vec{b_3}\).
Bond Rotatability Analysis
- geometry_utils.compute_bond_rotatability_batch(atom_symbols, bond_atom1_idx, bond_atom2_idx, bond_is_cyclic, bond_types, bond_is_rotatable_raw, device)[source]
Determine which bonds are rotatable based on CSD criteria.
- Parameters:
atom_symbols (List[List[str]]) – Element symbols per atom for each structure (B × N).
bond_atom1_idx (torch.LongTensor, shape (B, M)) – Index of the first atom in each bond slot.
bond_atom2_idx (torch.LongTensor, shape (B, M)) – Index of the second atom in each bond slot.
bond_is_cyclic (torch.BoolTensor, shape (B, M)) – True if the bond is in a ring.
bond_types (List[List[str]]) – Bond types (e.g. ‘single’, ‘double’, ‘triple’) per slot.
bond_is_rotatable_raw (List[List[bool]]) – Original CSD rotatability flags per slot.
device (torch.device) – Device for the output tensor.
- Returns:
True where the bond passes all checks and is rotatable.
- Return type:
torch.BoolTensor, shape (B, M)
Rotatable Bond Classification
Identifies rotatable bonds based on chemical environment and structural constraints.
Rotatable Bond Criteria:
Non-hydrogen connectivity - Both atoms have ≥2 non-hydrogen neighbors
Single bond type - Must be a single covalent bond
Non-cyclic - Bond is not part of a ring system
Non-linear - Neither atom is sp-hybridized or in cumulated double bonds
Parameters:
atom_symbols (
List[List[str]]) - Atomic symbols (C, N, O, etc.)bond_atom1_idx (
torch.LongTensor) - Bond connectivitybond_atom2_idx (
torch.LongTensor) - Bond connectivitybond_types (
List[List[str]]) - Bond type annotations (‘single’, ‘double’, etc.)bond_in_ring (
List[List[bool]]) - Ring membership flagsbond_mask (
torch.BoolTensor) - Valid bond indicatorsdevice (
torch.device) - Computation device
Returns:
torch.BoolTensor, shape (B, M) - True for rotatable bonds
Applications:
Drug-like property assessment (Lipinski’s Rule of Five)
Conformational flexibility analysis
Molecular dynamics preparation
Structure-activity relationship studies
Planarity Analysis
- geometry_utils.compute_best_fit_plane_batch(coords, mask, device)[source]
Compute best-fit plane normals and centroids via SVD.
- Parameters:
coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates of atoms, padded to N.
mask (torch.BoolTensor, shape (B, N)) – True for real atoms, False for padding.
device (torch.device) – Device for computation.
- Returns:
normals (torch.Tensor, shape (B, 3)) – Unit normals of the best-fit planes (z ≥ 0).
centroids (torch.Tensor, shape (B, 3)) – Centroid of valid atoms per batch entry.
- Return type:
Best-Fit Plane Computation
Determines optimal plane through a set of atoms using least-squares fitting with optional weighting.
Parameters:
coords (
torch.Tensor, shape (B, N, 3)) - Atomic coordinatesweights (
torch.Tensor, shape (B, N)) - Per-atom weights (masses, charges, etc.)mask (
torch.BoolTensor, shape (B, N)) - Valid atom indicatorsdevice (
torch.device) - Computation device
Returns:
plane_normal (
torch.Tensor, shape (B, 3)) - Unit normal vectorsplane_centroid (
torch.Tensor, shape (B, 3)) - Plane centroidseigenvalues (
torch.Tensor, shape (B, 3)) - Principal component eigenvalues
- geometry_utils.compute_planarity_metrics_batch(coords, mask, normals, centroids, device, decay_width=0.5)[source]
Compute planarity metrics (RMSD, max deviation, planarity score) for a batch.
- Parameters:
coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates of atoms, padded to N.
mask (torch.BoolTensor, shape (B, N)) – True for real atoms.
normals (torch.Tensor, shape (B, 3)) – Plane normals (unit vectors).
centroids (torch.Tensor, shape (B, 3)) – Centroids of valid atoms.
device (torch.device) – Device for computation.
decay_width (float, optional) – Width parameter for exponential planarity score.
- Returns:
rmsd (torch.Tensor, shape (B,)) – Root‐mean‐square deviation from the plane.
max_dev (torch.Tensor, shape (B,)) – Maximum absolute deviation from the plane.
planarity_score (torch.Tensor, shape (B,)) – Exponential planarity score exp(–rmsd/decay_width).
- Return type:
Comprehensive Planarity Assessment
Computes multiple planarity descriptors for molecular fragments and rings.
Planarity Metrics:
RMSD - Root-mean-square deviation from best-fit plane
Max deviation - Maximum atomic displacement from plane
Planarity score - Normalized measure (0 = perfectly planar, 1 = highly non-planar)
Thickness - Distance between extreme atoms perpendicular to plane
Parameters:
coords (
torch.Tensor, shape (B, N, 3)) - Atomic coordinatesmask (
torch.BoolTensor, shape (B, N)) - Valid atom maskdevice (
torch.device) - Computation device
Returns:
planarity_rmsd (
torch.Tensor, shape (B,)) - RMSD values in Ångstromsplanarity_max_dev (
torch.Tensor, shape (B,)) - Maximum deviations in Ångstromsplanarity_score (
torch.Tensor, shape (B,)) - Normalized planarity scoresplane_normal (
torch.Tensor, shape (B, 3)) - Best-fit plane normalsplane_centroid (
torch.Tensor, shape (B, 3)) - Plane centroids
Usage Example:
# Analyze planarity of aromatic rings ring_coords = extract_ring_coordinates(molecule) ring_mask = create_valid_atom_mask(ring_coords) rmsd, max_dev, score, normal, centroid = compute_planarity_metrics_batch( ring_coords, ring_mask, device ) # Classify ring planarity for i, s in enumerate(score): if s < 0.1: print(f"Ring {i}: Highly planar (score: {s:.3f})") elif s < 0.3: print(f"Ring {i}: Moderately planar (score: {s:.3f})") else: print(f"Ring {i}: Non-planar (score: {s:.3f})")
Crystallographic Analysis
- geometry_utils.compute_distances_to_crystallographic_planes_frac_batch(atom_frac_coords, atom_mask, device)[source]
Compute fractional distances of each atom to the 26 special crystallographic planes.
- Parameters:
atom_frac_coords (torch.Tensor, shape (B, A, 3)) – Fractional coordinates of atoms, padded to A slots per structure.
atom_mask (torch.BoolTensor, shape (B, A)) – True for valid atoms, False for padding.
device (torch.device) – Device on which to perform the computation.
- Returns:
Absolute fractional distances to each plane (13 normals × 2 denominators).
- Return type:
torch.Tensor, shape (B, A, 26)
Special Plane Distance Analysis
Computes fractional distances from atoms to the 26 special crystallographic planes used in structure analysis.
Special Planes:
Primary planes: (100), (010), (001) - unit cell faces
Diagonal planes: (110), (101), (011) and negatives - face diagonals
Body diagonal planes: (111) and variants - space diagonals
Multiple denominators: 4 and 6 for enhanced precision
Parameters:
atom_frac_coords (
torch.Tensor, shape (B, A, 3)) - Fractional coordinatesatom_mask (
torch.BoolTensor, shape (B, A)) - Valid atom indicatorsdevice (
torch.device) - Computation device
Returns:
torch.Tensor, shape (B, A, 26) - Fractional distances to each special plane
Applications:
Packing motif characterization
Symmetry analysis and space group validation
Crystal engineering and polymorph prediction
Structure factor analysis for powder diffraction
- geometry_utils.compute_angles_between_bonds_and_crystallographic_planes_frac_batch(atom_frac_coords, bond_atom1, bond_atom2, bond_mask, device)[source]
Compute angles (in degrees) between each bond vector and the 13 crystallographic plane normals.
- Parameters:
atom_frac_coords (torch.Tensor, shape (B, A, 3)) – Fractional coordinates of all atoms, padded to A per structure.
bond_atom1 (torch.LongTensor, shape (B, M)) – Index of the first atom in each bond slot.
bond_atom2 (torch.LongTensor, shape (B, M)) – Index of the second atom in each bond slot.
bond_mask (torch.BoolTensor, shape (B, M)) – True for real bonds, False for padding.
device (torch.device) – Device on which to perform the computation.
- Returns:
Bond–plane angles in degrees; zeros where bond_mask is False.
- Return type:
torch.Tensor, shape (B, M, 13)
Bond-Plane Angular Analysis
Calculates angles between molecular bonds and crystallographic plane normals for orientation analysis.
Parameters:
atom_frac_coords (
torch.Tensor, shape (B, A, 3)) - Fractional atomic coordinatesbond_atom1 (
torch.LongTensor, shape (B, M)) - First bond atom indicesbond_atom2 (
torch.LongTensor, shape (B, M)) - Second bond atom indicesbond_mask (
torch.BoolTensor, shape (B, M)) - Valid bond indicatorsdevice (
torch.device) - Computation device
Returns:
torch.Tensor, shape (B, M, 13) - Bond-plane angles in degrees
Interpretation:
0° - Bond parallel to plane
90° - Bond perpendicular to plane
Intermediate values - Various orientations
Order Parameter Analysis
- geometry_utils.compute_global_steinhardt_order_parameters_batch(atom_to_com_vecs, atom_mask, atom_weights, device, l_values=[2, 4, 6, 8, 10], eps=1e-12)[source]
Compute global Steinhardt Q_l order parameters for a batch.
- Parameters:
atom_to_com_vecs (torch.Tensor, shape (B, N, 3)) – Positions relative to the center of mass.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms, False for padding.
atom_weights (torch.Tensor or None, shape (B, N)) – Per-atom weights, or None for uniform weights.
device (torch.device) – Device for computation.
l_values (List[int], optional) – List of ℓ values at which to compute Q_l.
eps (float, optional) – Small constant to avoid division by zero.
- Returns:
Q_l values for each batch entry.
- Return type:
torch.Tensor, shape (B, len(l_values))
Global Steinhardt Q_l Order Parameters
Computes rotationally invariant order parameters that characterize local atomic environments and overall structural organization.
Mathematical Background:
Steinhardt order parameters are based on spherical harmonics expansion:
\[Q_l = \sqrt{\frac{4\pi}{2l+1} \sum_{m=-l}^{l} |q_{lm}|^2}\]Where \(q_{lm}\) are the averaged spherical harmonic coefficients.
Parameters:
atom_to_com_vecs (
torch.Tensor, shape (B, N, 3)) - Vectors from center of massatom_mask (
torch.BoolTensor, shape (B, N)) - Valid atom indicatorsatom_weights (
torch.Tensoror None, shape (B, N)) - Optional atomic weightsdevice (
torch.device) - Computation devicel_values (
List[int]) - Order parameter degrees [2, 4, 6, 8, 10]eps (
float) - Numerical stability parameter
Returns:
torch.Tensor, shape (B, len(l_values)) - Q_l values for each structure
Physical Interpretation:
Q_2 - Measures nematic ordering (molecular alignment)
Q_4 - Detects tetrahedral vs. other local symmetries
Q_6 - Sensitive to hexagonal/cubic ordering
Q_8, Q_10 - Higher-order structural correlations
Applications:
Phase transition characterization
Crystal quality assessment
Polymorphism detection
Disorder quantification
Coordinate Transformations
- geometry_utils.compute_atom_vectors_to_point_batch(atom_coords, atom_frac_coords, atom_mask, com_coords, com_frac_coords, device)[source]
Compute displacement vectors and Euclidean distances from each atom to a reference point.
- Parameters:
atom_coords (torch.Tensor, shape (B, N, 3)) – Cartesian coordinates, padded to N per fragment.
atom_frac_coords (torch.Tensor, shape (B, N, 3)) – Fractional coordinates, padded to N per fragment.
atom_mask (torch.BoolTensor, shape (B, N)) – True for real atoms, False for padding.
com_coords (torch.Tensor, shape (B, 3)) – Reference points in Cartesian space.
com_frac_coords (torch.Tensor, shape (B, 3)) – Reference points in fractional space.
device (torch.device) – Device for computation.
- Returns:
dists_cart (torch.Tensor, shape (B, N)) – Euclidean distances in Cartesian space.
dists_frac (torch.Tensor, shape (B, N)) – Euclidean distances in fractional space.
vecs_cart (torch.Tensor, shape (B, N, 3)) – Cartesian displacement vectors (atom → point).
vecs_frac (torch.Tensor, shape (B, N, 3)) – Fractional displacement vectors.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Vector Computation to Reference Points
Calculates displacement vectors from atoms to specified reference points (centers of mass, centroids, etc.).
Parameters:
atom_coords (
torch.Tensor, shape (B, N, 3)) - Atomic coordinatesatom_frac_coords (
torch.Tensor, shape (B, N, 3)) - Fractional coordinatesatom_mask (
torch.BoolTensor, shape (B, N)) - Valid atom maskreference_coords (
torch.Tensor, shape (B, 3)) - Reference point coordinatesreference_frac_coords (
torch.Tensor, shape (B, 3)) - Reference fractional coordinatesdevice (
torch.device) - Computation device
Returns:
atom_to_ref_vec (
torch.Tensor, shape (B, N, 3)) - Cartesian displacement vectorsatom_to_ref_frac_vec (
torch.Tensor, shape (B, N, 3)) - Fractional displacement vectorsatom_to_ref_dist (
torch.Tensor, shape (B, N)) - Euclidean distances
- geometry_utils.compute_quaternions_from_rotation_matrices(R, device)[source]
Convert rotation matrices into unit quaternions [w, x, y, z] with w ≥ 0.
- Parameters:
R (torch.Tensor, shape (B, 3, 3)) – Proper rotation matrices (RᵀR = I, det=+1).
device (torch.device) – Device for computation.
- Returns:
Unit quaternions corresponding to each rotation matrix.
- Return type:
torch.Tensor, shape (B, 4)
Rotation Matrix to Quaternion Conversion
Converts 3×3 rotation matrices to unit quaternions using robust numerical algorithms.
Parameters:
R (
torch.Tensor, shape (B, 3, 3)) - Proper rotation matrices (R^T R = I, det = +1)device (
torch.device) - Computation device
Returns:
torch.Tensor, shape (B, 4) - Unit quaternions [w, x, y, z] with w ≥ 0
Algorithm Features:
Numerically stable - Handles near-singular cases
Consistent handedness - Enforces w ≥ 0 convention
Batch optimized - Vectorized operations for efficiency
Applications:
Molecular orientation analysis
Crystal symmetry operations
Rigid body motion decomposition
Interpolation between orientations
Fragment Pairwise Analysis
- geometry_utils.compute_fragment_pairwise_vectors_and_distances_batch(coords, mask, heavy_mask, device)[source]
Compute pairwise displacement vectors and Euclidean distances between non-hydrogen atoms in each fragment.
- Parameters:
coords (torch.Tensor, shape (F, A, 3)) – Cartesian coordinates for each of F fragments, padded to A atoms.
mask (torch.BoolTensor, shape (F, A)) – True for real atoms slots, False for padding.
heavy_mask (torch.BoolTensor, shape (F, A)) – True for heavy (non-H) atom slots, False otherwise.
device (torch.device) – Device on which to perform the computation.
- Returns:
distances (torch.Tensor, shape (F, A, A)) – Pairwise Euclidean distances. Entry [f,i,j] is the distance between atom i and j in fragment f if both are heavy & real; zero otherwise.
vectors (torch.Tensor, shape (F, A, A, 3)) – Pairwise displacement vectors: coords[j]−coords[i] for each heavy-atom pair, zero for any non-heavy or padding atom involved.
atom1_idx (torch.LongTensor, shape (P,)) – The “i” index of each unique heavy-atom pair (i<j), across all fragments.
atom2_idx (torch.LongTensor, shape (P,)) – The “j” index of each unique heavy-atom pair (i<j), across all fragments.
- Return type:
Notes
The order of entries in atom1_indices and atom2_indices matches the order you’d get by iterating through torch.nonzero(pair_valid & (j>i)).
Intra-Fragment Distance Matrix Computation
Calculates all pairwise distances within molecular fragments for shape analysis and internal geometry characterization.
Parameters:
atom_frac_coords (
torch.Tensor, shape (B, N, 3)) - Fractional atomic coordinatesatom_mask (
torch.BoolTensor, shape (B, N)) - Valid atom indicatorsfragment_atom_assignments (
torch.LongTensor, shape (B, N)) - Fragment membershipmax_atoms_per_fragment (
int) - Maximum atoms for memory allocationdevice (
torch.device) - Computation device
Returns:
fragment_pairwise_vectors (
torch.Tensor) - Inter-atomic vectors within fragmentsfragment_pairwise_distances (
torch.Tensor) - Distance matrices for each fragmentfragment_pairwise_mask (
torch.BoolTensor) - Valid pair indicators
Applications:
Molecular flexibility analysis
Conformational change detection
Internal coordinate validation
Shape descriptor computation
Performance Optimization
GPU Acceleration Guidelines
The geometry_utils module is optimized for GPU computation with PyTorch. Follow these best practices:
# Optimal device usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Batch size recommendations
if device.type == 'cuda':
batch_size = 64 # Larger batches for GPU
else:
batch_size = 16 # Smaller batches for CPU
# Memory management
torch.cuda.empty_cache() # Clear cache between large computations
# Data type optimization
coords = coords.to(device, dtype=torch.float32) # float32 sufficient for most cases
Memory Usage Patterns
Linear scaling: Most functions scale O(B × N) with batch size and atom count
Quadratic scaling: Pairwise functions scale O(B × N²) - use carefully
GPU memory: Monitor usage with
torch.cuda.memory_allocated()
Batch Size Optimization
def optimal_batch_size(total_structures, available_memory_gb):
"""Estimate optimal batch size based on available GPU memory."""
memory_per_structure_mb = 50 # Approximate for typical molecules
structures_per_gb = 1000 / memory_per_structure_mb
max_batch_size = int(available_memory_gb * structures_per_gb * 0.8) # 80% safety margin
return min(max_batch_size, total_structures)
Error Handling and Validation
Common Error Patterns
# Input validation
def validate_geometry_inputs(coords, mask):
if coords.dim() != 3:
raise ValueError(f"Expected 3D coordinates tensor, got {coords.dim()}D")
if coords.shape[:2] != mask.shape:
raise ValueError("Coordinate and mask shapes must match")
if torch.any(torch.isnan(coords)):
raise ValueError("NaN values detected in coordinates")
if torch.any(torch.isinf(coords)):
raise ValueError("Infinite values detected in coordinates")
Debugging Tools
# Diagnostic functions
def check_tensor_health(tensor, name):
print(f"{name}: shape={tensor.shape}, dtype={tensor.dtype}")
print(f" Range: [{tensor.min():.3f}, {tensor.max():.3f}]")
print(f" NaN count: {torch.isnan(tensor).sum()}")
print(f" Inf count: {torch.isinf(tensor).sum()}")
Integration Examples
Comprehensive Molecular Analysis Pipeline
import torch
from geometry_utils import *
def analyze_molecular_geometry(structures_batch):
"""Complete geometric analysis of molecular structures."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
results = {}
# 1. Bond angle analysis
angle_ids, angles, angle_mask, angle_indices = compute_bond_angles_batch(
structures_batch['atom_labels'],
structures_batch['coords'],
structures_batch['atom_mask'],
structures_batch['bond_atom1'],
structures_batch['bond_atom2'],
structures_batch['bond_mask'],
device
)
results['bond_angles'] = angles
# 2. Torsion angle analysis
torsion_ids, torsions, torsion_mask, torsion_indices = compute_torsion_angles_batch(
structures_batch['atom_labels'],
structures_batch['coords'],
structures_batch['atom_mask'],
structures_batch['bond_atom1'],
structures_batch['bond_atom2'],
structures_batch['bond_mask'],
device
)
results['torsion_angles'] = torsions
# 3. Planarity analysis for aromatic rings
ring_planarity = compute_planarity_metrics_batch(
structures_batch['ring_coords'],
structures_batch['ring_mask'],
device
)
results['planarity'] = ring_planarity
# 4. Order parameter analysis
order_params = compute_global_steinhardt_order_parameters_batch(
structures_batch['atom_to_com_vectors'],
structures_batch['atom_mask'],
structures_batch['atom_weights'],
device
)
results['order_parameters'] = order_params
return results
Crystallographic Structure Analysis
def analyze_crystal_packing(crystal_batch):
"""Analyze crystallographic packing features."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Distance to special planes
plane_distances = compute_distances_to_crystallographic_planes_frac_batch(
crystal_batch['frac_coords'],
crystal_batch['atom_mask'],
device
)
# Bond orientations relative to crystal axes
bond_plane_angles = compute_angles_between_bonds_and_crystallographic_planes_frac_batch(
crystal_batch['frac_coords'],
crystal_batch['bond_atom1'],
crystal_batch['bond_atom2'],
crystal_batch['bond_mask'],
device
)
return {
'plane_distances': plane_distances,
'bond_orientations': bond_plane_angles
}
Cross-References
Related CSA Modules:
fragment_utils module - Fragment identification and properties
contact_utils module - Intermolecular contact analysis
cell_utils module - Unit cell transformations
structure_post_extraction_processor module - Main processing pipeline
data_reader module - Input data handling
External Dependencies:
PyTorch - Tensor operations and GPU acceleration
NumPy - Array operations and mathematical functions
SciPy - Advanced mathematical algorithms
Mathematical References:
Steinhardt, P. J. et al. “Bond-orientational order in liquids and glasses” Physical Review B 28, 784 (1983)
Allen, M. P. & Tildesley, D. J. “Computer Simulation of Liquids” Oxford University Press (2017)
Giacovazzo, C. et al. “Fundamentals of Crystallography” Oxford University Press (2011)