fragment_utils.identify_rigid_fragments_batch
- fragment_utils.identify_rigid_fragments_batch(atom_mask, bond_atom1, bond_atom2, bond_is_rotatable, device)[source]
Identify rigid fragments in a batch via iterative label propagation on GPU.
- Parameters:
atom_mask (torch.BoolTensor of shape (B, N)) – True for real atoms, False for padding slots.
bond_atom1 (torch.LongTensor of shape (B, M)) – First‐atom indices for each bond (–1 for padding).
bond_atom2 (torch.LongTensor of shape (B, M)) – Second‐atom indices for each bond (–1 for padding).
bond_is_rotatable (torch.BoolTensor of shape (B, M)) – True if the bond is rotatable; non-rotatable bonds join fragments.
device (torch.device) – Device to perform computation on (e.g. ‘cuda’).
- Returns:
frag_id – Fragment ID for each atom (0..K−1 for real atoms, −1 for padding).
- Return type:
torch.LongTensor of shape (B, N)