structure_post_extraction_processor module

Module: structure_post_extraction_processor.py

Post-extraction processing for CSD structures.

Reads raw HDF5 outputs (from StructureDataExtractor), computes derived features (geometric, topological, contact-based) in GPU-accelerated batches, and writes both raw and computed datasets into a new “*_processed.h5” container. Designed for high throughput and minimal I/O overhead.

Dependencies

h5py numpy torch data_reader data_writer dataset_initializer dimension_scanner cell_utils contact_utils fragment_utils geometry_utils symmetry_utils

class structure_post_extraction_processor.CrystalParams(cell_lengths, cell_angles)[source]

Bases: NamedTuple

Container for basic crystal parameters.

cell_lengths

Unit-cell lengths [a, b, c] for each structure in the batch.

Type:

torch.Tensor, shape (B, 3)

cell_angles

Unit-cell angles [α, β, γ] in degrees for each structure in the batch.

Type:

torch.Tensor, shape (B, 3)

cell_lengths: torch.Tensor

Alias for field number 0

cell_angles: torch.Tensor

Alias for field number 1

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

class structure_post_extraction_processor.AtomParams(labels, symbols, coords, frac_coords, mask, weights, charges)[source]

Bases: NamedTuple

Container for atomic-level parameters.

labels

Atom labels for each structure, padded to max_atoms.

Type:

List[List[str]]

symbols

Atomic symbols for each structure, padded to max_atoms.

Type:

List[List[str]]

coords

Cartesian coordinates of each atom.

Type:

torch.Tensor, shape (B, max_atoms, 3)

frac_coords

Fractional coordinates of each atom.

Type:

torch.Tensor, shape (B, max_atoms, 3)

mask

Boolean mask indicating valid atom entries.

Type:

torch.BoolTensor, shape (B, max_atoms)

weights

Atomic weights.

Type:

torch.Tensor, shape (B, max_atoms)

charges

Partial charges per atom.

Type:

torch.Tensor, shape (B, max_atoms)

labels: List[List[str]]

Alias for field number 0

symbols: List[List[str]]

Alias for field number 1

coords: torch.Tensor

Alias for field number 2

frac_coords: torch.Tensor

Alias for field number 3

mask: torch.BoolTensor

Alias for field number 4

weights: torch.Tensor

Alias for field number 5

charges: torch.Tensor

Alias for field number 6

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

class structure_post_extraction_processor.BondParams(atom1_idx, atom2_idx, bond_type, is_rotatable_raw, is_cyclic, mask)[source]

Bases: NamedTuple

Container for bond-level parameters.

atom1_idx

Index of first atom in each bond.

Type:

torch.LongTensor, shape (B, max_bonds)

atom2_idx

Index of second atom in each bond.

Type:

torch.LongTensor, shape (B, max_bonds)

bond_type

Numeric or categorical encoding of bond types.

Type:

torch.Tensor, shape (B, max_bonds)

is_rotatable_raw

Initial mask for bond rotatability.

Type:

torch.BoolTensor, shape (B, max_bonds)

is_cyclic

Indicates if bond is part of a ring.

Type:

torch.BoolTensor, shape (B, max_bonds)

mask

Boolean mask indicating valid bond entries.

Type:

torch.BoolTensor, shape (B, max_bonds)

atom1_idx: torch.LongTensor

Alias for field number 0

atom2_idx: torch.LongTensor

Alias for field number 1

bond_type: torch.Tensor

Alias for field number 2

is_rotatable_raw: torch.BoolTensor

Alias for field number 3

is_cyclic: torch.BoolTensor

Alias for field number 4

mask: torch.BoolTensor

Alias for field number 5

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

class structure_post_extraction_processor.InterCCParams(central_atom, contact_atom, central_atom_idx, contact_atom_idx, central_atom_frac_coords, contact_atom_frac_coords, lengths, strengths, in_los, symmetry_A, symmetry_T, symmetry_A_inv, symmetry_T_inv)[source]

Bases: NamedTuple

Container for intermolecular close-contact parameters.

central_atom

Labels of central atoms in each contact.

Type:

List[List[str]]

contact_atom

Labels of contact atoms.

Type:

List[List[str]]

central_atom_idx

Indices of central atoms.

Type:

torch.LongTensor, shape (B, C)

contact_atom_idx

Indices of contact atoms.

Type:

torch.LongTensor, shape (B, C)

central_atom_frac_coords

Fractional coords of central atoms.

Type:

torch.Tensor, shape (B, C, 3)

contact_atom_frac_coords

Fractional coords of contact atoms.

Type:

torch.Tensor, shape (B, C, 3)

lengths

Contact distances.

Type:

torch.Tensor, shape (B, C)

strengths

Contact strength metrics.

Type:

torch.Tensor, shape (B, C)

in_los

Mask for line-of-sight contacts.

Type:

torch.Tensor, shape (B, C)

symmetry_A

Symmetry operation rotation matrices.

Type:

torch.Tensor, shape (B, C, 3, 3)

symmetry_T

Symmetry operation translation vectors.

Type:

torch.Tensor, shape (B, C, 3)

symmetry_A_inv

Inverse rotation matrices.

Type:

torch.Tensor, shape (B, C, 3, 3)

symmetry_T_inv

Inverse translation vectors.

Type:

torch.Tensor, shape (B, C, 3)

central_atom: List[List[str]]

