fragment_utils.identify_rigid_fragments_batch

fragment_utils.identify_rigid_fragments_batch(atom_mask, bond_atom1, bond_atom2, bond_is_rotatable, device)[source]

Identify rigid fragments in a batch via iterative label propagation on GPU.

Parameters:
  • atom_mask (torch.BoolTensor of shape (B, N)) – True for real atoms, False for padding slots.

  • bond_atom1 (torch.LongTensor of shape (B, M)) – First‐atom indices for each bond (–1 for padding).

  • bond_atom2 (torch.LongTensor of shape (B, M)) – Second‐atom indices for each bond (–1 for padding).

  • bond_is_rotatable (torch.BoolTensor of shape (B, M)) – True if the bond is rotatable; non-rotatable bonds join fragments.

  • device (torch.device) – Device to perform computation on (e.g. ‘cuda’).

Returns:

frag_id – Fragment ID for each atom (0..K−1 for real atoms, −1 for padding).

Return type:

torch.LongTensor of shape (B, N)