Source code for data_reader

"""
Module: data_reader.py

Provides RawDataReader for extracting raw NumPy arrays from an input HDF5 file
generated by the StructureDataExtractor. Supports reading crystal parameters,
atom data, bond data, and both inter‐ and intra‐molecular contacts/H‐bonds,
with zero‐padding to fixed maximum dimensions.

Dependencies
------------
h5py
numpy
"""
import h5py
import numpy as np
from typing import List, Dict, Any

[docs] class RawDataReader: """ Read and pad raw per‐structure data from an HDF5 file for batch processing. Attributes ---------- h5_in : h5py.File Open HDF5 file containing raw structure datasets under '/structures/<refcode>'. Methods ------- read_crystal_parameters(batch: List[str]) -> Dict[str, np.ndarray] Read unit‐cell lengths, angles, and scalar crystal properties. read_atoms(batch: List[str], N_max: int) -> Dict[str, Any] Read and pad per‐atom labels, symbols, coordinates, weights, charges, SYBYL types, neighbor lists, and mask. read_bonds(batch: List[str], max_bonds: int) -> Dict[str, Any] Read and pad per‐bond endpoint indices, bond‐type strings, rotatability flags, cyclic flags, and bond lengths. read_intermolecular_contacts(batch: List[str], max_contacts: int) -> Dict[str, Any] Read and pad raw inter‐molecular contact labels, symmetry ops, Cartesian/fractional coords, lengths, strengths, and in‐LOS flags. read_intermolecular_hbonds(batch: List[str], max_hbonds: int) -> Dict[str, Any] Read and pad raw inter‐molecular H‐bond labels, symmetry ops, Cartesian/fractional coords, lengths, angles, and in‐LOS flags. read_intramolecular_contacts(batch: List[str], max_contacts: int) -> Dict[str, Any] Read and pad intra‐molecular contact data. read_intramolecular_hbonds(batch: List[str], max_hbonds: int) -> Dict[str, Any] Read and pad intra‐molecular H‐bond data. """
[docs] def __init__(self, h5_in: h5py.File): """ Initialize RawDataReader. Parameters ---------- h5_in : h5py.File Open HDF5 file handle containing '/structures' groups. """ self.h5_in = h5_in
[docs] def read_crystal_parameters( self, batch: List[str] ) -> (np.ndarray, np.ndarray, Dict[str, np.ndarray]): """ Read unit‐cell lengths, angles, and scalar crystal metrics for a batch. Parameters ---------- batch : List[str] Refcode strings for structures to read. Returns ------- result : dict { 'cell_lengths': np.ndarray, shape (B,3), dtype float32, 'cell_angles': np.ndarray, shape (B,3), dtype float32, 'z_value': np.ndarray, shape (B,), dtype float32, 'z_prime': np.ndarray, shape (B,), dtype float32, 'cell_volume': np.ndarray, shape (B,), dtype float32, 'cell_density': np.ndarray, shape (B,), dtype float32, 'packing_coefficient': np.ndarray, shape (B,), dtype float32, 'identifier': np.ndarray, shape (B,), dtype object, 'space_group': np.ndarray, shape (B,), dtype object } """ B = len(batch) identifier = np.empty(B, dtype=object) lengths = np.zeros((B, 3), dtype=np.float32) angles = np.zeros((B, 3), dtype=np.float32) volume = np.zeros(B, dtype=np.float32) density = np.zeros(B, dtype=np.float32) packing = np.zeros(B, dtype=np.float32) z_value = np.zeros(B, dtype=np.float32) z_prime = np.zeros(B, dtype=np.float32) sg = np.empty(B, dtype=object) for i, ref in enumerate(batch): grp = self.h5_in['structures'][ref] identifier[i] = grp['identifier'][()].decode() lengths[i] = grp['cell_lengths'][()] angles[i] = grp['cell_angles'][()] volume[i] = grp['cell_volume'][()] density[i] = grp['cell_density'][()] packing[i] = grp['packing_coefficient'][()] z_value[i] = grp['z_value'][()] z_prime[i] = grp['z_prime'][()] sg[i] = grp['space_group'][()].decode() return { 'cell_lengths': lengths, 'cell_angles': angles, 'z_value': z_value, 'z_prime': z_prime, 'cell_volume': volume, 'cell_density': density, 'packing_coefficient': packing, 'identifier': identifier, 'space_group': sg }
[docs] def read_atoms( self, batch: List[str], N_max: int ) -> Dict[str, Any]: """ Read and pad per‐atom data for a batch of structures. Parameters ---------- batch : List[str] Refcode strings to read from. N_max : int Maximum atom count for padding. Returns ------- result : dict { 'n_atoms': List[int], 'atom_label': List[List[str]], 'atom_symbol': List[List[str]], 'atom_number': np.ndarray, shape (B, N_max), dtype int32, 'atom_coords': np.ndarray, shape (B, N_max, 3), dtype float32, 'atom_frac_coords': np.ndarray, shape (B, N_max, 3), dtype float32, 'atom_weight': np.ndarray, shape (B, N_max), dtype float32, 'atom_charge': np.ndarray, shape (B, N_max), dtype float32, 'atom_sybyl_type': List[List[str]], 'atom_neighbour_list': List[List[str]], 'atom_mask': np.ndarray, shape (B, N_max), dtype bool } """ B = len(batch) atomic_numbers = np.zeros((B, N_max), dtype=np.int32) coords = np.zeros((B, N_max, 3), dtype=np.float32) frac = np.zeros((B, N_max, 3), dtype=np.float32) weights = np.zeros((B, N_max), dtype=np.float32) charges = np.zeros((B, N_max), dtype=np.float32) mask = np.zeros((B, N_max), dtype=bool) atomic_labels: List[List[str]] = [] atomic_symbols: List[List[str]] = [] sybyl_types: List[List[str]] = [] neighbour_lists: List[List[str]] = [] counts: List[int] = [] for i, ref in enumerate(batch): grp = self.h5_in['structures'][ref] labels = grp['atom_label'][()].astype(str).tolist() symbols = grp['atom_symbol'][()].astype(str).tolist() sybyl = grp['atom_sybyl_type'][()].astype(str).tolist() neighs = grp['atom_neighbour_list'][()].astype(str).tolist() nat = len(labels) counts.append(nat) atomic_labels.append(labels) atomic_symbols.append(symbols) sybyl_types.append(sybyl) neighbour_lists.append(neighs) atomic_numbers[i, :nat] = grp['atom_number'][()] coords[i, :nat] = grp['atom_coords'][()] frac[i, :nat] = grp['atom_frac_coords'][()] weights[i, :nat] = grp['atom_weight'][()] charges[i, :nat] = grp['atom_charge'][()] mask[i, :nat] = True return { "n_atoms": counts, "atom_label": atomic_labels, "atom_symbol": atomic_symbols, "atom_number": atomic_numbers, "atom_coords": coords, "atom_frac_coords": frac, "atom_weight": weights, "atom_charge": charges, "atom_sybyl_type": sybyl_types, "atom_neighbour_list": neighbour_lists, "atom_mask": mask, }
[docs] def read_bonds( self, batch: List[str], max_bonds: int ) -> Dict[str, Any]: """ Read and pad per‐bond data for a batch of structures. Parameters ---------- batch : List[str] Refcode strings to read bonds from. max_bonds : int Maximum bond count for padding. Returns ------- result : dict { 'n_bonds': List[int], 'bond_id': List[List[str]], 'bond_atom1': List[List[str]], 'bond_atom2': List[List[str]], 'bond_atom1_idx': np.ndarray, shape (B, max_bonds), dtype int32, 'bond_atom2_idx': np.ndarray, shape (B, max_bonds), dtype int32, 'bond_type': List[List[str]], 'bond_is_rotatable_raw': List[List[bool]], 'bond_is_cyclic': np.ndarray, shape (B, max_bonds), dtype bool, 'bond_length': np.ndarray, shape (B, max_bonds), dtype float32, 'bond_mask': np.ndarray, shape (B, max_bonds), dtype bool } """ B = len(batch) atom1_idx = np.zeros((B, max_bonds), dtype=np.int32) atom2_idx = np.zeros((B, max_bonds), dtype=np.int32) lengths = np.zeros((B, max_bonds), dtype=np.float32) cyclic = np.zeros((B, max_bonds), dtype=bool) mask = np.zeros((B, max_bonds), dtype=bool) bids: List[List[str]] = [] atom1: List[List[str]] = [] atom2: List[List[str]] = [] types: List[List[str]] = [] rot: List[List[bool]] = [] counts: List[int] = [] for i, ref in enumerate(batch): grp = self.h5_in['structures'][ref] nb = grp['bond_atom1_idx'].shape[0] counts.append(nb) bids.append(grp['bond_id'][()].astype(str).tolist()) atom1.append(grp['bond_atom1'][()].astype(str).tolist()) atom2.append(grp['bond_atom2'][()].astype(str).tolist()) types.append(grp['bond_type'][()].astype(str).tolist()) rot.append(grp['bond_is_rotatable'][()].tolist()) atom1_idx[i, :nb] = grp['bond_atom1_idx'][()] atom2_idx[i, :nb] = grp['bond_atom2_idx'][()] cyclic[i, :nb] = grp['bond_is_cyclic'][()] lengths[i, :nb] = grp['bond_length'][()] mask[i, :nb] = True return { "n_bonds": counts, "bond_id": bids, "bond_atom1": atom1, "bond_atom2": atom2, "bond_atom1_idx": atom1_idx, "bond_atom2_idx": atom2_idx, "bond_type": types, "bond_is_rotatable_raw": rot, "bond_is_cyclic": cyclic, "bond_length": lengths, "bond_mask": mask }
[docs] def read_intermolecular_contacts( self, batch: List[str], max_contacts: int ) -> Dict[str, Any]: """ Read and pad raw intermolecular contact data. Parameters ---------- batch : List[str] Refcode strings to read contacts from. max_contacts : int Maximum contact count for padding. Returns ------- result : dict { 'n_inter_cc': List[int], 'inter_cc_id': List[List[str]], 'inter_cc_central_atom': List[List[str]], 'inter_cc_contact_atom': List[List[str]], 'inter_cc_central_atom_idx': np.ndarray, shape (B, max_contacts), dtype int32, 'inter_cc_contact_atom_idx': np.ndarray, shape (B, max_contacts), dtype int32, 'inter_cc_symmetry': List[List[str]], 'inter_cc_central_atom_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'inter_cc_contact_atom_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'inter_cc_central_atom_frac_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'inter_cc_contact_atom_frac_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'inter_cc_length': np.ndarray, shape (B, max_contacts), dtype float32, 'inter_cc_strength': np.ndarray, shape (B, max_contacts), dtype float32, 'inter_cc_in_los': np.ndarray, shape (B, max_contacts), dtype bool } """ B = len(batch) cent_idx = np.zeros((B, max_contacts), dtype=np.int32) cont_idx = np.zeros((B, max_contacts), dtype=np.int32) cent_cart = np.zeros((B, max_contacts, 3), dtype=np.float32) cont_cart = np.zeros((B, max_contacts, 3), dtype=np.float32) cent_frac = np.zeros((B, max_contacts, 3), dtype=np.float32) cont_frac = np.zeros((B, max_contacts, 3), dtype=np.float32) lengths = np.zeros((B, max_contacts), dtype=np.float32) strengths = np.zeros((B, max_contacts), dtype=np.float32) in_los = np.zeros((B, max_contacts), dtype=bool) cids: List[List[str]] = [] cent_labels: List[List[str]] = [] cont_labels: List[List[str]] = [] sym_ops: List[List[str]] = [] counts: List[int] = [] for i, ref in enumerate(batch): grp = self.h5_in['structures'][ref] cid = grp['inter_cc_id'][()].astype(str).tolist() cent = grp['inter_cc_central_atom'][()].astype(str).tolist() cont = grp['inter_cc_contact_atom'][()].astype(str).tolist() ops = grp['inter_cc_symmetry'][()].astype(str).tolist() nC = len(cid) counts.append(nC) cids.append(cid) cent_labels.append(cent) cont_labels.append(cont) sym_ops.append(ops) cent_idx[i, :nC] = grp['inter_cc_central_atom_idx'][()] cont_idx[i, :nC] = grp['inter_cc_contact_atom_idx'][()] cent_cart[i, :nC] = grp['inter_cc_central_atom_coords'][()] cont_cart[i, :nC] = grp['inter_cc_contact_atom_coords'][()] cent_frac[i, :nC] = grp['inter_cc_central_atom_frac_coords'][()] cont_frac[i, :nC] = grp['inter_cc_contact_atom_frac_coords'][()] lengths[i, :nC] = grp['inter_cc_length'][()] strengths[i, :nC] = grp['inter_cc_strength'][()] in_los[i, :nC] = grp['inter_cc_in_los'][()] return { "n_inter_cc": counts, "inter_cc_id": cids, "inter_cc_central_atom": cent_labels, "inter_cc_contact_atom": cont_labels, "inter_cc_central_atom_idx": cent_idx, "inter_cc_contact_atom_idx": cont_idx, "inter_cc_symmetry": sym_ops, "inter_cc_central_atom_coords": cent_cart, "inter_cc_contact_atom_coords": cont_cart, "inter_cc_central_atom_frac_coords": cent_frac, "inter_cc_contact_atom_frac_coords": cont_frac, "inter_cc_length": lengths, "inter_cc_strength": strengths, "inter_cc_in_los": in_los, }
[docs] def read_intermolecular_hbonds( self, batch: List[str], max_hbonds: int ) -> Dict[str, Any]: """ Read and pad raw intermolecular H‐bond data. Parameters ---------- batch : List[str] Refcode strings to read H‐bonds from. max_hbonds : int Maximum H‐bond count for padding. Returns ------- result : dict { 'n_inter_hb': List[int], 'inter_hb_id': List[List[str]], 'inter_hb_central_atom': List[List[str]], 'inter_hb_hydrogen_atom': List[List[str]], 'inter_hb_contact_atom': List[List[str]], 'inter_hb_central_atom_idx': np.ndarray, shape (B, max_hbonds), dtype int32, 'inter_hb_hydrogen_atom_idx': np.ndarray, shape (B, max_hbonds), dtype int32, 'inter_hb_contact_atom_idx': np.ndarray, shape (B, max_hbonds), dtype int32, 'inter_hb_symmetry': List[List[str]], 'inter_hb_central_atom_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'inter_hb_hydrogen_atom_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'inter_hb_contact_atom_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'inter_hb_central_atom_frac_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'inter_hb_hydrogen_atom_frac_coords':np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'inter_hb_contact_atom_frac_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'inter_hb_length': np.ndarray, shape (B, max_hbonds), dtype float32, 'inter_hb_angle': np.ndarray, shape (B, max_hbonds), dtype float32, 'inter_hb_in_los': np.ndarray, shape (B, max_hbonds), dtype bool } """ B = len(batch) cent_idx = np.zeros((B, max_hbonds), dtype=np.int32) hydr_idx = np.zeros((B, max_hbonds), dtype=np.int32) cont_idx = np.zeros((B, max_hbonds), dtype=np.int32) cent_cart = np.zeros((B, max_hbonds, 3), dtype=np.float32) hydr_cart = np.zeros((B, max_hbonds, 3), dtype=np.float32) cont_cart = np.zeros((B, max_hbonds, 3), dtype=np.float32) cent_frac = np.zeros((B, max_hbonds, 3), dtype=np.float32) hydr_frac = np.zeros((B, max_hbonds, 3), dtype=np.float32) cont_frac = np.zeros((B, max_hbonds, 3), dtype=np.float32) lengths = np.zeros((B, max_hbonds), dtype=np.float32) angles = np.zeros((B, max_hbonds), dtype=np.float32) in_los = np.zeros((B, max_hbonds), dtype=bool) hids : List[List[str]] = [] cent_labels: List[List[str]] = [] hydr_labels: List[List[str]] = [] cont_labels: List[List[str]] = [] sym_ops: List[List[str]] = [] counts: List[str] = [] for i, ref in enumerate(batch): grp = self.h5_in['structures'][ref] hid = grp['inter_hb_id'][()].astype(str).tolist() cent = grp['inter_hb_central_atom'][()].astype(str).tolist() hydr = grp['inter_hb_hydrogen_atom'][()].astype(str).tolist() cont = grp['inter_hb_contact_atom'][()].astype(str).tolist() ops = grp['inter_hb_symmetry'][()].astype(str).tolist() nH = len(cent) counts.append(nH) hids.append(hid) cent_labels.append(cent) hydr_labels.append(hydr) cont_labels.append(cont) sym_ops.append(ops) cent_idx[i, :nH] = grp['inter_hb_central_atom_idx'][()] hydr_idx[i, :nH] = grp['inter_hb_hydrogen_atom_idx'][()] cont_idx[i, :nH] = grp['inter_hb_contact_atom_idx'][()] cent_cart[i, :nH] = grp['inter_hb_central_atom_coords'][()] hydr_cart[i, :nH] = grp['inter_hb_hydrogen_atom_coords'][()] cont_cart[i, :nH] = grp['inter_hb_contact_atom_coords'][()] cent_frac[i, :nH] = grp['inter_hb_central_atom_frac_coords'][()] hydr_frac[i, :nH] = grp['inter_hb_hydrogen_atom_frac_coords'][()] cont_frac[i, :nH] = grp['inter_hb_contact_atom_frac_coords'][()] lengths[i, :nH] = grp['inter_hb_length'][()] angles[i, :nH] = grp['inter_hb_angle'][()] in_los[i, :nH] = grp['inter_hb_in_los'][()] return { "n_inter_hb": counts, "inter_hb_id": hids, "inter_hb_central_atom": cent_labels, "inter_hb_hydrogen_atom": hydr_labels, "inter_hb_contact_atom": cont_labels, "inter_hb_central_atom_idx": cent_idx, "inter_hb_hydrogen_atom_idx": hydr_idx, "inter_hb_contact_atom_idx": cont_idx, "inter_hb_symmetry": sym_ops, "inter_hb_central_atom_coords": cent_cart, "inter_hb_hydrogen_atom_coords": hydr_cart, "inter_hb_contact_atom_coords": cont_cart, "inter_hb_central_atom_frac_coords": cent_frac, "inter_hb_hydrogen_atom_frac_coords": hydr_frac, "inter_hb_contact_atom_frac_coords": cont_frac, "inter_hb_length": lengths, "inter_hb_angle": angles, "inter_hb_in_los": in_los, }
[docs] def read_intramolecular_contacts( self, batch: List[str], max_contacts: int ) -> Dict[str, Any]: """ Read and pad raw intramolecular contact data. Parameters ---------- batch : List[str] Structure identifiers to read intra‐contacts from. max_contacts : int Maximum intra‐contact count for padding. Returns ------- result : dict { 'n_intra_cc': List[int], 'intra_cc_id': List[List[str]], 'intra_cc_central_atom': List[List[str]], 'intra_cc_contact_atom': List[List[str]], 'intra_cc_central_atom_idx': np.ndarray, shape (B, max_contacts), dtype int32, 'intra_cc_contact_atom_idx': np.ndarray, shape (B, max_contacts), dtype int32, 'intra_cc_central_atom_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'intra_cc_contact_atom_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'intra_cc_central_atom_frac_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'intra_cc_contact_atom_frac_coords': np.ndarray, shape (B, max_contacts, 3), dtype float32, 'intra_cc_length': np.ndarray, shape (B, max_contacts), dtype float32, 'intra_cc_strength': np.ndarray, shape (B, max_contacts), dtype float32, 'intra_cc_in_los': np.ndarray, shape (B, max_contacts), dtype bool } """ B = len(batch) cent_idx = np.zeros((B, max_contacts), dtype=np.int32) cont_idx = np.zeros((B, max_contacts), dtype=np.int32) cent_cart = np.zeros((B, max_contacts, 3), dtype=np.float32) cont_cart = np.zeros((B, max_contacts, 3), dtype=np.float32) cent_frac = np.zeros((B, max_contacts, 3), dtype=np.float32) cont_frac = np.zeros((B, max_contacts, 3), dtype=np.float32) lengths = np.zeros((B, max_contacts), dtype=np.float32) strengths = np.zeros((B, max_contacts), dtype=np.float32) in_los = np.zeros((B, max_contacts), dtype=bool) cids: List[List[str]] = [] cent_labels: List[List[str]] = [] cont_labels: List[List[str]] = [] counts: List[int] = [] for i, ref in enumerate(batch): grp = self.h5_in['structures'][ref] cid = grp['intra_cc_id'][()].astype(str).tolist() cent = grp['intra_cc_central_atom'][()].astype(str).tolist() cont = grp['intra_cc_contact_atom'][()].astype(str).tolist() nC = len(cent) counts.append(nC) cids.append(cid) cent_labels.append(cent) cont_labels.append(cont) cent_idx[i, :nC] = grp['intra_cc_central_atom_idx'][()] cont_idx[i, :nC] = grp['intra_cc_contact_atom_idx'][()] cent_cart[i, :nC] = grp['intra_cc_central_atom_coords'][()] cont_cart[i, :nC] = grp['intra_cc_contact_atom_coords'][()] cent_frac[i, :nC] = grp['intra_cc_central_atom_frac_coords'][()] cont_frac[i, :nC] = grp['intra_cc_contact_atom_frac_coords'][()] lengths[i, :nC] = grp['intra_cc_length'][()] strengths[i, :nC] = grp['intra_cc_strength'][()] in_los[i, :nC] = grp['intra_cc_in_los'][()] return { "n_intra_cc": counts, "intra_cc_id": cids, "intra_cc_central_atom": cent_labels, "intra_cc_contact_atom": cont_labels, "intra_cc_central_atom_idx": cent_idx, "intra_cc_contact_atom_idx": cont_idx, "intra_cc_central_atom_coords": cent_cart, "intra_cc_contact_atom_coords": cont_cart, "intra_cc_central_atom_frac_coords": cent_frac, "intra_cc_contact_atom_frac_coords": cont_frac, "intra_cc_length": lengths, "intra_cc_strength": strengths, "intra_cc_in_los": in_los, }
[docs] def read_intramolecular_hbonds( self, batch: List[str], max_hbonds: int ) -> Dict[str, Any]: """ Read and pad raw intramolecular H‐bond data. Parameters ---------- batch : List[str] Structure identifiers to read intra‐H‐bonds from. max_hbonds : int Maximum intra‐H‐bond count for padding. Returns ------- result : dict { 'n_intra_hb': List[int], 'intra_hb_id': List[List[str]], 'intra_hb_central_atom': List[List[str]], 'intra_hb_hydrogen_atom': List[List[str]], 'intra_hb_contact_atom': List[List[str]], 'intra_hb_central_atom_idx': np.ndarray, shape (B, max_hbonds), dtype int32, 'intra_hb_hydrogen_atom_idx': np.ndarray, shape (B, max_hbonds), dtype int32, 'intra_hb_contact_atom_idx': np.ndarray, shape (B, max_hbonds), dtype int32, 'intra_hb_central_atom_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'intra_hb_hydrogen_atom_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'intra_hb_contact_atom_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'intra_hb_central_atom_frac_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'intra_hb_hydrogen_atom_frac_coords':np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'intra_hb_contact_atom_frac_coords': np.ndarray, shape (B, max_hbonds, 3), dtype float32, 'intra_hb_length': np.ndarray, shape (B, max_hbonds), dtype float32, 'intra_hb_angle': np.ndarray, shape (B, max_hbonds), dtype float32, 'intra_hb_in_los': np.ndarray, shape (B, max_hbonds), dtype bool } """ B = len(batch) cent_idx = np.zeros((B, max_hbonds), dtype=np.int32) hydr_idx = np.zeros((B, max_hbonds), dtype=np.int32) cont_idx = np.zeros((B, max_hbonds), dtype=np.int32) cent_cart = np.zeros((B, max_hbonds, 3), dtype=np.float32) hydr_cart = np.zeros((B, max_hbonds, 3), dtype=np.float32) cont_cart = np.zeros((B, max_hbonds, 3), dtype=np.float32) cent_frac = np.zeros((B, max_hbonds, 3), dtype=np.float32) hydr_frac = np.zeros((B, max_hbonds, 3), dtype=np.float32) cont_frac = np.zeros((B, max_hbonds, 3), dtype=np.float32) lengths = np.zeros((B, max_hbonds), dtype=np.float32) angles = np.zeros((B, max_hbonds), dtype=np.float32) in_los = np.zeros((B, max_hbonds), dtype=bool) hids : List[List[str]] = [] cent_labels: List[List[str]] = [] hydr_labels: List[List[str]] = [] cont_labels: List[List[str]] = [] counts: List[str] = [] for i, ref in enumerate(batch): grp = self.h5_in['structures'][ref] hid = grp['intra_hb_id'][()].astype(str).tolist() cent = grp['intra_hb_central_atom'][()].astype(str).tolist() hydr = grp['intra_hb_hydrogen_atom'][()].astype(str).tolist() cont = grp['intra_hb_contact_atom'][()].astype(str).tolist() nH = len(cent) counts.append(nH) hids.append(hid) cent_labels.append(cent) hydr_labels.append(hydr) cont_labels.append(cont) cent_idx[i, :nH] = grp['intra_hb_central_atom_idx'][()] hydr_idx[i, :nH] = grp['intra_hb_hydrogen_atom_idx'][()] cont_idx[i, :nH] = grp['intra_hb_contact_atom_idx'][()] cent_cart[i, :nH] = grp['intra_hb_central_atom_coords'][()] hydr_cart[i, :nH] = grp['intra_hb_hydrogen_atom_coords'][()] cont_cart[i, :nH] = grp['intra_hb_contact_atom_coords'][()] cent_frac[i, :nH] = grp['intra_hb_central_atom_frac_coords'][()] hydr_frac[i, :nH] = grp['intra_hb_hydrogen_atom_frac_coords'][()] cont_frac[i, :nH] = grp['intra_hb_contact_atom_frac_coords'][()] lengths[i, :nH] = grp['intra_hb_length'][()] angles[i, :nH] = grp['intra_hb_angle'][()] in_los[i, :nH] = grp['intra_hb_in_los'][()] return { "n_intra_hb": counts, "intra_hb_id": hids, "intra_hb_central_atom": cent_labels, "intra_hb_hydrogen_atom": hydr_labels, "intra_hb_contact_atom": cont_labels, "intra_hb_central_atom_idx": cent_idx, "intra_hb_hydrogen_atom_idx": hydr_idx, "intra_hb_contact_atom_idx": cont_idx, "intra_hb_central_atom_coords": cent_cart, "intra_hb_hydrogen_atom_coords": hydr_cart, "intra_hb_contact_atom_coords": cont_cart, "intra_hb_central_atom_frac_coords": cent_frac, "intra_hb_hydrogen_atom_frac_coords": hydr_frac, "intra_hb_contact_atom_frac_coords": cont_frac, "intra_hb_length": lengths, "intra_hb_angle": angles, "intra_hb_in_los": in_los, }