Alias for field number 0

contact_atom: List[List[str]]

Alias for field number 1

central_atom_idx: torch.LongTensor

Alias for field number 2

contact_atom_idx: torch.LongTensor

Alias for field number 3

central_atom_frac_coords: torch.Tensor

Alias for field number 4

contact_atom_frac_coords: torch.Tensor

Alias for field number 5

lengths: torch.Tensor

Alias for field number 6

strengths: torch.Tensor

Alias for field number 7

in_los: torch.Tensor

Alias for field number 8

symmetry_A: torch.Tensor

Alias for field number 9

symmetry_T: torch.Tensor

Alias for field number 10

symmetry_A_inv: torch.Tensor

Alias for field number 11

symmetry_T_inv: torch.Tensor

Alias for field number 12

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

class structure_post_extraction_processor.InterHBParams(central_atom, hydrogen_atom, contact_atom, central_atom_idx, hydrogen_atom_idx, contact_atom_idx, central_atom_frac_coords, hydrogen_atom_frac_coords, contact_atom_frac_coords, lengths, angles, in_los, symmetry_A, symmetry_T, symmetry_A_inv, symmetry_T_inv)[source]

Bases: NamedTuple

Container for intermolecular hydrogen-bond parameters.

central_atom

Labels of hydrogen-bond donor atoms.

Type:

List[List[str]]

hydrogen_atom

Labels of hydrogen atoms.

Type:

List[List[str]]

contact_atom

Labels of acceptor atoms.

Type:

List[List[str]]

central_atom_idx

Indices of donor atoms.

Type:

torch.LongTensor, shape (B, H)

hydrogen_atom_idx

Indices of hydrogen atoms.

Type:

torch.LongTensor, shape (B, H)

contact_atom_idx

Indices of acceptor atoms.

Type:

torch.LongTensor, shape (B, H)

central_atom_frac_coords

Fractional coords of donor atoms.

Type:

torch.Tensor, shape (B, H, 3)

hydrogen_atom_frac_coords

Fractional coords of hydrogen atoms.

Type:

torch.Tensor, shape (B, H, 3)

contact_atom_frac_coords

Fractional coords of acceptor atoms.

Type:

torch.Tensor, shape (B, H, 3)

lengths

H-bond distances.

Type:

torch.Tensor, shape (B, H)

angles

H-bond angles.

Type:

torch.Tensor, shape (B, H)

in_los

Mask for line-of-sight H-bonds.

Type:

torch.Tensor, shape (B, H)

symmetry_A

Symmetry rotation matrices.

Type:

torch.Tensor, shape (B, H, 3, 3)

symmetry_T

Symmetry translation vectors.

Type:

torch.Tensor, shape (B, H, 3)

symmetry_A_inv

Inverse rotation matrices.

Type:

torch.Tensor, shape (B, H, 3, 3)

symmetry_T_inv

Inverse translation vectors.

Type:

torch.Tensor, shape (B, H, 3)

central_atom: List[List[str]]

Alias for field number 0

hydrogen_atom: List[List[str]]

Alias for field number 1

contact_atom: List[List[str]]

Alias for field number 2

central_atom_idx: torch.LongTensor

Alias for field number 3

hydrogen_atom_idx: torch.LongTensor

Alias for field number 4

contact_atom_idx: torch.LongTensor

Alias for field number 5

central_atom_frac_coords: torch.Tensor

Alias for field number 6

hydrogen_atom_frac_coords: torch.Tensor

Alias for field number 7

contact_atom_frac_coords: torch.Tensor

Alias for field number 8

lengths: torch.Tensor

Alias for field number 9

angles: torch.Tensor

Alias for field number 10

in_los: torch.Tensor

Alias for field number 11

symmetry_A: torch.Tensor

Alias for field number 12

symmetry_T: torch.Tensor

Alias for field number 13

symmetry_A_inv: torch.Tensor

Alias for field number 14

symmetry_T_inv: torch.Tensor

Alias for field number 15

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

class structure_post_extraction_processor.StructurePostExtractionProcessor(hdf5_path, batch_size, device=None)[source]

Bases: object

Orchestrates post-extraction computation of derived structure features.

Reads raw data from an HDF5 file, processes structures in GPU-accelerated batches to compute geometric, topological, and contact-based features, and writes both raw and computed data to a new processed HDF5 file.

__init__(hdf5_path, batch_size, device=None)[source]

Initialize the processor.

Parameters:
  • hdf5_path (Path) – Path to the raw HDF5 file containing extracted structure data.

  • batch_size (int) – Number of structures to process per GPU batch.

  • device (str or torch.device, optional) – Device specifier (e.g., ‘cuda’, ‘cpu’); if None, selects CUDA if available.

run()[source]

Execute the full post-extraction processing pipeline.

Removes any existing processed file, reads raw data, initializes output datasets, processes structures in batches, and writes both raw and computed data to the output HDF5 file.

Advanced Feature Engineering Pipeline

The structure_post_extraction_processor module implements GPU-accelerated computation of advanced structural descriptors and features from raw crystal structure data. This module forms the core of Stage 5 in the CSA pipeline.

Data Structure Classes

CrystalParams

class structure_post_extraction_processor.CrystalParams(cell_lengths, cell_angles)[source]

Bases: NamedTuple

Container for basic crystal parameters.

cell_lengths

