Source code for contact_utils

"""
Module: contact_utils.py

Batch utilities for computing and expanding intermolecular contacts and
hydrogen bonds using symmetry operations and mapping to fragment-level
descriptors.

Dependencies
------------
torch
"""
from typing import List, Tuple
import torch

[docs] def compute_symmetric_contacts_batch( central_atom_label: List[List[str]], contact_atom_label: List[List[str]], central_atom_idx: torch.Tensor, contact_atom_idx: torch.Tensor, central_atom_frac_coords: torch.Tensor, contact_atom_frac_coords: torch.Tensor, lengths: torch.Tensor, strengths: torch.Tensor, in_los: torch.Tensor, symmetry_A: torch.Tensor, symmetry_T: torch.Tensor, symmetry_A_inv: torch.Tensor, symmetry_T_inv: torch.Tensor, cell_matrix: torch.Tensor, device: torch.device ) -> Tuple[ List[List[str]], List[List[str]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[List[str]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: """ Expand intermolecular contacts by applying precomputed symmetry operations. Parameters ---------- central_atom_label : List[List[str]] Original central-atom labels for each structure. contact_atom_label : List[List[str]] Original contact-atom labels for each structure. central_atom_idx : torch.LongTensor, shape (B, C) Central-atom indices per contact. contact_atom_idx : torch.LongTensor, shape (B, C) Contact-atom indices per contact. central_atom_frac_coords : torch.Tensor, shape (B, C, 3) Fractional coordinates of central atoms. contact_atom_frac_coords : torch.Tensor, shape (B, C, 3) Fractional coordinates of contact atoms. lengths : torch.Tensor, shape (B, C) Contact distances. strengths : torch.Tensor, shape (B, C) Contact strength metrics. in_los : torch.Tensor, shape (B, C) Line-of-sight contact mask. symmetry_A : torch.Tensor, shape (B, C, 3, 3) Symmetry rotation matrices. symmetry_T : torch.Tensor, shape (B, C, 3) Symmetry translation vectors. symmetry_A_inv : torch.Tensor, shape (B, C, 3, 3) Inverse symmetry rotation matrices. symmetry_T_inv : torch.Tensor, shape (B, C, 3) Inverse symmetry translation vectors. cell_matrix : torch.Tensor, shape (B, 3, 3) Real-space cell matrices for each structure. device : torch.device Device on which to perform computations. Returns ------- dict Dictionary of extended contact parameters: - inter_cc_central_atom : List[List[str]] - inter_cc_contact_atom : List[List[str]] - inter_cc_central_atom_idx : torch.LongTensor, shape (B, 2C) - inter_cc_contact_atom_idx : torch.LongTensor, shape (B, 2C) - inter_cc_central_atom_coords : torch.Tensor, shape (B, 2C, 3) - inter_cc_contact_atom_coords : torch.Tensor, shape (B, 2C, 3) - inter_cc_central_atom_frac_coords : torch.Tensor, shape (B, 2C, 3) - inter_cc_contact_atom_frac_coords : torch.Tensor, shape (B, 2C, 3) - inter_cc_length : torch.Tensor, shape (B, 2C) - inter_cc_strength : torch.Tensor, shape (B, 2C) - inter_cc_in_los : torch.Tensor, shape (B, 2C) - inter_cc_symmetry_A : torch.Tensor, shape (B, 2C, 3, 3) - inter_cc_symmetry_T : torch.Tensor, shape (B, 2C, 3) - inter_cc_symmetry_A_inv : torch.Tensor, shape (B, 2C, 3, 3) - inter_cc_symmetry_T_inv : torch.Tensor, shape (B, 2C, 3) - inter_cc_mask : torch.BoolTensor, shape (B, 2C) """ # Move fractional coords to device central_atom_frac_coords = central_atom_frac_coords.to(device=device) contact_atom_frac_coords = contact_atom_frac_coords.to(device=device) # Use fractional coords dtype to cast all other tensors dtype = central_atom_frac_coords.dtype # Move and cast other tensors central_atom_idx = central_atom_idx.to(device=device) contact_atom_idx = contact_atom_idx.to(device=device) lengths = lengths.to(device=device, dtype=dtype) strengths = strengths.to(device=device, dtype=dtype) in_los = in_los.to(device=device, dtype=dtype) symmetry_A = symmetry_A.to(device=device, dtype=dtype) symmetry_T = symmetry_T.to(device=device, dtype=dtype) symmetry_A_inv = symmetry_A_inv.to(device=device, dtype=dtype) symmetry_T_inv = symmetry_T_inv.to(device=device, dtype=dtype) cell_matrix = cell_matrix.to(device=device, dtype=dtype) B, C, _ = central_atom_frac_coords.shape # 1) Compute reversed fractional coords central_atom_frac_coords_rev = torch.einsum('bcij,bcj->bci', symmetry_A_inv, contact_atom_frac_coords) + symmetry_T_inv contact_atom_frac_coords_rev = torch.einsum('bcij,bcj->bci', symmetry_A_inv, central_atom_frac_coords) + symmetry_T_inv # 2) Concatenate original + reversed coords central_atom_frac_coords_pre = torch.cat([central_atom_frac_coords, central_atom_frac_coords_rev], dim=1) contact_atom_frac_coords_pre = torch.cat([contact_atom_frac_coords, contact_atom_frac_coords_rev], dim=1) # 3) Cartesian coords central_atom_coords_pre = torch.matmul(central_atom_frac_coords_pre, cell_matrix) contact_atom_coords_pre = torch.matmul(contact_atom_frac_coords_pre, cell_matrix) # 4) Duplicate metrics and indices lengths_pre = torch.cat([lengths, lengths], dim=1) strengths_pre = torch.cat([strengths, strengths], dim=1) in_los_pre = torch.cat([in_los, in_los], dim=1) central_atom_idx_pre = torch.cat([central_atom_idx, contact_atom_idx], dim=1) contact_atom_idx_pre = torch.cat([contact_atom_idx, central_atom_idx], dim=1) # 5) Prepare extended symmetry matrices symmetry_A_ext = torch.cat([symmetry_A, symmetry_A_inv], dim=1) symmetry_T_ext = torch.cat([symmetry_T, symmetry_T_inv], dim=1) symmetry_A_inv_ext = torch.cat([symmetry_A_inv, symmetry_A], dim=1) symmetry_T_inv_ext = torch.cat([symmetry_T_inv, symmetry_T], dim=1) # 6) Allocate zero-padded outputs central_atom_frac_coords_ext = torch.zeros_like(central_atom_frac_coords_pre) contact_atom_frac_coords_ext = torch.zeros_like(contact_atom_frac_coords_pre) central_atom_coords_ext = torch.zeros_like(central_atom_coords_pre) contact_atom_coords_ext = torch.zeros_like(contact_atom_coords_pre) lengths_ext = torch.zeros_like(lengths_pre) strengths_ext = torch.zeros_like(strengths_pre) in_los_ext = torch.zeros_like(in_los_pre) central_atom_idx_ext = torch.zeros_like(central_atom_idx_pre) contact_atom_idx_ext = torch.zeros_like(contact_atom_idx_pre) # 7) Pack valid contacts first for b in range(B): nC = int((lengths[b] > 0).sum()) if nC == 0: continue orig_end = nC rev_start = nC rev_end = 2 * nC # coords central_atom_frac_coords_ext[b, :orig_end] = central_atom_frac_coords_pre[b, :orig_end] central_atom_frac_coords_ext[b, rev_start:rev_end] = central_atom_frac_coords_pre[b, C:C + nC] contact_atom_frac_coords_ext[b, :orig_end] = contact_atom_frac_coords_pre[b, :orig_end] contact_atom_frac_coords_ext[b, rev_start:rev_end] = contact_atom_frac_coords_pre[b, C:C + nC] central_atom_coords_ext[b, :orig_end] = central_atom_coords_pre[b, :orig_end] central_atom_coords_ext[b, rev_start:rev_end] = central_atom_coords_pre[b, C:C + nC] contact_atom_coords_ext[b, :orig_end] = contact_atom_coords_pre[b, :orig_end] contact_atom_coords_ext[b, rev_start:rev_end] = contact_atom_coords_pre[b, C:C + nC] # metrics lengths_ext[b, :orig_end] = lengths_pre[b, :orig_end] lengths_ext[b, rev_start:rev_end] = lengths_pre[b, C:C + nC] strengths_ext[b, :orig_end] = strengths_pre[b, :orig_end] strengths_ext[b, rev_start:rev_end] = strengths_pre[b, C:C + nC] in_los_ext[b, :orig_end] = in_los_pre[b, :orig_end] in_los_ext[b, rev_start:rev_end] = in_los_pre[b, C:C + nC] # indices central_atom_idx_ext[b, :orig_end] = central_atom_idx_pre[b, :orig_end] central_atom_idx_ext[b, rev_start:rev_end] = central_atom_idx_pre[b, C:C + nC] contact_atom_idx_ext[b, :orig_end] = contact_atom_idx_pre[b, :orig_end] contact_atom_idx_ext[b, rev_start:rev_end] = contact_atom_idx_pre[b, C:C + nC] # 8) Extend labels central_atom_labels_ext = [orig + rev for orig, rev in zip(central_atom_label, contact_atom_label)] contact_atom_labels_ext = [orig + rev for orig, rev in zip(contact_atom_label, central_atom_label)] # 9) Calculate the contact mask inter_cc_mask = lengths_ext > 0 return { "inter_cc_central_atom": central_atom_labels_ext, "inter_cc_contact_atom": contact_atom_labels_ext, "inter_cc_central_atom_idx": central_atom_idx_ext, "inter_cc_contact_atom_idx": contact_atom_idx_ext, "inter_cc_central_atom_coords": central_atom_coords_ext, "inter_cc_contact_atom_coords": contact_atom_coords_ext, "inter_cc_central_atom_frac_coords": central_atom_frac_coords_ext, "inter_cc_contact_atom_frac_coords": contact_atom_frac_coords_ext, "inter_cc_length": lengths_ext, "inter_cc_strength": strengths_ext, "inter_cc_in_los": in_los_ext, "inter_cc_symmetry_A": symmetry_A_ext, "inter_cc_symmetry_T": symmetry_T_ext, "inter_cc_symmetry_A_inv": symmetry_A_inv_ext, "inter_cc_symmetry_T_inv": symmetry_T_inv_ext, "inter_cc_mask": inter_cc_mask }
[docs] def compute_symmetric_hbonds_batch( central_atom_label: List[List[str]], hydrogen_atom_label: List[List[str]], contact_atom_label: List[List[str]], central_atom_idx: torch.Tensor, hydrogen_atom_idx: torch.Tensor, contact_atom_idx: torch.Tensor, central_atom_frac_coords: torch.Tensor, hydrogen_atom_frac_coords: torch.Tensor, contact_atom_frac_coords: torch.Tensor, lengths: torch.Tensor, angles: torch.Tensor, in_los: torch.Tensor, symmetry_A: torch.Tensor, symmetry_T: torch.Tensor, symmetry_A_inv: torch.Tensor, symmetry_T_inv: torch.Tensor, cell_matrix: torch.Tensor, device: torch.device ) -> Tuple[ List[List[str]], List[List[str]], List[List[str]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[List[str]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: """ Expand intermolecular H-bonds by applying precomputed symmetry operations. Parameters ---------- central_atom_label : List[List[str]] Original donor-atom labels for each structure. hydrogen_atom_label : List[List[str]] Original hydrogen-atom labels for each structure. contact_atom_label : List[List[str]] Original acceptor-atom labels for each structure. central_atom_idx : torch.LongTensor, shape (B, H) Donor-atom indices per H-bond. hydrogen_atom_idx : torch.LongTensor, shape (B, H) Hydrogen-atom indices per H-bond. contact_atom_idx : torch.LongTensor, shape (B, H) Acceptor-atom indices per H-bond. central_atom_frac_coords : torch.Tensor, shape (B, H, 3) Fractional coordinates of donor atoms. hydrogen_atom_frac_coords : torch.Tensor, shape (B, H, 3) Fractional coordinates of hydrogen atoms. contact_atom_frac_coords : torch.Tensor, shape (B, H, 3) Fractional coordinates of acceptor atoms. lengths : torch.Tensor, shape (B, H) H-bond lengths. angles : torch.Tensor, shape (B, H) H-bond angles. in_los : torch.Tensor, shape (B, H) Line-of-sight flag per H-bond. symmetry_A : torch.Tensor, shape (B, H, 3, 3) Symmetry rotation matrices. symmetry_T : torch.Tensor, shape (B, H, 3) Symmetry translation vectors. symmetry_A_inv : torch.Tensor, shape (B, H, 3, 3) Inverse symmetry rotation matrices. symmetry_T_inv : torch.Tensor, shape (B, H, 3) Inverse symmetry translation vectors. cell_matrix : torch.Tensor, shape (B, 3, 3) Real-space cell matrices for each structure. device : torch.device Device on which to perform computations. Returns ------- dict Dictionary of extended H-bond parameters: - inter_hb_central_atom : List[List[str]] - inter_hb_hydrogen_atom : List[List[str]] - inter_hb_contact_atom : List[List[str]] - inter_hb_central_atom_idx : torch.LongTensor, shape (B, 2H) - inter_hb_hydrogen_atom_idx : torch.LongTensor, shape (B, 2H) - inter_hb_contact_atom_idx : torch.LongTensor, shape (B, 2H) - inter_hb_central_atom_coords : torch.Tensor, shape (B, 2H, 3) - inter_hb_hydrogen_atom_coords : torch.Tensor, shape (B, 2H, 3) - inter_hb_contact_atom_coords : torch.Tensor, shape (B, 2H, 3) - inter_hb_central_atom_frac_coords : torch.Tensor, shape (B, 2H, 3) - inter_hb_hydrogen_atom_frac_coords : torch.Tensor, shape (B, 2H, 3) - inter_hb_contact_atom_frac_coords : torch.Tensor, shape (B, 2H, 3) - inter_hb_length : torch.Tensor, shape (B, 2H) - inter_hb_angle : torch.Tensor, shape (B, 2H) - inter_hb_in_los : torch.Tensor, shape (B, 2H) - inter_hb_symmetry_A : torch.Tensor, shape (B, 2H, 3, 3) - inter_hb_symmetry_T : torch.Tensor, shape (B, 2H, 3) - inter_hb_symmetry_A_inv : torch.Tensor, shape (B, 2H, 3, 3) - inter_hb_symmetry_T_inv : torch.Tensor, shape (B, 2H, 3) - inter_hb_mask : torch.BoolTensor, shape (B, 2H) """ # Move fractional coords to device and get dtype central_atom_frac_coords = central_atom_frac_coords.to(device=device) hydrogen_atom_frac_coords = hydrogen_atom_frac_coords.to(device=device) contact_atom_frac_coords = contact_atom_frac_coords.to(device=device) dtype = central_atom_frac_coords.dtype # Move and cast other tensors central_atom_idx = central_atom_idx.to(device=device) hydrogen_atom_idx = hydrogen_atom_idx.to(device=device) contact_atom_idx = contact_atom_idx.to(device=device) lengths = lengths.to(device=device, dtype=dtype) angles = angles.to(device=device, dtype=dtype) in_los = in_los.to(device=device, dtype=dtype) symmetry_A = symmetry_A.to(device=device, dtype=dtype) symmetry_T = symmetry_T.to(device=device, dtype=dtype) symmetry_A_inv = symmetry_A_inv.to(device=device, dtype=dtype) symmetry_T_inv = symmetry_T_inv.to(device=device, dtype=dtype) cell_matrix = cell_matrix.to(device=device, dtype=dtype) B, C, _ = central_atom_frac_coords.shape # 1) Compute reversed fractional coords central_atom_frac_coords_rev = torch.einsum('bcij,bcj->bci', symmetry_A_inv, contact_atom_frac_coords) + symmetry_T_inv hydrogen_atom_frac_coords_rev = torch.einsum('bcij,bcj->bci', symmetry_A_inv, hydrogen_atom_frac_coords) + symmetry_T_inv contact_atom_frac_coords_rev = torch.einsum('bcij,bcj->bci', symmetry_A_inv, central_atom_frac_coords) + symmetry_T_inv # 2) Concatenate original + reversed coords central_atom_frac_coords_pre = torch.cat([central_atom_frac_coords, central_atom_frac_coords_rev], dim=1) hydrogen_atom_frac_coords_pre = torch.cat([hydrogen_atom_frac_coords, hydrogen_atom_frac_coords_rev], dim=1) contact_atom_frac_coords_pre = torch.cat([contact_atom_frac_coords, contact_atom_frac_coords_rev], dim=1) # 3) Compute Cartesian coords central_atom_coords_pre = torch.matmul(central_atom_frac_coords_pre, cell_matrix) hydrogen_atom_coords_pre = torch.matmul(hydrogen_atom_frac_coords_pre, cell_matrix) contact_atom_coords_pre = torch.matmul(contact_atom_frac_coords_pre, cell_matrix) # 4) Duplicate metrics lengths_pre = torch.cat([lengths, lengths], dim=1) angles_pre = torch.cat([angles, angles], dim=1) in_los_pre = torch.cat([in_los, in_los], dim=1) central_atom_idx_pre = torch.cat([central_atom_idx, contact_atom_idx], dim=1) hydrogen_atom_idx_pre = torch.cat([hydrogen_atom_idx, hydrogen_atom_idx], dim=1) contact_atom_idx_pre = torch.cat([contact_atom_idx, central_atom_idx], dim=1) # 5) Prepare extended symmetry matrices symmetry_A_ext = torch.cat([symmetry_A, symmetry_A_inv], dim=1) symmetry_T_ext = torch.cat([symmetry_T, symmetry_T_inv], dim=1) symmetry_A_inv_ext = torch.cat([symmetry_A_inv, symmetry_A], dim=1) symmetry_T_inv_ext = torch.cat([symmetry_T_inv, symmetry_T], dim=1) # 6) Allocate zero-padded outputs central_atom_frac_coords_ext = torch.zeros_like(central_atom_frac_coords_pre) hydrogen_atom_frac_coords_ext = torch.zeros_like(hydrogen_atom_frac_coords_pre) contact_atom_frac_coords_ext = torch.zeros_like(contact_atom_frac_coords_pre) central_atom_coords_ext = torch.zeros_like(central_atom_coords_pre) hydrogen_atom_coords_ext = torch.zeros_like(hydrogen_atom_coords_pre) contact_atom_coords_ext = torch.zeros_like(contact_atom_coords_pre) lengths_ext = torch.zeros_like(lengths_pre) angles_ext = torch.zeros_like(angles_pre) in_los_ext = torch.zeros_like(in_los_pre) central_atom_idx_ext = torch.zeros_like(central_atom_idx_pre) hydrogen_atom_idx_ext = torch.zeros_like(hydrogen_atom_idx_pre) contact_atom_idx_ext = torch.zeros_like(contact_atom_idx_pre) # 7) Pack valid hbonds first for b in range(B): nC = int((lengths[b] > 0).sum()) if nC == 0: continue orig_end = nC rev_start = nC rev_end = 2 * nC # fractional coords central_atom_frac_coords_ext[b, :orig_end] = central_atom_frac_coords_pre[b, :orig_end] central_atom_frac_coords_ext[b, rev_start:rev_end] = central_atom_frac_coords_pre[b, C:C + nC] hydrogen_atom_frac_coords_ext[b, :orig_end] = hydrogen_atom_frac_coords_pre[b, :orig_end] hydrogen_atom_frac_coords_ext[b, rev_start:rev_end] = hydrogen_atom_frac_coords_pre[b, C:C + nC] contact_atom_frac_coords_ext[b, :orig_end] = contact_atom_frac_coords_pre[b, :orig_end] contact_atom_frac_coords_ext[b, rev_start:rev_end] = contact_atom_frac_coords_pre[b, C:C + nC] # Cartesian coords central_atom_coords_ext[b, :orig_end] = central_atom_coords_pre[b, :orig_end] central_atom_coords_ext[b, rev_start:rev_end] = central_atom_coords_pre[b, C:C + nC] hydrogen_atom_coords_ext[b, :orig_end] = hydrogen_atom_coords_pre[b, :orig_end] hydrogen_atom_coords_ext[b, rev_start:rev_end] = hydrogen_atom_coords_pre[b, C:C + nC] contact_atom_coords_ext[b, :orig_end] = contact_atom_coords_pre[b, :orig_end] contact_atom_coords_ext[b, rev_start:rev_end] = contact_atom_coords_pre[b, C:C + nC] # metrics lengths_ext[b, :orig_end] = lengths_pre[b, :orig_end] lengths_ext[b, rev_start:rev_end] = lengths_pre[b, C:C + nC] angles_ext[b, :orig_end] = angles_pre[b, :orig_end] angles_ext[b, rev_start:rev_end] = angles_pre[b, C:C + nC] in_los_ext[b, :orig_end] = in_los_pre[b, :orig_end] in_los_ext[b, rev_start:rev_end] = in_los_pre[b, C:C + nC] # indices central_atom_idx_ext[b, :orig_end] = central_atom_idx_pre[b, :orig_end] central_atom_idx_ext[b, rev_start:rev_end] = central_atom_idx_pre[b, C:C + nC] hydrogen_atom_idx_ext[b, :orig_end] = hydrogen_atom_idx_pre[b, :orig_end] hydrogen_atom_idx_ext[b, rev_start:rev_end] = hydrogen_atom_idx_pre[b, C:C + nC] contact_atom_idx_ext[b, :orig_end] = contact_atom_idx_pre[b, :orig_end] contact_atom_idx_ext[b, rev_start:rev_end] = contact_atom_idx_pre[b, C:C + nC] # 8) Extend label lists central_atom_labels_ext = [orig + rev for orig, rev in zip(central_atom_label, contact_atom_label)] hydrogen_atom_labels_ext = [orig + rev for orig, rev in zip(hydrogen_atom_label, hydrogen_atom_label)] contact_atom_labels_ext = [orig + rev for orig, rev in zip(contact_atom_label, central_atom_label)] # 9) Calculate the contact mask inter_hb_mask = lengths_ext > 0 return { "inter_hb_central_atom": central_atom_labels_ext, "inter_hb_hydrogen_atom": hydrogen_atom_labels_ext, "inter_hb_contact_atom": contact_atom_labels_ext, "inter_hb_central_atom_idx": central_atom_idx_ext, "inter_hb_hydrogen_atom_idx": hydrogen_atom_idx_ext, "inter_hb_contact_atom_idx": contact_atom_idx_ext, "inter_hb_central_atom_coords": central_atom_coords_ext, "inter_hb_hydrogen_atom_coords": hydrogen_atom_coords_ext, "inter_hb_contact_atom_coords": contact_atom_coords_ext, "inter_hb_central_atom_frac_coords": central_atom_frac_coords_ext, "inter_hb_hydrogen_atom_frac_coords": hydrogen_atom_frac_coords_ext, "inter_hb_contact_atom_frac_coords": contact_atom_frac_coords_ext, "inter_hb_length": lengths_ext, "inter_hb_angle": angles_ext, "inter_hb_in_los": in_los_ext, "inter_hb_symmetry_A": symmetry_A_ext, "inter_hb_symmetry_T": symmetry_T_ext, "inter_hb_symmetry_A_inv": symmetry_A_inv_ext, "inter_hb_symmetry_T_inv": symmetry_T_inv_ext, "inter_hb_mask": inter_hb_mask }
[docs] def compute_contact_is_hbond( cc_central_idx: torch.Tensor, cc_contact_idx: torch.Tensor, cc_mask: torch.Tensor, hb_central_idx: torch.Tensor, hb_hydrogen_idx: torch.Tensor, hb_contact_idx: torch.Tensor, hb_mask: torch.Tensor, device: torch.device ) -> torch.Tensor: """ Flag which contacts correspond to hydrogen bonds. Parameters ---------- cc_central_idx : torch.LongTensor, shape (B, C) Central-atom indices for each intermolecular contact. cc_contact_idx : torch.LongTensor, shape (B, C) Contact-atom indices for each intermolecular contact. cc_mask : torch.BoolTensor, shape (B, C) Validity mask for contacts. hb_central_idx : torch.LongTensor, shape (B, H) Donor-atom indices for each H-bond. hb_hydrogen_idx : torch.LongTensor, shape (B, H) Hydrogen-atom indices for each H-bond. hb_contact_idx : torch.LongTensor, shape (B, H) Acceptor-atom indices for each H-bond. hb_mask : torch.BoolTensor, shape (B, H) Validity mask for H-bonds. device : torch.device Device on which to perform computations. Returns ------- torch.BoolTensor, shape (B, C) True where each contact participates in any hydrogen-bond triplet. """ # move everything onto the same device cc_central_idx = cc_central_idx.to(device) cc_contact_idx = cc_contact_idx.to(device) cc_mask = cc_mask.to(device) hb_central_idx = hb_central_idx.to(device) hb_hydrogen_idx= hb_hydrogen_idx.to(device) hb_contact_idx = hb_contact_idx.to(device) hb_mask = hb_mask.to(device) # shapes B, C_max = cc_central_idx.shape _, H_max = hb_central_idx.shape # prepare for broadcasting cc_c = cc_central_idx.unsqueeze(2) # (B, C_max, 1) cc_p = cc_contact_idx.unsqueeze(2) # (B, C_max, 1) hb_c = hb_central_idx.unsqueeze(1) # (B, 1, H_max) hb_h = hb_hydrogen_idx.unsqueeze(1) # (B, 1, H_max) hb_p = hb_contact_idx.unsqueeze(1) # (B, 1, H_max) hb_m = hb_mask.unsqueeze(1) # (B, 1, H_max) # 1) heavy‐heavy matches hh_match = (cc_c == hb_c) & (cc_p == hb_p) & hb_m hh_flag = hh_match.any(dim=2) & cc_mask # 2) donor‐to‐H matches d2h_match = (cc_c == hb_c) & (cc_p == hb_h) & hb_m d2h_flag = d2h_match.any(dim=2) & cc_mask # 3) H‐to‐acceptor matches h2a_match = (cc_c == hb_h) & (cc_p == hb_p) & hb_m h2a_flag = h2a_match.any(dim=2) & cc_mask # combine all three conditions hb_flags = hh_flag | d2h_flag | h2a_flag return hb_flags
[docs] def compute_contact_fragment_indices_batch( central_atom_idx: torch.LongTensor, # (B, C) atom‐index of the central atom for each contact, -1 if none contact_atom_idx: torch.LongTensor, # (B, C) atom‐index of the other atom in each contact, -1 if none atom_fragment_ids: torch.LongTensor, # (B, N) fragment ID per atom device: torch.device ) -> Tuple[torch.LongTensor, torch.LongTensor]: """ Map contact atom indices to fragment IDs. Parameters ---------- central_atom_idx : torch.LongTensor, shape (B, C) Central-atom indices for each contact, or –1 for padding. contact_atom_idx : torch.LongTensor, shape (B, C) Contact-atom indices for each contact, or –1 for padding. atom_fragment_ids : torch.LongTensor, shape (B, N) Fragment ID assigned to each atom. device : torch.device Device on which to perform computations. Returns ------- dict { 'inter_cc_central_atom_fragment_idx': torch.LongTensor, shape (B, C), 'inter_cc_contact_atom_fragment_idx': torch.LongTensor, shape (B, C) } """ # 1) Move to device and ensure int64 centr = central_atom_idx.to(device).long() # now int64 cont = contact_atom_idx.to(device).long() frag = atom_fragment_ids.to(device).long() # ensure int64 B, N = frag.shape _, C = centr.shape # 2) clamp into valid range [0, N-1] to avoid out-of-bounds zero = torch.zeros((), dtype=torch.long, device=device) max_i = torch.full((), N-1, dtype=torch.long, device=device) centr_clamped = torch.clamp(centr, zero, max_i) cont_clamped = torch.clamp(cont, zero, max_i) # 3) gather fragment IDs central_frag_idx = frag.gather(1, centr_clamped) # (B, C) contact_frag_idx = frag.gather(1, cont_clamped) # (B, C) # 4) restore “–1” where original index was negative central_frag_idx = torch.where(centr < 0, torch.full_like(central_frag_idx, -1), central_frag_idx) contact_frag_idx = torch.where(cont < 0, torch.full_like(contact_frag_idx, -1), contact_frag_idx) return { 'inter_cc_central_atom_fragment_idx': central_frag_idx, 'inter_cc_contact_atom_fragment_idx' : contact_frag_idx }
[docs] def compute_contact_atom_to_central_fragment_com_batch( inter_cc_contact_coords: torch.Tensor, # (B, C_max, 3) inter_cc_contact_frac_coords: torch.Tensor, # (B, C_max, 3) central_frag_idx: torch.Tensor, # (B, Cc), int or long, –1 for padding fragment_com_coords: torch.Tensor, # (F_total, 3) fragment_com_frac_coords: torch.Tensor, # (F_total, 3) struct_ids: torch.Tensor, # (F_total,), long fragment_local_ids: torch.Tensor, # (F_total,), long device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute vectors and distances from contact atoms to central-fragment center of mass. Parameters ---------- inter_cc_contact_coords : torch.Tensor, shape (B, C, 3) Cartesian coordinates of contact atoms. inter_cc_contact_frac_coords : torch.Tensor, shape (B, C, 3) Fractional coordinates of contact atoms. central_frag_idx : torch.LongTensor, shape (B, C) Fragment IDs of central atoms for each contact. fragment_com_coords : torch.Tensor, shape (F_total, 3) Cartesian COM coordinates for all fragments. fragment_com_frac_coords : torch.Tensor, shape (F_total, 3) Fractional COM coordinates for all fragments. struct_ids : torch.Tensor, shape (F_total,) Structure IDs corresponding to each fragment. fragment_local_ids : torch.Tensor, shape (F_total,) Local fragment indices within each structure. device : torch.device Device on which to perform computations. Returns ------- dict { 'inter_cc_contact_atom_to_fragment_com_vec': torch.Tensor, shape (B, C, 3), 'inter_cc_contact_atom_to_fragment_com_frac_vec': torch.Tensor, shape (B, C, 3), 'inter_cc_contact_atom_to_fragment_com_dist': torch.Tensor, shape (B, C), 'inter_cc_contact_atom_to_fragment_com_frac_dist': torch.Tensor, shape (B, C) } """ # 1) move inputs to device and ensure int64 for indices coords_cart = inter_cc_contact_coords.to(device) coords_frac = inter_cc_contact_frac_coords.to(device) centr_idx = central_frag_idx.to(device).long() com_cart = fragment_com_coords.to(device).long().float() # ensure float dtype com_frac = fragment_com_frac_coords.to(device).long().float() s_ids = struct_ids.to(device).long() local_ids = fragment_local_ids.to(device).long() B, C_max, _ = coords_cart.shape F_total = com_cart.shape[0] # 2) pad central‐fragment indices to C_max columns if needed if centr_idx.shape[1] < C_max: pad = torch.full((B, C_max - centr_idx.shape[1]), -1, dtype=torch.long, device=device) centr_full = torch.cat([centr_idx, pad], dim=1) else: centr_full = centr_idx # 3) build a flattened lookup table of size B * n_frags_max # so that lookup_flat[b * n_frags_max + local_id] = global_row_index n_frags_per_struct = torch.bincount(s_ids, minlength=B) # (B,) n_frags_max = int(n_frags_per_struct.max().item()) lookup_flat = torch.full((B * n_frags_max,), -1, dtype=torch.long, device=device) fragment_rows = torch.arange(F_total, device=device, dtype=torch.long) # (F_total,) idx_flat = s_ids * n_frags_max + local_ids # (F_total,) lookup_flat = lookup_flat.scatter(0, idx_flat, fragment_rows) lookup = lookup_flat.view(B, n_frags_max) # (B, n_frags_max) # 4) for each contact, get the global fragment‐COM row index # clamp to [0, n_frags_max-1] to avoid OOB, will mask out invalid entries later local_clamped = centr_full.clamp(min=0, max=n_frags_max - 1) # (B, C_max) batch_idx = torch.arange(B, device=device).unsqueeze(1).expand(B, C_max) # (B, C_max) com_row_idx = lookup[batch_idx, local_clamped] # (B, C_max), values in [-1, F_total-1] # 5) clamp to valid fragment rows for indexing COM coords safe_rows = com_row_idx.clamp(min=0, max=F_total - 1) # (B, C_max) # 6) gather the actual COM coordinates per contact point_cart = com_cart[safe_rows] # (B, C_max, 3) point_frac = com_frac[safe_rows] # (B, C_max, 3) # 7) mask for valid contacts, zero out padded ones valid_mask = (centr_full >= 0) # (B, C_max) mask3 = valid_mask.unsqueeze(-1).to(coords_cart.dtype) # (B, C_max, 1) # 8) compute displacement vectors and distances vecs_cart = (coords_cart - point_cart) * mask3 vecs_frac = (coords_frac - point_frac) * mask3 dists_cart = vecs_cart.norm(dim=-1) # (B, C_max) dists_frac = vecs_frac.norm(dim=-1) # (B, C_max) return { 'inter_cc_contact_atom_to_fragment_com_dist': dists_cart, 'inter_cc_contact_atom_to_fragment_com_frac_dist': dists_frac, 'inter_cc_contact_atom_to_fragment_com_vec': vecs_cart, 'inter_cc_contact_atom_to_fragment_com_frac_vec': vecs_frac }