Source code for cell_utils

"""
Module: cell_utils.py

Batch utilities for converting unit-cell parameters to normalized/scaled forms
and constructing cell matrices, leveraging PyTorch for GPU acceleration.

Dependencies
------------
torch
"""
import math
import torch
from typing import Optional

__all__ = [
    'compute_scaled_cell',
    'compute_cell_matrix_batch',
    ]


[docs] def compute_scaled_cell( lengths: torch.Tensor, angles: torch.Tensor, device: Optional[torch.device] = None ) -> torch.Tensor: """ Compute scaled cell parameters for a batch of structures. Parameters ---------- lengths : torch.Tensor, shape (B, 3) Unit-cell lengths [a, b, c] for each structure in the batch. angles : torch.Tensor, shape (B, 3) Unit-cell angles [α, β, γ] in degrees for each structure in the batch. device : torch.device, optional Target device for output. If None, uses the device of `lengths`. Returns ------- torch.Tensor, shape (B, 6) Scaled cell feature vectors on the specified device, where each row is [1.0, b/a, c/a, α/90, β/90, γ/90]. """ if device is None: device = lengths.device # ensure on correct device lengths = lengths.to(device) angles = angles.to(device) / 90.0 # unpack dims: (B,1) each a = lengths[:, 0:1] b = lengths[:, 1:2] c = lengths[:, 2:3] # build scaled vector ones = torch.ones_like(a) scaled = torch.cat([ones, b / a, c / a, angles], dim=1) # (B,6) return scaled
[docs] def compute_cell_matrix_batch( lengths: torch.Tensor, angles: torch.Tensor, device: Optional[torch.device] = None ) -> torch.Tensor: """ Compute a batch of real-space cell matrices from unit-cell parameters. Parameters ---------- lengths : torch.Tensor, shape (B, 3) Unit-cell lengths [a, b, c] for each structure in the batch. angles : torch.Tensor, shape (B, 3) Unit-cell angles [α, β, γ] in degrees for each structure in the batch. device : torch.device, optional Target device for output. If None, uses the device of `lengths`. Returns ------- torch.Tensor, shape (B, 3, 3) Real-space cell matrices on the specified device, where each 3×3 matrix has rows: [a, 0, 0] [b·cos(γ), b·sin(γ), 0] [c·cos(β), c·(cos(α) – cos(β)·cos(γ))/sin(γ), c_z] with c_z = sqrt(c² – M[2,0]² – M[2,1]²). """ if device is None: device = lengths.device # move to device and convert angles to radians lengths = lengths.to(device) A = angles.to(device) * (math.pi / 180.0) # unpack components a, b, c = lengths.unbind(dim=1) # each (B,) cos_a, cos_b, cos_g = torch.cos(A).unbind(dim=1) sin_g = torch.sin(A[:, 2]) # (B,) # compute safe c_z cz_sq = c**2 - (c * cos_b)**2 - (c * (cos_a - cos_b * cos_g) / sin_g)**2 cz = torch.sqrt(torch.clamp(cz_sq, min=0.0)) # (B,) # build rows row0 = torch.stack([a, torch.zeros_like(a), torch.zeros_like(a)], dim=1) row1 = torch.stack([b * cos_g, b * sin_g, torch.zeros_like(b)], dim=1) row2 = torch.stack([c * cos_b, c * (cos_a - cos_b * cos_g) / sin_g, cz], dim=1) # stack into (B,3,3) cell = torch.stack([row0, row1, row2], dim=1) return cell