Unit-cell lengths [a, b, c] for each structure in the batch.

Type:

torch.Tensor, shape (B, 3)

cell_angles

Unit-cell angles [α, β, γ] in degrees for each structure in the batch.

Type:

torch.Tensor, shape (B, 3)

Crystal-Level Parameter Container

Dataclass holding crystal properties and unit cell information for batch processing.

Attributes:
  • identifiers (List[str]) - CSD refcodes for structures

  • space_groups (List[str]) - Space group symbols

  • z_values (torch.LongTensor) - Z values per structure

  • z_prime (torch.Tensor) - Z’ values per structure

  • cell_volumes (torch.Tensor) - Unit cell volumes (ų)

  • cell_densities (torch.Tensor) - Crystal densities (g/cm³)

  • cell_lengths (torch.Tensor) - Unit cell parameters a, b, c

  • cell_angles (torch.Tensor) - Unit cell angles α, β, γ

cell_lengths: torch.Tensor

Alias for field number 0

cell_angles: torch.Tensor

Alias for field number 1

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

AtomParams

class structure_post_extraction_processor.AtomParams(labels, symbols, coords, frac_coords, mask, weights, charges)[source]

Bases: NamedTuple

Container for atomic-level parameters.

labels

Atom labels for each structure, padded to max_atoms.

Type:

List[List[str]]

symbols

Atomic symbols for each structure, padded to max_atoms.

Type:

List[List[str]]

coords

Cartesian coordinates of each atom.

Type:

torch.Tensor, shape (B, max_atoms, 3)

frac_coords

Fractional coordinates of each atom.

Type:

torch.Tensor, shape (B, max_atoms, 3)

mask

Boolean mask indicating valid atom entries.

Type:

torch.BoolTensor, shape (B, max_atoms)

weights

Atomic weights.

Type:

torch.Tensor, shape (B, max_atoms)

charges

Partial charges per atom.

Type:

torch.Tensor, shape (B, max_atoms)

Atomic-Level Parameter Container

Dataclass holding atomic coordinates, properties, and connectivity information.

Attributes:
  • labels (List[List[str]]) - Atomic labels per structure

  • coords (torch.Tensor) - Cartesian coordinates (Å)

  • frac_coords (torch.Tensor) - Fractional coordinates

  • weights (torch.Tensor) - Atomic masses (Da)

  • numbers (torch.LongTensor) - Atomic numbers

  • charges (torch.Tensor) - Formal charges

  • symbols (List[List[str]]) - Element symbols

  • sybyl_types (List[List[str]]) - SYBYL atom types

  • mask (torch.BoolTensor) - Validity mask for atoms

labels: List[List[str]]

Alias for field number 0

symbols: List[List[str]]

Alias for field number 1

coords: torch.Tensor

Alias for field number 2

frac_coords: torch.Tensor

Alias for field number 3

mask: torch.BoolTensor

Alias for field number 4

weights: torch.Tensor

Alias for field number 5

charges: torch.Tensor

Alias for field number 6

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

BondParams

class structure_post_extraction_processor.BondParams(atom1_idx, atom2_idx, bond_type, is_rotatable_raw, is_cyclic, mask)[source]

Bases: NamedTuple

Container for bond-level parameters.

atom1_idx

Index of first atom in each bond.

Type:

torch.LongTensor, shape (B, max_bonds)

atom2_idx

Index of second atom in each bond.

Type:

torch.LongTensor, shape (B, max_bonds)

bond_type

Numeric or categorical encoding of bond types.

Type:

torch.Tensor, shape (B, max_bonds)

is_rotatable_raw

Initial mask for bond rotatability.

Type:

torch.BoolTensor, shape (B, max_bonds)

is_cyclic

Indicates if bond is part of a ring.

Type:

torch.BoolTensor, shape (B, max_bonds)

mask

Boolean mask indicating valid bond entries.

Type:

torch.BoolTensor, shape (B, max_bonds)

Bond Connectivity Parameter Container

Dataclass holding molecular bond information and connectivity graphs.

Attributes:
  • atom1_labels (List[List[str]]) - First bond partner labels

  • atom2_labels (List[List[str]]) - Second bond partner labels

  • atom1_idx (torch.LongTensor) - First partner indices

  • atom2_idx (torch.LongTensor) - Second partner indices

  • orders (torch.Tensor) - Bond order values

  • mask (torch.BoolTensor) - Validity mask for bonds

atom1_idx: torch.LongTensor

Alias for field number 0

atom2_idx: torch.LongTensor

Alias for field number 1

bond_type: torch.Tensor

Alias for field number 2

is_rotatable_raw: torch.BoolTensor

Alias for field number 3

is_cyclic: torch.BoolTensor

Alias for field number 4

mask: torch.BoolTensor

Alias for field number 5

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

ContactParams

HBondParams

StructurePostExtractionProcessor Class

class structure_post_extraction_processor.StructurePostExtractionProcessor(hdf5_path, batch_size, device=None)[source]

Bases: object

Orchestrates post-extraction computation of derived structure features.

Reads raw data from an HDF5 file, processes structures in GPU-accelerated batches to compute geometric, topological, and contact-based features, and writes both raw and computed data to a new processed HDF5 file.

GPU-Accelerated Feature Engineering Pipeline

Orchestrates the computation of advanced structural descriptors and features from raw crystal structure data using GPU acceleration for maximum performance.

