structure_post_extraction_processor module
Module: structure_post_extraction_processor.py
Post-extraction processing for CSD structures.
Reads raw HDF5 outputs (from StructureDataExtractor), computes derived features (geometric, topological, contact-based) in GPU-accelerated batches, and writes both raw and computed datasets into a new “*_processed.h5” container. Designed for high throughput and minimal I/O overhead.
Dependencies
h5py numpy torch data_reader data_writer dataset_initializer dimension_scanner cell_utils contact_utils fragment_utils geometry_utils symmetry_utils
- class structure_post_extraction_processor.CrystalParams(cell_lengths, cell_angles)[source]
Bases:
NamedTupleContainer for basic crystal parameters.
- cell_lengths
Unit-cell lengths [a, b, c] for each structure in the batch.
- Type:
torch.Tensor, shape (B, 3)
- cell_angles
Unit-cell angles [α, β, γ] in degrees for each structure in the batch.
- Type:
torch.Tensor, shape (B, 3)
- cell_lengths: torch.Tensor
Alias for field number 0
- cell_angles: torch.Tensor
Alias for field number 1
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- class structure_post_extraction_processor.AtomParams(labels, symbols, coords, frac_coords, mask, weights, charges)[source]
Bases:
NamedTupleContainer for atomic-level parameters.
- coords
Cartesian coordinates of each atom.
- Type:
torch.Tensor, shape (B, max_atoms, 3)
- frac_coords
Fractional coordinates of each atom.
- Type:
torch.Tensor, shape (B, max_atoms, 3)
- mask
Boolean mask indicating valid atom entries.
- Type:
torch.BoolTensor, shape (B, max_atoms)
- weights
Atomic weights.
- Type:
torch.Tensor, shape (B, max_atoms)
- charges
Partial charges per atom.
- Type:
torch.Tensor, shape (B, max_atoms)
- coords: torch.Tensor
Alias for field number 2
- frac_coords: torch.Tensor
Alias for field number 3
- mask: torch.BoolTensor
Alias for field number 4
- weights: torch.Tensor
Alias for field number 5
- charges: torch.Tensor
Alias for field number 6
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- class structure_post_extraction_processor.BondParams(atom1_idx, atom2_idx, bond_type, is_rotatable_raw, is_cyclic, mask)[source]
Bases:
NamedTupleContainer for bond-level parameters.
- atom1_idx
Index of first atom in each bond.
- Type:
torch.LongTensor, shape (B, max_bonds)
- atom2_idx
Index of second atom in each bond.
- Type:
torch.LongTensor, shape (B, max_bonds)
- bond_type
Numeric or categorical encoding of bond types.
- Type:
torch.Tensor, shape (B, max_bonds)
- is_rotatable_raw
Initial mask for bond rotatability.
- Type:
torch.BoolTensor, shape (B, max_bonds)
- is_cyclic
Indicates if bond is part of a ring.
- Type:
torch.BoolTensor, shape (B, max_bonds)
- mask
Boolean mask indicating valid bond entries.
- Type:
torch.BoolTensor, shape (B, max_bonds)
- atom1_idx: torch.LongTensor
Alias for field number 0
- atom2_idx: torch.LongTensor
Alias for field number 1
- bond_type: torch.Tensor
Alias for field number 2
- is_rotatable_raw: torch.BoolTensor
Alias for field number 3
- is_cyclic: torch.BoolTensor
Alias for field number 4
- mask: torch.BoolTensor
Alias for field number 5
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- class structure_post_extraction_processor.InterCCParams(central_atom, contact_atom, central_atom_idx, contact_atom_idx, central_atom_frac_coords, contact_atom_frac_coords, lengths, strengths, in_los, symmetry_A, symmetry_T, symmetry_A_inv, symmetry_T_inv)[source]
Bases:
NamedTupleContainer for intermolecular close-contact parameters.
- central_atom_idx
Indices of central atoms.
- Type:
torch.LongTensor, shape (B, C)
- contact_atom_idx
Indices of contact atoms.
- Type:
torch.LongTensor, shape (B, C)
- central_atom_frac_coords
Fractional coords of central atoms.
- Type:
torch.Tensor, shape (B, C, 3)
- contact_atom_frac_coords
Fractional coords of contact atoms.
- Type:
torch.Tensor, shape (B, C, 3)
- lengths
Contact distances.
- Type:
torch.Tensor, shape (B, C)
- strengths
Contact strength metrics.
- Type:
torch.Tensor, shape (B, C)
- in_los
Mask for line-of-sight contacts.
- Type:
torch.Tensor, shape (B, C)
- symmetry_A
Symmetry operation rotation matrices.
- Type:
torch.Tensor, shape (B, C, 3, 3)
- symmetry_T
Symmetry operation translation vectors.
- Type:
torch.Tensor, shape (B, C, 3)
- symmetry_A_inv
Inverse rotation matrices.
- Type:
torch.Tensor, shape (B, C, 3, 3)
- symmetry_T_inv
Inverse translation vectors.
- Type:
torch.Tensor, shape (B, C, 3)
- central_atom_idx: torch.LongTensor
Alias for field number 2
- contact_atom_idx: torch.LongTensor
Alias for field number 3
- central_atom_frac_coords: torch.Tensor
Alias for field number 4
- contact_atom_frac_coords: torch.Tensor
Alias for field number 5
- lengths: torch.Tensor
Alias for field number 6
- strengths: torch.Tensor
Alias for field number 7
- in_los: torch.Tensor
Alias for field number 8
- symmetry_A: torch.Tensor
Alias for field number 9
- symmetry_T: torch.Tensor
Alias for field number 10
- symmetry_A_inv: torch.Tensor
Alias for field number 11
- symmetry_T_inv: torch.Tensor
Alias for field number 12
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- class structure_post_extraction_processor.InterHBParams(central_atom, hydrogen_atom, contact_atom, central_atom_idx, hydrogen_atom_idx, contact_atom_idx, central_atom_frac_coords, hydrogen_atom_frac_coords, contact_atom_frac_coords, lengths, angles, in_los, symmetry_A, symmetry_T, symmetry_A_inv, symmetry_T_inv)[source]
Bases:
NamedTupleContainer for intermolecular hydrogen-bond parameters.
- central_atom_idx
Indices of donor atoms.
- Type:
torch.LongTensor, shape (B, H)
- hydrogen_atom_idx
Indices of hydrogen atoms.
- Type:
torch.LongTensor, shape (B, H)
- contact_atom_idx
Indices of acceptor atoms.
- Type:
torch.LongTensor, shape (B, H)
- central_atom_frac_coords
Fractional coords of donor atoms.
- Type:
torch.Tensor, shape (B, H, 3)
- hydrogen_atom_frac_coords
Fractional coords of hydrogen atoms.
- Type:
torch.Tensor, shape (B, H, 3)
- contact_atom_frac_coords
Fractional coords of acceptor atoms.
- Type:
torch.Tensor, shape (B, H, 3)
- lengths
H-bond distances.
- Type:
torch.Tensor, shape (B, H)
- angles
H-bond angles.
- Type:
torch.Tensor, shape (B, H)
- in_los
Mask for line-of-sight H-bonds.
- Type:
torch.Tensor, shape (B, H)
- symmetry_A
Symmetry rotation matrices.
- Type:
torch.Tensor, shape (B, H, 3, 3)
- symmetry_T
Symmetry translation vectors.
- Type:
torch.Tensor, shape (B, H, 3)
- symmetry_A_inv
Inverse rotation matrices.
- Type:
torch.Tensor, shape (B, H, 3, 3)
- symmetry_T_inv
Inverse translation vectors.
- Type:
torch.Tensor, shape (B, H, 3)
- central_atom_idx: torch.LongTensor
Alias for field number 3
- hydrogen_atom_idx: torch.LongTensor
Alias for field number 4
- contact_atom_idx: torch.LongTensor
Alias for field number 5
- central_atom_frac_coords: torch.Tensor
Alias for field number 6
- hydrogen_atom_frac_coords: torch.Tensor
Alias for field number 7
- contact_atom_frac_coords: torch.Tensor
Alias for field number 8
- lengths: torch.Tensor
Alias for field number 9
- angles: torch.Tensor
Alias for field number 10
- in_los: torch.Tensor
Alias for field number 11
- symmetry_A: torch.Tensor
Alias for field number 12
- symmetry_T: torch.Tensor
Alias for field number 13
- symmetry_A_inv: torch.Tensor
Alias for field number 14
- symmetry_T_inv: torch.Tensor
Alias for field number 15
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- class structure_post_extraction_processor.StructurePostExtractionProcessor(hdf5_path, batch_size, device=None)[source]
Bases:
objectOrchestrates post-extraction computation of derived structure features.
Reads raw data from an HDF5 file, processes structures in GPU-accelerated batches to compute geometric, topological, and contact-based features, and writes both raw and computed data to a new processed HDF5 file.
- __init__(hdf5_path, batch_size, device=None)[source]
Initialize the processor.
- Parameters:
hdf5_path (Path) – Path to the raw HDF5 file containing extracted structure data.
batch_size (int) – Number of structures to process per GPU batch.
device (str or torch.device, optional) – Device specifier (e.g., ‘cuda’, ‘cpu’); if None, selects CUDA if available.
Advanced Feature Engineering Pipeline
The structure_post_extraction_processor module implements GPU-accelerated computation of advanced structural descriptors and features from raw crystal structure data. This module forms the core of Stage 5 in the CSA pipeline.
Data Structure Classes
CrystalParams
- class structure_post_extraction_processor.CrystalParams(cell_lengths, cell_angles)[source]
Bases:
NamedTupleContainer for basic crystal parameters.
- cell_lengths
Unit-cell lengths [a, b, c] for each structure in the batch.
- Type:
torch.Tensor, shape (B, 3)
- cell_angles
Unit-cell angles [α, β, γ] in degrees for each structure in the batch.
- Type:
torch.Tensor, shape (B, 3)
Crystal-Level Parameter Container
Dataclass holding crystal properties and unit cell information for batch processing.
- Attributes:
identifiers (
List[str]) - CSD refcodes for structuresspace_groups (
List[str]) - Space group symbolsz_values (
torch.LongTensor) - Z values per structurez_prime (
torch.Tensor) - Z’ values per structurecell_volumes (
torch.Tensor) - Unit cell volumes (ų)cell_densities (
torch.Tensor) - Crystal densities (g/cm³)cell_lengths (
torch.Tensor) - Unit cell parameters a, b, ccell_angles (
torch.Tensor) - Unit cell angles α, β, γ
- cell_lengths: torch.Tensor
Alias for field number 0
- cell_angles: torch.Tensor
Alias for field number 1
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
AtomParams
- class structure_post_extraction_processor.AtomParams(labels, symbols, coords, frac_coords, mask, weights, charges)[source]
Bases:
NamedTupleContainer for atomic-level parameters.
- coords
Cartesian coordinates of each atom.
- Type:
torch.Tensor, shape (B, max_atoms, 3)
- frac_coords
Fractional coordinates of each atom.
- Type:
torch.Tensor, shape (B, max_atoms, 3)
- mask
Boolean mask indicating valid atom entries.
- Type:
torch.BoolTensor, shape (B, max_atoms)
- weights
Atomic weights.
- Type:
torch.Tensor, shape (B, max_atoms)
- charges
Partial charges per atom.
- Type:
torch.Tensor, shape (B, max_atoms)
Atomic-Level Parameter Container
Dataclass holding atomic coordinates, properties, and connectivity information.
- Attributes:
labels (
List[List[str]]) - Atomic labels per structurecoords (
torch.Tensor) - Cartesian coordinates (Å)frac_coords (
torch.Tensor) - Fractional coordinatesweights (
torch.Tensor) - Atomic masses (Da)numbers (
torch.LongTensor) - Atomic numberscharges (
torch.Tensor) - Formal chargessymbols (
List[List[str]]) - Element symbolssybyl_types (
List[List[str]]) - SYBYL atom typesmask (
torch.BoolTensor) - Validity mask for atoms
- coords: torch.Tensor
Alias for field number 2
- frac_coords: torch.Tensor
Alias for field number 3
- mask: torch.BoolTensor
Alias for field number 4
- weights: torch.Tensor
Alias for field number 5
- charges: torch.Tensor
Alias for field number 6
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
BondParams
- class structure_post_extraction_processor.BondParams(atom1_idx, atom2_idx, bond_type, is_rotatable_raw, is_cyclic, mask)[source]
Bases:
NamedTupleContainer for bond-level parameters.
- atom1_idx
Index of first atom in each bond.
- Type:
torch.LongTensor, shape (B, max_bonds)
- atom2_idx
Index of second atom in each bond.
- Type:
torch.LongTensor, shape (B, max_bonds)
- bond_type
Numeric or categorical encoding of bond types.
- Type:
torch.Tensor, shape (B, max_bonds)
- is_rotatable_raw
Initial mask for bond rotatability.
- Type:
torch.BoolTensor, shape (B, max_bonds)
- is_cyclic
Indicates if bond is part of a ring.
- Type:
torch.BoolTensor, shape (B, max_bonds)
- mask
Boolean mask indicating valid bond entries.
- Type:
torch.BoolTensor, shape (B, max_bonds)
Bond Connectivity Parameter Container
Dataclass holding molecular bond information and connectivity graphs.
- Attributes:
atom1_labels (
List[List[str]]) - First bond partner labelsatom2_labels (
List[List[str]]) - Second bond partner labelsatom1_idx (
torch.LongTensor) - First partner indicesatom2_idx (
torch.LongTensor) - Second partner indicesorders (
torch.Tensor) - Bond order valuesmask (
torch.BoolTensor) - Validity mask for bonds
- atom1_idx: torch.LongTensor
Alias for field number 0
- atom2_idx: torch.LongTensor
Alias for field number 1
- bond_type: torch.Tensor
Alias for field number 2
- is_rotatable_raw: torch.BoolTensor
Alias for field number 3
- is_cyclic: torch.BoolTensor
Alias for field number 4
- mask: torch.BoolTensor
Alias for field number 5
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
ContactParams
HBondParams
StructurePostExtractionProcessor Class
- class structure_post_extraction_processor.StructurePostExtractionProcessor(hdf5_path, batch_size, device=None)[source]
Bases:
objectOrchestrates post-extraction computation of derived structure features.
Reads raw data from an HDF5 file, processes structures in GPU-accelerated batches to compute geometric, topological, and contact-based features, and writes both raw and computed data to a new processed HDF5 file.
GPU-Accelerated Feature Engineering Pipeline
Orchestrates the computation of advanced structural descriptors and features from raw crystal structure data using GPU acceleration for maximum performance.
Core Capabilities:
Rigid Fragment Analysis - Identifies molecular fragments and rigid groups
Geometric Descriptors - Computes shape descriptors, moment tensors, orientations
Contact Analysis - Analyzes intermolecular interactions and networks
Symmetry Operations - Applies crystallographic symmetry transformations
Feature Engineering - Derives machine learning-ready descriptors
GPU Acceleration - Leverages CUDA for batch tensor operations
Processing Pipeline:
Data Loading - Read raw data from HDF5 into GPU tensors
Fragment Identification - Detect rigid molecular fragments
Geometric Analysis - Compute centers of mass, inertia tensors, orientations
Contact Expansion - Apply symmetry operations to intermolecular contacts
Feature Computation - Calculate advanced descriptors and properties
Data Writing - Save computed features to processed HDF5 file
- Attributes:
hdf5_path (
Path) - Input HDF5 file with raw databatch_size (
int) - Structures processed per GPU batchdevice (
torch.device) - GPU device for tensor operationsraw_reader (
DataReader) - Raw data loading interfaceraw_writer (
DataWriter) - Raw data writing interfacecomputed_writer (
DataWriter) - Computed data writing interface
- __init__(hdf5_path, batch_size, device=None)[source]
Initialize the processor.
- Parameters:
hdf5_path (Path) – Path to the raw HDF5 file containing extracted structure data.
batch_size (int) – Number of structures to process per GPU batch.
device (str or torch.device, optional) – Device specifier (e.g., ‘cuda’, ‘cpu’); if None, selects CUDA if available.
Initialize Post-Extraction Processor
- Parameters:
hdf5_path (
Path) - Path to raw HDF5 filebatch_size (
int) - Number of structures per GPU batchdevice (
Optional[Union[str, torch.device]]) - GPU device specification
Device Configuration:
# Automatic device selection if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Manual device specification device = torch.device("cuda:0") # Specific GPU device = torch.device("cpu") # Force CPU processing
Memory Management:
# Configure GPU memory settings if device.type == "cuda": torch.cuda.empty_cache() torch.cuda.set_per_process_memory_fraction(0.8)
File Initialization:
# Set up data readers and writers self.raw_reader = DataReader(hdf5_path) # Create processed data file processed_path = hdf5_path.with_suffix("_processed.h5") self.raw_writer = DataWriter(processed_path, "raw") self.computed_writer = DataWriter(processed_path, "computed")
- run()[source]
Execute the full post-extraction processing pipeline.
Removes any existing processed file, reads raw data, initializes output datasets, processes structures in batches, and writes both raw and computed data to the output HDF5 file.
Execute Complete Feature Engineering Pipeline
Orchestrates the full post-extraction processing workflow from raw data loading to computed feature storage.
Processing Workflow:
Initialization - Set up output files and data structures
Batch Processing - Process structures in GPU-optimized batches
Feature Computation - Apply all feature engineering algorithms
Data Writing - Store results in structured HDF5 format
Validation - Verify data integrity and completeness
Batch Processing Loop:
for batch_start in range(0, total_structures, self.batch_size): batch_end = min(batch_start + self.batch_size, total_structures) # Load raw data batch raw_data = self.raw_reader.load_batch(batch_start, batch_end) # Process on GPU with torch.cuda.device(self.device): computed_data = self._process_batch(raw_data) # Write results self._write_batch_results(batch_start, raw_data, computed_data) # Memory cleanup torch.cuda.empty_cache()
Progress Monitoring:
INFO - Starting post-extraction processing... INFO - Found 8234 structures to process INFO - Processing batch 1/258 (32 structures) INFO - Computed 1024 rigid fragments INFO - Processed 15678 intermolecular contacts INFO - Processing batch 2/258 (32 structures) ... INFO - Post-extraction processing complete INFO - Processed file saved: analysis_processed.h5
GPU Memory Management:
def monitor_gpu_memory(): if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 cached = torch.cuda.memory_reserved() / 1024**3 logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB")
Error Handling:
GPU OOM - Automatically reduces batch size and retries
Data Corruption - Skips problematic structures with logging
Device Failures - Falls back to CPU processing if needed
File I/O Issues - Implements robust error recovery
- _process_batch(start, batch, dims, h5_in, h5_out)[source]
Process and write raw plus computed data for a batch of structures.
- Parameters:
start (int) – Index offset in the global refcode list corresponding to this batch.
batch (List[str]) – Refcode identifiers for this batch.
dims (Dict[str, int]) – Maximum-dimension dict from scan_max_dimensions.
h5_in (h5py.File) – Open HDF5 file handle for raw input data.
h5_out (h5py.File) – Open HDF5 file handle for processed output data.
- Return type:
None
Process Single Batch with GPU Acceleration
Core batch processing function that applies all feature engineering algorithms to a batch of structures.
- Parameters:
start (
int) - Starting structure index in batchcrystal (
CrystalParams) - Crystal propertiesatom (
AtomParams) - Atomic databond (
BondParams) - Bond connectivityintra_cc (
ContactParams) - Intramolecular contactsintra_hb (
HBondParams) - Intramolecular H-bondsinter_cc (
ContactParams) - Intermolecular contactsinter_hb (
HBondParams) - Intermolecular H-bonds
Feature Engineering Pipeline:
# 1. Compute unit cell transformation matrices cell_matrices = self._compute_cell_matrices( crystal.cell_lengths, crystal.cell_angles ) # 2. Identify rotatable bonds rotatable_bonds = self._compute_rotatable_bonds( atom.symbols, bond.atom1_idx, bond.atom2_idx, bond.orders ) # 3. Expand intermolecular contacts with symmetry expanded_contacts = self._expand_inter_contacts( inter_cc, cell_matrices ) # 4. Identify rigid molecular fragments fragment_ids = self._compute_rigid_fragments( atom.mask, bond.atom1_idx, bond.atom2_idx, rotatable_bonds ) # 5. Compute fragment properties fragment_properties = self._compute_fragment_properties( fragment_ids, atom.coords, atom.weights, atom.charges ) # 6. Map contacts to fragments contact_fragments = self._identify_contact_fragments( expanded_contacts, fragment_ids )
- __init__(hdf5_path, batch_size, device=None)[source]
Initialize the processor.
- Parameters:
hdf5_path (Path) – Path to the raw HDF5 file containing extracted structure data.
batch_size (int) – Number of structures to process per GPU batch.
device (str or torch.device, optional) – Device specifier (e.g., ‘cuda’, ‘cpu’); if None, selects CUDA if available.
Fragment Analysis Methods
- StructurePostExtractionProcessor._compute_rigid_fragments(atom_mask, atom1_idx, atom2_idx, bond_is_rotatable)[source]
Identify rigid fragments by grouping non-rotatable bonds.
- Parameters:
atom_mask (torch.BoolTensor, shape (B, N)) – Mask indicating valid atoms.
atom1_idx (torch.LongTensor, shape (B, M)) – First-atom indices per bond.
atom2_idx (torch.LongTensor, shape (B, M)) – Second-atom indices per bond.
bond_is_rotatable (torch.BoolTensor, shape (B, M)) – Mask for bonds that are rotatable.
- Returns:
Fragment ID assigned to each atom.
- Return type:
torch.LongTensor, shape (B, N)
Identify Rigid Molecular Fragments
Analyzes molecular connectivity to identify rigid groups of atoms that move as single units.
- Parameters:
atom_mask (
torch.BoolTensor) - Valid atom flagsbond_atom1 (
torch.LongTensor) - First bond partner indicesbond_atom2 (
torch.LongTensor) - Second bond partner indicesrotatable_mask (
torch.BoolTensor) - Rotatable bond flags
Algorithm:
Graph Construction - Build molecular connectivity graph
Bond Classification - Identify rotatable vs. rigid bonds
Component Analysis - Find connected components after removing rotatable bonds
Fragment Assignment - Assign unique IDs to each rigid fragment
Fragment Types Identified:
Aromatic Rings - Benzene, pyridine, etc.
Aliphatic Rings - Cyclohexane, cyclopentane, etc.
Rigid Chains - Double/triple bonded segments
Individual Atoms - Isolated atoms or small groups
- Returns:
torch.LongTensorwith fragment IDs for each atom
Geometric Computation Methods
- StructurePostExtractionProcessor._compute_fragment_com_centroid(frag_coords, frag_frac_coords, frag_weight, frag_mask)[source]
Compute each fragment’s center of mass and centroid.
- Parameters:
frag_coords (torch.Tensor, shape (B, F, Nf, 3)) – Cartesian fragment atom coordinates.
frag_frac_coords (torch.Tensor, shape (B, F, Nf, 3)) – Fractional fragment atom coordinates.
frag_weight (torch.Tensor, shape (B, F, Nf)) – Atomic weights per fragment.
frag_mask (torch.BoolTensor, shape (B, F, Nf)) – Validity mask for fragment atoms.
- Returns:
- {
‘fragment_com_coords’: torch.Tensor, ‘fragment_com_frac_coords’: torch.Tensor, ‘fragment_centroid_coords’: torch.Tensor, ‘fragment_centroid_frac_coords’: torch.Tensor
}
- Return type:
Dict[str, torch.Tensor]
Compute Fragment Centers of Mass and Centroids
Calculates both mass-weighted and geometric centers for fragments.
- Parameters:
fragment_coords (
torch.Tensor) - Atomic coordinates per fragmentfragment_frac_coords (
torch.Tensor) - Fractional coordinatesfragment_weights (
torch.Tensor) - Atomic massesfragment_mask (
torch.BoolTensor) - Valid atom flags
Center of Mass Calculation:
# Mass-weighted center calculation total_mass = torch.sum(fragment_weights * fragment_mask, dim=-1) weighted_coords = fragment_coords * fragment_weights.unsqueeze(-1) com_coords = torch.sum(weighted_coords * fragment_mask.unsqueeze(-1), dim=-2) / total_mass.unsqueeze(-1)
Geometric Centroid:
# Unweighted geometric center n_atoms = torch.sum(fragment_mask, dim=-1) centroid_coords = torch.sum(fragment_coords * fragment_mask.unsqueeze(-1), dim=-2) / n_atoms.unsqueeze(-1)
- Returns:
Dictionary with computed centers in both Cartesian and fractional coordinates
Contact Analysis Methods
- StructurePostExtractionProcessor._expand_inter_contacts(inter_cc, cell_matrix)[source]
Expand all close contacts to include symmetry-equivalent images.
- Parameters:
inter_cc (InterCCParams) – Intermolecular contact NamedTuple.
cell_matrix (torch.Tensor, shape (B, 3, 3)) – Real-space cell matrices.
- Returns:
Expanded contact data with keys like ‘inter_cc_central_atom_coords’, ‘inter_cc_length’, etc.
- Return type:
Dict[str, torch.Tensor]
Expand Intermolecular Contacts with Symmetry Operations
Applies crystallographic symmetry operations to generate the complete intermolecular contact network.
- Parameters:
contacts (
ContactParams) - Raw intermolecular contactscell_matrices (
torch.Tensor) - Unit cell transformation matrices
Symmetry Expansion Process:
# Parse symmetry operators from contact data symmetry_ops = parse_symmetry_operators(contact_symmetry_strings) # Apply rotation matrices rotation_matrices = symmetry_ops['rotation'] # Shape: (B, C, 3, 3) translation_vectors = symmetry_ops['translation'] # Shape: (B, C, 3) # Transform contact atom coordinates contact_coords_transformed = torch.matmul( rotation_matrices, contact_coords.unsqueeze(-1) ).squeeze(-1) + translation_vectors # Convert to Cartesian coordinates contact_coords_cartesian = torch.matmul( contact_coords_transformed, cell_matrices )
Distance Recalculation:
# Compute actual intermolecular distances distance_vectors = contact_coords_cartesian - central_coords distances = torch.norm(distance_vectors, dim=-1) # Verify contact validity (within cutoff distance) valid_contacts = distances < contact_cutoff_distance
Contact Classification:
van der Waals Contacts - Within vdW radii sum + tolerance
Close Contacts - Shorter than vdW sum (potential strain)
Hydrogen Bonds - Specific geometric criteria
π-π Interactions - Aromatic ring stacking
Electrostatic Contacts - Charge-charge interactions
- Returns:
Dictionary with expanded contact coordinates, distances, and classifications
- StructurePostExtractionProcessor._flag_hbond_contacts(cc_central_idx, cc_contact_idx, cc_mask, hb_central_idx, hb_hydrogen_idx, hb_contact_idx, hb_mask)[source]
Flag which close contacts correspond to hydrogen bonds.
- Parameters:
cc_central_idx (torch.LongTensor, shape (B, C)) – Central-atom indices for contacts.
cc_contact_idx (torch.LongTensor, shape (B, C)) – Contact-atom indices for contacts.
cc_mask (torch.BoolTensor, shape (B, C)) – Validity mask for contacts.
hb_central_idx (torch.LongTensor, shape (B, H)) – Donor-atom indices for H-bonds.
hb_hydrogen_idx (torch.LongTensor, shape (B, H)) – Hydrogen-atom indices for H-bonds.
hb_contact_idx (torch.LongTensor, shape (B, H)) – Acceptor-atom indices for H-bonds.
hb_mask (torch.BoolTensor, shape (B, H)) – Validity mask for H-bonds.
- Returns:
Mask indicating which contacts are H-bonds.
- Return type:
torch.BoolTensor, shape (B, C)
Identify Hydrogen Bond Contacts
Determines which close contacts are actually part of hydrogen bonds based on geometric criteria.
- Parameters:
cc_central_idx (
torch.LongTensor) - Contact central atom indicescc_contact_idx (
torch.LongTensor) - Contact contact atom indicescc_mask (
torch.BoolTensor) - Contact validity maskhb_central_idx (
torch.LongTensor) - H-bond donor indiceshb_hydrogen_idx (
torch.LongTensor) - H-bond hydrogen indiceshb_contact_idx (
torch.LongTensor) - H-bond acceptor indiceshb_mask (
torch.BoolTensor) - H-bond validity mask
Classification Algorithm:
# Match contacts to hydrogen bonds hbond_flags = torch.zeros_like(cc_mask, dtype=torch.bool) for b in range(batch_size): for c in range(max_contacts): if not cc_mask[b, c]: continue central_atom = cc_central_idx[b, c] contact_atom = cc_contact_idx[b, c] # Check if this contact matches any H-bond for h in range(max_hbonds): if not hb_mask[b, h]: continue # Match donor-acceptor pair if (hb_central_idx[b, h] == central_atom and hb_contact_idx[b, h] == contact_atom): hbond_flags[b, c] = True break
Geometric Criteria:
Distance Cutoff - D-A distance < 3.5 Å
Angular Cutoff - D-H-A angle > 120°
Linearity - Preference for linear arrangements
Chemical Validation - Appropriate donor/acceptor atoms
- Returns:
torch.BoolTensorindicating which contacts are H-bonds
- StructurePostExtractionProcessor._identify_contact_fragments(cc_central_idx, cc_contact_idx, atom_fragment_id)[source]
Map each contact’s atoms back to their rigid-fragment IDs.
- Parameters:
cc_central_idx (torch.LongTensor, shape (B, C)) – Central-atom indices for contacts.
cc_contact_idx (torch.LongTensor, shape (B, C)) – Contact-atom indices for contacts.
atom_fragment_id (torch.LongTensor, shape (B, N)) – Fragment ID per atom.
- Returns:
- {
‘inter_cc_central_atom_fragment_idx’: torch.LongTensor, ‘inter_cc_contact_atom_fragment_idx’: torch.LongTensor
}
- Return type:
Dict[str, torch.LongTensor]
Map Contacts to Fragment Pairs
Associates each intermolecular contact with the rigid fragments involved.
- Parameters:
cc_central_idx (
torch.LongTensor) - Central atom indicescc_contact_idx (
torch.LongTensor) - Contact atom indicesatom_fragment_id (
torch.LongTensor) - Fragment ID per atom
Fragment Mapping:
# Map contact atoms to their fragments central_fragment_ids = torch.gather( atom_fragment_id, dim=-1, index=cc_central_idx ) contact_fragment_ids = torch.gather( atom_fragment_id, dim=-1, index=cc_contact_idx ) # Create fragment pair identifiers fragment_pairs = torch.stack([ central_fragment_ids, contact_fragment_ids ], dim=-1)
Contact Statistics:
Contacts per Fragment - Number of intermolecular contacts
Fragment Coordination - Number of neighboring fragments
Contact Directionality - Preferred contact directions
Fragment Accessibility - Solvent-accessible surface contacts
- Returns:
Dictionary with fragment indices and contact-fragment mappings
Advanced Feature Computation
- StructurePostExtractionProcessor._compute_contact_com_vectors(cc_coords, cc_frac_coords, cc_fragment_idx, frag_com_coords, frag_com_frac_coords, frag_structure_id, frag_local_ids)[source]
Compute vectors & distances from contact atoms to central-fragment COM.
- Parameters:
cc_coords (torch.Tensor, shape (B, C, 3)) – Cartesian coords of contact atoms.
cc_frac_coords (torch.Tensor, shape (B, C, 3)) – Fractional coords of contact atoms.
cc_fragment_idx (torch.LongTensor, shape (B, C)) – Fragment indices for contact atoms.
frag_com_coords (torch.Tensor, shape (B, F, 3)) – Cartesian COM coords for each fragment.
frag_com_frac_coords (torch.Tensor, shape (B, F, 3)) – Fractional COM coords for each fragment.
frag_structure_id (torch.LongTensor, shape (B, F)) – Structure IDs for each fragment.
frag_local_ids (torch.LongTensor, shape (B, F)) – Local fragment indices.
- Returns:
- {
‘inter_cc_atom_to_central_com_vec’: torch.Tensor, ‘inter_cc_atom_to_central_com_dist’: torch.Tensor
}
- Return type:
Dict[str, torch.Tensor]
Compute Contact-to-COM Vectors and Distances
Calculates vectors and distances from contact atoms to fragment centers of mass.
- Parameters:
cc_coords (
torch.Tensor) - Contact atom coordinatescc_frac_coords (
torch.Tensor) - Contact fractional coordinatescc_fragment_idx (
torch.LongTensor) - Contact fragment indicesfrag_com_coords (
torch.Tensor) - Fragment COM coordinatesfrag_com_frac_coords (
torch.Tensor) - Fragment COM fractional coordinatesfrag_structure_id (
torch.LongTensor) - Fragment structure IDsfrag_local_ids (
torch.LongTensor) - Fragment local IDs
Vector Calculations:
# Map contacts to fragment COM positions contact_com_coords = torch.gather( frag_com_coords, dim=-2, index=cc_fragment_idx.unsqueeze(-1).expand(-1, -1, 3) ) # Compute contact-to-COM vectors com_vectors = cc_coords - contact_com_coords com_distances = torch.norm(com_vectors, dim=-1) # Normalize for directional analysis com_unit_vectors = com_vectors / com_distances.unsqueeze(-1)
Geometric Descriptors:
Contact Distance - Direct atom-atom distance
COM Distance - Distance from contact to fragment center
Contact Vector - Direction from COM to contact point
Surface Normal - Estimated surface normal at contact
Contact Accessibility - Geometric accessibility measure
Applications:
Packing Analysis - How fragments pack together
Surface Properties - Contact surface characterization
Intermolecular Forces - Direction and magnitude analysis
Crystal Engineering - Design principles extraction
- Returns:
Dictionary with contact vectors, distances, and geometric descriptors
Data Writing and I/O Methods
Usage Examples
Basic Feature Engineering
from structure_post_extraction_processor import StructurePostExtractionProcessor
import torch
from pathlib import Path
# Initialize processor
processor = StructurePostExtractionProcessor(
hdf5_path=Path("./analysis.h5"),
batch_size=32,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# Execute feature engineering
processor.run()
GPU Memory Optimization
# Configure for large datasets with limited GPU memory
import torch
# Set memory fraction
torch.cuda.set_per_process_memory_fraction(0.7)
# Use smaller batch size
memory_optimized_processor = StructurePostExtractionProcessor(
hdf5_path=Path("./large_dataset.h5"),
batch_size=16, # Reduced batch size
device=torch.device("cuda")
)
# Monitor memory usage
def memory_monitor():
allocated = torch.cuda.memory_allocated() / 1024**3
cached = torch.cuda.memory_reserved() / 1024**3
print(f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached")
# Run with monitoring
memory_monitor()
memory_optimized_processor.run()
memory_monitor()
Custom Feature Selection
# Subclass for custom feature computation
class CustomFeatureProcessor(StructurePostExtractionProcessor):
def _process_batch(self, start, crystal, atom, bond, intra_cc, intra_hb, inter_cc, inter_hb):
"""Override to add custom features."""
# Call parent processing
super()._process_batch(start, crystal, atom, bond, intra_cc, intra_hb, inter_cc, inter_hb)
# Add custom features
custom_features = self._compute_custom_descriptors(
atom.coords, atom.symbols, bond.atom1_idx, bond.atom2_idx
)
# Write custom features
self._write_custom_features(start, custom_features)
def _compute_custom_descriptors(self, coords, symbols, bond1, bond2):
"""Compute application-specific descriptors."""
# Example: Compute molecular surface area
surface_areas = self._compute_molecular_surface_area(coords, symbols)
# Example: Compute ring strain energies
ring_strains = self._compute_ring_strain_energies(coords, bond1, bond2)
return {
'surface_areas': surface_areas,
'ring_strains': ring_strains
}
High-Throughput Processing
# Configure for maximum throughput
high_throughput_processor = StructurePostExtractionProcessor(
hdf5_path=Path("./massive_dataset.h5"),
batch_size=128, # Large batches for throughput
device=torch.device("cuda")
)
# Use multiple GPUs if available
if torch.cuda.device_count() > 1:
# Implement data parallel processing
import torch.nn as nn
class ParallelProcessor(nn.DataParallel):
def __init__(self, processor):
super().__init__(processor)
def forward(self, batch_data):
return self.module._process_batch_parallel(batch_data)
parallel_processor = ParallelProcessor(high_throughput_processor)
Quality Control and Validation
# Implement comprehensive validation
class ValidatedProcessor(StructurePostExtractionProcessor):
def _validate_batch_results(self, computed_data):
"""Validate computed features for quality."""
# Check for NaN values
for key, tensor in computed_data.items():
if torch.isnan(tensor).any():
logging.warning(f"NaN values detected in {key}")
# Check physical constraints
com_distances = computed_data.get('com_distances')
if com_distances is not None:
if (com_distances < 0).any():
logging.error("Negative distances detected")
# Check fragment validity
fragment_ids = computed_data.get('fragment_ids')
if fragment_ids is not None:
max_fragment_id = fragment_ids.max()
expected_max = len(torch.unique(fragment_ids)) - 1
if max_fragment_id != expected_max:
logging.warning("Fragment ID inconsistency detected")
def _process_batch(self, *args, **kwargs):
# Process normally
result = super()._process_batch(*args, **kwargs)
# Validate results
self._validate_batch_results(result)
return result
Performance Optimization
GPU Utilization
# Optimize GPU kernel launches
torch.backends.cudnn.benchmark = True # Optimize for fixed input sizes
torch.backends.cudnn.deterministic = False # Allow non-deterministic for speed
# Use mixed precision for memory efficiency
from torch.cuda.amp import autocast, GradScaler
class MixedPrecisionProcessor(StructurePostExtractionProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scaler = GradScaler()
def _process_batch(self, *args, **kwargs):
with autocast():
return super()._process_batch(*args, **kwargs)
Memory Management
# Implement gradient checkpointing for memory efficiency
import torch.utils.checkpoint as checkpoint
def memory_efficient_feature_computation(coords, weights):
"""Use checkpointing for memory-intensive computations."""
def compute_inertia_checkpoint(coords_chunk, weights_chunk):
return compute_inertia_tensor(coords_chunk, weights_chunk)
# Use checkpointing to trade compute for memory
inertia_tensors = checkpoint.checkpoint(
compute_inertia_checkpoint,
coords,
weights,
use_reentrant=False
)
return inertia_tensors
Batch Size Optimization
def find_optimal_batch_size(processor, max_batch_size=256):
"""Find maximum feasible batch size through binary search."""
low, high = 1, max_batch_size
optimal_size = 1
while low <= high:
mid = (low + high) // 2
try:
# Test with current batch size
processor.batch_size = mid
test_batch = processor._create_test_batch()
processor._process_batch(*test_batch)
# Success - try larger
optimal_size = mid
low = mid + 1
except RuntimeError as e:
if "out of memory" in str(e):
# OOM - try smaller
high = mid - 1
torch.cuda.empty_cache()
else:
raise
return optimal_size
See Also
crystal_analyzer module : Pipeline orchestration structure_data_extractor module : Raw data extraction fragment_utils module : Fragment analysis utilities geometry_utils module : Geometric computation utilities data_writer module : HDF5 data writing utilities