"""
Module: symmetry_utils.py
Functions to parse and apply crystallographic symmetry operators, augmenting
parameter dictionaries with rotation matrices and translation vectors.
Dependencies
------------
torch
"""
import re
from fractions import Fraction
from typing import Dict, Any, List, Tuple
import torch
__all__ = [
'parse_sym_op',
'invert_sym_op',
'add_symmetry_matrices',
'add_inter_cc_symmetry',
'add_inter_hb_symmetry',
]
[docs]
def parse_sym_op(sym: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parse a crystallographic symmetry-operator string into rotation matrix and translation vector.
Parameters
----------
sym : str
Symmetry operator string (e.g. 'x+1/2, -y+1/2, z').
Returns
-------
A : torch.Tensor, shape (3, 3)
Integer rotation matrix.
t : torch.Tensor, shape (3,)
Float translation vector.
Raises
------
ValueError
If the input string does not contain exactly three comma-separated expressions.
"""
# remove whitespace, split into 3 axis expressions
exprs = sym.replace(' ', '').split(',')
if len(exprs) != 3:
raise ValueError(f"Invalid symmetry operator: {sym!r}")
A = torch.zeros((3, 3), dtype=torch.int64)
t = torch.zeros(3, dtype=torch.float32)
axis_map = {'x': 0, 'y': 1, 'z': 2}
for i, expr in enumerate(exprs):
# 1) find ±x, ±y, ±z terms
for axis, col in axis_map.items():
for m in re.finditer(r'([+-]?)(?:1)?' + axis, expr):
sign = -1 if m.group(1) == '-' else 1
A[i, col] = sign
# 2) strip out x/y/z pieces
const_str = re.sub(r'[+-]?\d*\.?\d*?[xyz]', '', expr)
if not const_str:
continue
# 3) extract numeric tokens
nums = re.findall(r'[+-]?\d+\.\d+|[+-]?\d+/\d+|[+-]?\d+', const_str)
for num in nums:
num = num.rstrip('+-')
if not re.match(r'[+-]?\d', num):
continue
# parse number (fraction or float)
try:
if '/' in num:
val = float(Fraction(num))
else:
val = float(Fraction(num).limit_denominator())
except Exception:
try:
val = float(num)
except Exception:
val = 0.0
t[i] += val
return A, t
[docs]
def invert_sym_op(A: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the inverse of a symmetry operator defined by A and t.
Parameters
----------
A : torch.Tensor, shape (3, 3)
Rotation matrix.
t : torch.Tensor, shape (3,)
Translation vector.
Returns
-------
A_inv : torch.Tensor, shape (3, 3)
Inverse rotation (transpose of A).
t_inv : torch.Tensor, shape (3,)
Inverse translation (-A_inv @ t).
Raises
------
ValueError
If A is not shape (3,3) or t is not shape (3,).
"""
if A.shape != (3, 3) or t.shape != (3,):
raise ValueError("A must be (3,3) and t must be (3,)")
A_inv = A.t().contiguous()
t_inv = -A_inv.to(torch.float32) @ t
return A_inv, t_inv
[docs]
def add_symmetry_matrices(
parameters: Dict[str, Any],
sym_key: str,
coords_key: str,
device: torch.device = None
) -> None:
"""
Add parsed symmetry matrices and their inverses to a parameter dictionary.
Parameters
----------
parameters : Dict[str, Any]
Dictionary containing symmetry strings and coordinate tensors.
sym_key : str
Key for List[List[str]] of symmetry operator strings in `parameters`.
coords_key : str
Key for torch.Tensor of shape (B, N, …) holding coordinates.
device : torch.device, optional
Device on which to store the resulting tensors. If None, inferred.
Modifies
--------
parameters : adds the following keys
'{sym_key}_A' : torch.Tensor, shape (B, N, 3, 3)
'{sym_key}_T' : torch.Tensor, shape (B, N, 3)
'{sym_key}_A_inv' : torch.Tensor, shape (B, N, 3, 3)
'{sym_key}_T_inv' : torch.Tensor, shape (B, N, 3)
Raises
------
ValueError
If input tensor shapes or list lengths are invalid.
"""
# determine device
if device is None:
device = parameters.get('__device__', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
# get batch size and max_contacts from coords tensor
coords = parameters[coords_key]
if not isinstance(coords, torch.Tensor) or coords.dim() < 2:
raise ValueError(f"Expected torch.Tensor for '{coords_key}' with dim >=2")
B, N = coords.shape[0], coords.shape[1]
# collect symmetry op strings
sym_ops_batch: List[List[str]] = parameters[sym_key]
if len(sym_ops_batch) != B:
raise ValueError(f"Expected {B} lists in '{sym_key}', got {len(sym_ops_batch)}")
# unique preserving order
flat = [op for sub in sym_ops_batch for op in sub]
unique = list(dict.fromkeys(flat))
# parse unique ops
A_list = []
T_list = []
for op in unique:
A, t = parse_sym_op(op)
A_list.append(A)
T_list.append(t)
# stack
A_all = torch.stack(A_list, dim=0).to(device)
T_all = torch.stack(T_list, dim=0).to(device)
# compute inverses
A_inv_all = A_all.permute(0, 2, 1).contiguous()
# T_inv = -A_inv @ T
T_inv_all = -(A_inv_all.to(torch.float32) @ T_all.unsqueeze(-1)).squeeze(-1)
# build index map
idx_map = {op: idx for idx, op in enumerate(unique)}
idx = torch.zeros((B, N), dtype=torch.long, device=device)
for i, row in enumerate(sym_ops_batch):
for j, op in enumerate(row):
idx[i, j] = idx_map[op]
# gather per-contact tensors
A_per = A_all[idx] # (B,N,3,3)
T_per = T_all[idx] # (B,N,3)
Ainv_per = A_inv_all[idx] # (B,N,3,3)
Tinv_per = T_inv_all[idx] # (B,N,3)
# assign back
parameters[f'{sym_key}_A'] = A_per
parameters[f'{sym_key}_T'] = T_per
parameters[f'{sym_key}_A_inv'] = Ainv_per
parameters[f'{sym_key}_T_inv'] = Tinv_per
[docs]
def add_inter_cc_symmetry(parameters: Dict[str, Any], device: torch.device = None) -> None:
"""
Shortcut to add inter_cc symmetry matrices.
Parameters
----------
parameters : Dict[str, Any]
Must contain 'inter_cc_symmetry' and 'inter_cc_central_atom_coords'.
device : torch.device, optional
Device for computation.
"""
add_symmetry_matrices(parameters, 'inter_cc_symmetry', 'inter_cc_central_atom_coords', device)
[docs]
def add_inter_hb_symmetry(parameters: Dict[str, Any], device: torch.device = None) -> None:
"""
Shortcut to add inter_hb symmetry matrices.
Parameters
----------
parameters : Dict[str, Any]
Must contain 'inter_hb_symmetry' and 'inter_hb_central_atom_coords'.
device : torch.device, optional
Device for computation.
"""
add_symmetry_matrices(parameters, 'inter_hb_symmetry', 'inter_hb_central_atom_coords', device)