Core Capabilities:

  • Rigid Fragment Analysis - Identifies molecular fragments and rigid groups

  • Geometric Descriptors - Computes shape descriptors, moment tensors, orientations

  • Contact Analysis - Analyzes intermolecular interactions and networks

  • Symmetry Operations - Applies crystallographic symmetry transformations

  • Feature Engineering - Derives machine learning-ready descriptors

  • GPU Acceleration - Leverages CUDA for batch tensor operations

Processing Pipeline:

  1. Data Loading - Read raw data from HDF5 into GPU tensors

  2. Fragment Identification - Detect rigid molecular fragments

  3. Geometric Analysis - Compute centers of mass, inertia tensors, orientations

  4. Contact Expansion - Apply symmetry operations to intermolecular contacts

  5. Feature Computation - Calculate advanced descriptors and properties

  6. Data Writing - Save computed features to processed HDF5 file

Attributes:
  • hdf5_path (Path) - Input HDF5 file with raw data

  • batch_size (int) - Structures processed per GPU batch

  • device (torch.device) - GPU device for tensor operations

  • raw_reader (DataReader) - Raw data loading interface

  • raw_writer (DataWriter) - Raw data writing interface

  • computed_writer (DataWriter) - Computed data writing interface

__init__(hdf5_path, batch_size, device=None)[source]

Initialize the processor.

Parameters:
  • hdf5_path (Path) – Path to the raw HDF5 file containing extracted structure data.

  • batch_size (int) – Number of structures to process per GPU batch.

  • device (str or torch.device, optional) – Device specifier (e.g., ‘cuda’, ‘cpu’); if None, selects CUDA if available.

Initialize Post-Extraction Processor

Parameters:
  • hdf5_path (Path) - Path to raw HDF5 file

  • batch_size (int) - Number of structures per GPU batch

  • device (Optional[Union[str, torch.device]]) - GPU device specification

Device Configuration:

# Automatic device selection
if device is None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Manual device specification
device = torch.device("cuda:0")  # Specific GPU
device = torch.device("cpu")     # Force CPU processing

Memory Management:

# Configure GPU memory settings
if device.type == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.set_per_process_memory_fraction(0.8)

File Initialization:

# Set up data readers and writers
self.raw_reader = DataReader(hdf5_path)

# Create processed data file
processed_path = hdf5_path.with_suffix("_processed.h5")
self.raw_writer = DataWriter(processed_path, "raw")
self.computed_writer = DataWriter(processed_path, "computed")
run()[source]

Execute the full post-extraction processing pipeline.

Removes any existing processed file, reads raw data, initializes output datasets, processes structures in batches, and writes both raw and computed data to the output HDF5 file.

Execute Complete Feature Engineering Pipeline

Orchestrates the full post-extraction processing workflow from raw data loading to computed feature storage.

Processing Workflow:

  1. Initialization - Set up output files and data structures

  2. Batch Processing - Process structures in GPU-optimized batches

  3. Feature Computation - Apply all feature engineering algorithms

  4. Data Writing - Store results in structured HDF5 format

  5. Validation - Verify data integrity and completeness

Batch Processing Loop:

for batch_start in range(0, total_structures, self.batch_size):
    batch_end = min(batch_start + self.batch_size, total_structures)

    # Load raw data batch
    raw_data = self.raw_reader.load_batch(batch_start, batch_end)

    # Process on GPU
    with torch.cuda.device(self.device):
        computed_data = self._process_batch(raw_data)

    # Write results
    self._write_batch_results(batch_start, raw_data, computed_data)

    # Memory cleanup
    torch.cuda.empty_cache()

Progress Monitoring:

INFO - Starting post-extraction processing...
INFO - Found 8234 structures to process
INFO - Processing batch 1/258 (32 structures)
INFO - Computed 1024 rigid fragments
INFO - Processed 15678 intermolecular contacts
INFO - Processing batch 2/258 (32 structures)
...
INFO - Post-extraction processing complete
INFO - Processed file saved: analysis_processed.h5

GPU Memory Management:

def monitor_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        cached = torch.cuda.memory_reserved() / 1024**3
        logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB")

Error Handling:

  • GPU OOM - Automatically reduces batch size and retries

  • Data Corruption - Skips problematic structures with logging

  • Device Failures - Falls back to CPU processing if needed

  • File I/O Issues - Implements robust error recovery

_process_batch(start, batch, dims, h5_in, h5_out)[source]

Process and write raw plus computed data for a batch of structures.

Parameters:
  • start (int) – Index offset in the global refcode list corresponding to this batch.

  • batch (List[str]) – Refcode identifiers for this batch.

  • dims (Dict[str, int]) – Maximum-dimension dict from scan_max_dimensions.

  • h5_in (h5py.File) – Open HDF5 file handle for raw input data.

  • h5_out (h5py.File) – Open HDF5 file handle for processed output data.

Return type:

None

Process Single Batch with GPU Acceleration

Core batch processing function that applies all feature engineering algorithms to a batch of structures.

Parameters:
  • start (int) - Starting structure index in batch

  • crystal (CrystalParams) - Crystal properties

  • atom (AtomParams) - Atomic data

  • bond (BondParams) - Bond connectivity

  • intra_cc (ContactParams) - Intramolecular contacts

  • intra_hb (HBondParams) - Intramolecular H-bonds

  • inter_cc (ContactParams) - Intermolecular contacts

  • inter_hb (HBondParams) - Intermolecular H-bonds

Feature Engineering Pipeline:

# 1. Compute unit cell transformation matrices
cell_matrices = self._compute_cell_matrices(
    crystal.cell_lengths, crystal.cell_angles
)

# 2. Identify rotatable bonds
rotatable_bonds = self._compute_rotatable_bonds(
    atom.symbols, bond.atom1_idx, bond.atom2_idx, bond.orders
)

# 3. Expand intermolecular contacts with symmetry
expanded_contacts = self._expand_inter_contacts(
    inter_cc, cell_matrices
)

# 4. Identify rigid molecular fragments
fragment_ids = self._compute_rigid_fragments(
    atom.mask, bond.atom1_idx, bond.atom2_idx, rotatable_bonds
)

# 5. Compute fragment properties
fragment_properties = self._compute_fragment_properties(
    fragment_ids, atom.coords, atom.weights, atom.charges
)

# 6. Map contacts to fragments
contact_fragments = self._identify_contact_fragments(
    expanded_contacts, fragment_ids
)
__init__(hdf5_path, batch_size, device=None)[source]

Initialize the processor.

Parameters:
  • hdf5_path (Path) – Path to the raw HDF5 file containing extracted structure data.

  • batch_size (int) – Number of structures to process per GPU batch.

  • device (str or torch.device, optional) – Device specifier (e.g., ‘cuda’, ‘cpu’); if None, selects CUDA if available.

run()[source]

Execute the full post-extraction processing pipeline.

Removes any existing processed file, reads raw data, initializes output datasets, processes structures in batches, and writes both raw and computed data to the output HDF5 file.

Fragment Analysis Methods

StructurePostExtractionProcessor._compute_rigid_fragments(atom_mask, atom1_idx, atom2_idx, bond_is_rotatable)[source]

Identify rigid fragments by grouping non-rotatable bonds.

Parameters:
  • atom_mask (torch.BoolTensor, shape (B, N)) – Mask indicating valid atoms.

  • atom1_idx (torch.LongTensor, shape (B, M)) – First-atom indices per bond.

  • atom2_idx (torch.LongTensor, shape (B, M)) – Second-atom indices per bond.

  • bond_is_rotatable (torch.BoolTensor, shape (B, M)) – Mask for bonds that are rotatable.

Returns:

Fragment ID assigned to each atom.

Return type:

torch.LongTensor, shape (B, N)

Identify Rigid Molecular Fragments

Analyzes molecular connectivity to identify rigid groups of atoms that move as single units.

Parameters:
  • atom_mask (torch.BoolTensor) - Valid atom flags

  • bond_atom1 (torch.LongTensor) - First bond partner indices

  • bond_atom2 (torch.LongTensor) - Second bond partner indices

  • rotatable_mask (torch.BoolTensor) - Rotatable bond flags

Algorithm:

  1. Graph Construction - Build molecular connectivity graph

  2. Bond Classification - Identify rotatable vs. rigid bonds

  3. Component Analysis - Find connected components after removing rotatable bonds

  4. Fragment Assignment - Assign unique IDs to each rigid fragment

Fragment Types Identified:

  • Aromatic Rings - Benzene, pyridine, etc.

  • Aliphatic Rings - Cyclohexane, cyclopentane, etc.

  • Rigid Chains - Double/triple bonded segments

  • Individual Atoms - Isolated atoms or small groups

Returns:

torch.LongTensor with fragment IDs for each atom

Geometric Computation Methods

StructurePostExtractionProcessor._compute_fragment_com_centroid(frag_coords, frag_frac_coords, frag_weight, frag_mask)[source]

Compute each fragment’s center of mass and centroid.

Parameters:
  • frag_coords (torch.Tensor, shape (B, F, Nf, 3)) – Cartesian fragment atom coordinates.

  • frag_frac_coords (torch.Tensor, shape (B, F, Nf, 3)) – Fractional fragment atom coordinates.

  • frag_weight (torch.Tensor, shape (B, F, Nf)) – Atomic weights per fragment.

  • frag_mask (torch.BoolTensor, shape (B, F, Nf)) – Validity mask for fragment atoms.

Returns:

{

‘fragment_com_coords’: torch.Tensor, ‘fragment_com_frac_coords’: torch.Tensor, ‘fragment_centroid_coords’: torch.Tensor, ‘fragment_centroid_frac_coords’: torch.Tensor

}

Return type:

Dict[str, torch.Tensor]

Compute Fragment Centers of Mass and Centroids

Calculates both mass-weighted and geometric centers for fragments.

Parameters:
  • fragment_coords (torch.Tensor) - Atomic coordinates per fragment

  • fragment_frac_coords (torch.Tensor) - Fractional coordinates

  • fragment_weights (torch.Tensor) - Atomic masses

  • fragment_mask (torch.BoolTensor) - Valid atom flags

Center of Mass Calculation:

# Mass-weighted center calculation
total_mass = torch.sum(fragment_weights * fragment_mask, dim=-1)
weighted_coords = fragment_coords * fragment_weights.unsqueeze(-1)
com_coords = torch.sum(weighted_coords * fragment_mask.unsqueeze(-1), dim=-2) / total_mass.unsqueeze(-1)

Geometric Centroid:

# Unweighted geometric center
n_atoms = torch.sum(fragment_mask, dim=-1)
centroid_coords = torch.sum(fragment_coords * fragment_mask.unsqueeze(-1), dim=-2) / n_atoms.unsqueeze(-1)
Returns:

Dictionary with computed centers in both Cartesian and fractional coordinates

Contact Analysis Methods

StructurePostExtractionProcessor._expand_inter_contacts(inter_cc, cell_matrix)[source]

Expand all close contacts to include symmetry-equivalent images.

Parameters:
  • inter_cc (InterCCParams) – Intermolecular contact NamedTuple.

  • cell_matrix (torch.Tensor, shape (B, 3, 3)) – Real-space cell matrices.

Returns:

Expanded contact data with keys like ‘inter_cc_central_atom_coords’, ‘inter_cc_length’, etc.

Return type:

Dict[str, torch.Tensor]

Expand Intermolecular Contacts with Symmetry Operations

Applies crystallographic symmetry operations to generate the complete intermolecular contact network.

Parameters:
  • contacts (ContactParams) - Raw intermolecular contacts

  • cell_matrices (torch.Tensor) - Unit cell transformation matrices

Symmetry Expansion Process:

# Parse symmetry operators from contact data
symmetry_ops = parse_symmetry_operators(contact_symmetry_strings)

# Apply rotation matrices
rotation_matrices = symmetry_ops['rotation']  # Shape: (B, C, 3, 3)
translation_vectors = symmetry_ops['translation']  # Shape: (B, C, 3)

# Transform contact atom coordinates
contact_coords_transformed = torch.matmul(
    rotation_matrices,
    contact_coords.unsqueeze(-1)
).squeeze(-1) + translation_vectors

# Convert to Cartesian coordinates
contact_coords_cartesian = torch.matmul(
    contact_coords_transformed,
    cell_matrices
)

Distance Recalculation:

# Compute actual intermolecular distances
distance_vectors = contact_coords_cartesian - central_coords
distances = torch.norm(distance_vectors, dim=-1)

# Verify contact validity (within cutoff distance)
valid_contacts = distances < contact_cutoff_distance

Contact Classification:

  • van der Waals Contacts - Within vdW radii sum + tolerance

  • Close Contacts - Shorter than vdW sum (potential strain)

  • Hydrogen Bonds - Specific geometric criteria

  • π-π Interactions - Aromatic ring stacking

  • Electrostatic Contacts - Charge-charge interactions

Returns:

Dictionary with expanded contact coordinates, distances, and classifications

StructurePostExtractionProcessor._flag_hbond_contacts(cc_central_idx, cc_contact_idx, cc_mask, hb_central_idx, hb_hydrogen_idx, hb_contact_idx, hb_mask)[source]

Flag which close contacts correspond to hydrogen bonds.

Parameters:
  • cc_central_idx (torch.LongTensor, shape (B, C)) – Central-atom indices for contacts.

  • cc_contact_idx (torch.LongTensor, shape (B, C)) – Contact-atom indices for contacts.

  • cc_mask (torch.BoolTensor, shape (B, C)) – Validity mask for contacts.

  • hb_central_idx (torch.LongTensor, shape (B, H)) – Donor-atom indices for H-bonds.

  • hb_hydrogen_idx (torch.LongTensor, shape (B, H)) – Hydrogen-atom indices for H-bonds.

  • hb_contact_idx (torch.LongTensor, shape (B, H)) – Acceptor-atom indices for H-bonds.

  • hb_mask (torch.BoolTensor, shape (B, H)) – Validity mask for H-bonds.

Returns:

Mask indicating which contacts are H-bonds.

Return type:

torch.BoolTensor, shape (B, C)

Identify Hydrogen Bond Contacts

Determines which close contacts are actually part of hydrogen bonds based on geometric criteria.

Parameters:
  • cc_central_idx (torch.LongTensor) - Contact central atom indices

  • cc_contact_idx (torch.LongTensor) - Contact contact atom indices

  • cc_mask (torch.BoolTensor) - Contact validity mask

  • hb_central_idx (torch.LongTensor) - H-bond donor indices

  • hb_hydrogen_idx (torch.LongTensor) - H-bond hydrogen indices

  • hb_contact_idx (torch.LongTensor) - H-bond acceptor indices

  • hb_mask (torch.BoolTensor) - H-bond validity mask

Classification Algorithm:

# Match contacts to hydrogen bonds
hbond_flags = torch.zeros_like(cc_mask, dtype=torch.bool)

for b in range(batch_size):
    for c in range(max_contacts):
        if not cc_mask[b, c]:
            continue

        central_atom = cc_central_idx[b, c]
        contact_atom = cc_contact_idx[b, c]

        # Check if this contact matches any H-bond
        for h in range(max_hbonds):
            if not hb_mask[b, h]:
                continue

            # Match donor-acceptor pair
            if (hb_central_idx[b, h] == central_atom and
                hb_contact_idx[b, h] == contact_atom):
                hbond_flags[b, c] = True
                break

Geometric Criteria:

  • Distance Cutoff - D-A distance < 3.5 Å

  • Angular Cutoff - D-H-A angle > 120°

  • Linearity - Preference for linear arrangements

  • Chemical Validation - Appropriate donor/acceptor atoms

Returns:

torch.BoolTensor indicating which contacts are H-bonds

StructurePostExtractionProcessor._identify_contact_fragments(cc_central_idx, cc_contact_idx, atom_fragment_id)[source]

Map each contact’s atoms back to their rigid-fragment IDs.

Parameters:
  • cc_central_idx (torch.LongTensor, shape (B, C)) – Central-atom indices for contacts.

  • cc_contact_idx (torch.LongTensor, shape (B, C)) – Contact-atom indices for contacts.

  • atom_fragment_id (torch.LongTensor, shape (B, N)) – Fragment ID per atom.

Returns:

{

‘inter_cc_central_atom_fragment_idx’: torch.LongTensor, ‘inter_cc_contact_atom_fragment_idx’: torch.LongTensor

}

Return type:

Dict[str, torch.LongTensor]

Map Contacts to Fragment Pairs

Associates each intermolecular contact with the rigid fragments involved.

Parameters:
  • cc_central_idx (torch.LongTensor) - Central atom indices

  • cc_contact_idx (torch.LongTensor) - Contact atom indices

  • atom_fragment_id (torch.LongTensor) - Fragment ID per atom

Fragment Mapping:

# Map contact atoms to their fragments
central_fragment_ids = torch.gather(
    atom_fragment_id,
    dim=-1,
    index=cc_central_idx
)

contact_fragment_ids = torch.gather(
    atom_fragment_id,
    dim=-1,
    index=cc_contact_idx
)

# Create fragment pair identifiers
fragment_pairs = torch.stack([
    central_fragment_ids,
    contact_fragment_ids
], dim=-1)

Contact Statistics:

  • Contacts per Fragment - Number of intermolecular contacts

  • Fragment Coordination - Number of neighboring fragments

  • Contact Directionality - Preferred contact directions

  • Fragment Accessibility - Solvent-accessible surface contacts

Returns:

Dictionary with fragment indices and contact-fragment mappings

Advanced Feature Computation

StructurePostExtractionProcessor._compute_contact_com_vectors(cc_coords, cc_frac_coords, cc_fragment_idx, frag_com_coords, frag_com_frac_coords, frag_structure_id, frag_local_ids)[source]

Compute vectors & distances from contact atoms to central-fragment COM.

Parameters:
  • cc_coords (torch.Tensor, shape (B, C, 3)) – Cartesian coords of contact atoms.

  • cc_frac_coords (torch.Tensor, shape (B, C, 3)) – Fractional coords of contact atoms.

  • cc_fragment_idx (torch.LongTensor, shape (B, C)) – Fragment indices for contact atoms.

  • frag_com_coords (torch.Tensor, shape (B, F, 3)) – Cartesian COM coords for each fragment.

  • frag_com_frac_coords (torch.Tensor, shape (B, F, 3)) – Fractional COM coords for each fragment.

  • frag_structure_id (torch.LongTensor, shape (B, F)) – Structure IDs for each fragment.

  • frag_local_ids (torch.LongTensor, shape (B, F)) – Local fragment indices.

Returns:

{

‘inter_cc_atom_to_central_com_vec’: torch.Tensor, ‘inter_cc_atom_to_central_com_dist’: torch.Tensor

}

Return type:

Dict[str, torch.Tensor]

Compute Contact-to-COM Vectors and Distances

Calculates vectors and distances from contact atoms to fragment centers of mass.

Parameters:
  • cc_coords (torch.Tensor) - Contact atom coordinates

  • cc_frac_coords (torch.Tensor) - Contact fractional coordinates

  • cc_fragment_idx (torch.LongTensor) - Contact fragment indices

  • frag_com_coords (torch.Tensor) - Fragment COM coordinates

  • frag_com_frac_coords (torch.Tensor) - Fragment COM fractional coordinates

  • frag_structure_id (torch.LongTensor) - Fragment structure IDs

  • frag_local_ids (torch.LongTensor) - Fragment local IDs

Vector Calculations:

# Map contacts to fragment COM positions
contact_com_coords = torch.gather(
    frag_com_coords,
    dim=-2,
    index=cc_fragment_idx.unsqueeze(-1).expand(-1, -1, 3)
)

# Compute contact-to-COM vectors
com_vectors = cc_coords - contact_com_coords
com_distances = torch.norm(com_vectors, dim=-1)

# Normalize for directional analysis
com_unit_vectors = com_vectors / com_distances.unsqueeze(-1)

Geometric Descriptors:

  • Contact Distance - Direct atom-atom distance

  • COM Distance - Distance from contact to fragment center

  • Contact Vector - Direction from COM to contact point

  • Surface Normal - Estimated surface normal at contact

  • Contact Accessibility - Geometric accessibility measure

Applications:

  • Packing Analysis - How fragments pack together

  • Surface Properties - Contact surface characterization

  • Intermolecular Forces - Direction and magnitude analysis

  • Crystal Engineering - Design principles extraction

Returns:

Dictionary with contact vectors, distances, and geometric descriptors

Data Writing and I/O Methods

Usage Examples

Basic Feature Engineering

from structure_post_extraction_processor import StructurePostExtractionProcessor
import torch
from pathlib import Path

# Initialize processor
processor = StructurePostExtractionProcessor(
    hdf5_path=Path("./analysis.h5"),
    batch_size=32,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)

# Execute feature engineering
processor.run()

GPU Memory Optimization

# Configure for large datasets with limited GPU memory
import torch

# Set memory fraction
torch.cuda.set_per_process_memory_fraction(0.7)

# Use smaller batch size
memory_optimized_processor = StructurePostExtractionProcessor(
    hdf5_path=Path("./large_dataset.h5"),
    batch_size=16,  # Reduced batch size
    device=torch.device("cuda")
)

# Monitor memory usage
def memory_monitor():
    allocated = torch.cuda.memory_allocated() / 1024**3
    cached = torch.cuda.memory_reserved() / 1024**3
    print(f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached")

# Run with monitoring
memory_monitor()
memory_optimized_processor.run()
memory_monitor()

Custom Feature Selection

# Subclass for custom feature computation
class CustomFeatureProcessor(StructurePostExtractionProcessor):

    def _process_batch(self, start, crystal, atom, bond, intra_cc, intra_hb, inter_cc, inter_hb):
        """Override to add custom features."""

        # Call parent processing
        super()._process_batch(start, crystal, atom, bond, intra_cc, intra_hb, inter_cc, inter_hb)

        # Add custom features
        custom_features = self._compute_custom_descriptors(
            atom.coords, atom.symbols, bond.atom1_idx, bond.atom2_idx
        )

        # Write custom features
        self._write_custom_features(start, custom_features)

    def _compute_custom_descriptors(self, coords, symbols, bond1, bond2):
        """Compute application-specific descriptors."""

        # Example: Compute molecular surface area
        surface_areas = self._compute_molecular_surface_area(coords, symbols)

        # Example: Compute ring strain energies
        ring_strains = self._compute_ring_strain_energies(coords, bond1, bond2)

        return {
            'surface_areas': surface_areas,
            'ring_strains': ring_strains
        }

High-Throughput Processing

# Configure for maximum throughput
high_throughput_processor = StructurePostExtractionProcessor(
    hdf5_path=Path("./massive_dataset.h5"),
    batch_size=128,  # Large batches for throughput
    device=torch.device("cuda")
)

# Use multiple GPUs if available
if torch.cuda.device_count() > 1:
    # Implement data parallel processing
    import torch.nn as nn

    class ParallelProcessor(nn.DataParallel):
        def __init__(self, processor):
            super().__init__(processor)

        def forward(self, batch_data):
            return self.module._process_batch_parallel(batch_data)

    parallel_processor = ParallelProcessor(high_throughput_processor)

Quality Control and Validation

# Implement comprehensive validation
class ValidatedProcessor(StructurePostExtractionProcessor):

    def _validate_batch_results(self, computed_data):
        """Validate computed features for quality."""

        # Check for NaN values
        for key, tensor in computed_data.items():
            if torch.isnan(tensor).any():
                logging.warning(f"NaN values detected in {key}")

        # Check physical constraints
        com_distances = computed_data.get('com_distances')
        if com_distances is not None:
            if (com_distances < 0).any():
                logging.error("Negative distances detected")

        # Check fragment validity
        fragment_ids = computed_data.get('fragment_ids')
        if fragment_ids is not None:
            max_fragment_id = fragment_ids.max()
            expected_max = len(torch.unique(fragment_ids)) - 1
            if max_fragment_id != expected_max:
                logging.warning("Fragment ID inconsistency detected")

    def _process_batch(self, *args, **kwargs):
        # Process normally
        result = super()._process_batch(*args, **kwargs)

        # Validate results
        self._validate_batch_results(result)

        return result

Performance Optimization

GPU Utilization

# Optimize GPU kernel launches
torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
torch.backends.cudnn.deterministic = False  # Allow non-deterministic for speed

# Use mixed precision for memory efficiency
from torch.cuda.amp import autocast, GradScaler

class MixedPrecisionProcessor(StructurePostExtractionProcessor):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.scaler = GradScaler()

    def _process_batch(self, *args, **kwargs):
        with autocast():
            return super()._process_batch(*args, **kwargs)

Memory Management

# Implement gradient checkpointing for memory efficiency
import torch.utils.checkpoint as checkpoint

def memory_efficient_feature_computation(coords, weights):
    """Use checkpointing for memory-intensive computations."""

    def compute_inertia_checkpoint(coords_chunk, weights_chunk):
        return compute_inertia_tensor(coords_chunk, weights_chunk)

    # Use checkpointing to trade compute for memory
    inertia_tensors = checkpoint.checkpoint(
        compute_inertia_checkpoint,
        coords,
        weights,
        use_reentrant=False
    )

    return inertia_tensors

Batch Size Optimization

def find_optimal_batch_size(processor, max_batch_size=256):
    """Find maximum feasible batch size through binary search."""

    low, high = 1, max_batch_size
    optimal_size = 1

    while low <= high:
        mid = (low + high) // 2

        try:
            # Test with current batch size
            processor.batch_size = mid
            test_batch = processor._create_test_batch()
            processor._process_batch(*test_batch)

            # Success - try larger
            optimal_size = mid
            low = mid + 1

        except RuntimeError as e:
            if "out of memory" in str(e):
                # OOM - try smaller
                high = mid - 1
                torch.cuda.empty_cache()
            else:
                raise

    return optimal_size

See Also

crystal_analyzer module : Pipeline orchestration structure_data_extractor module : Raw data extraction fragment_utils module : Fragment analysis utilities geometry_utils module : Geometric computation utilities data_writer module : HDF5 data writing utilities