cell_utils module

Module: cell_utils.py

Batch utilities for converting unit-cell parameters to normalized/scaled forms and constructing cell matrices, leveraging PyTorch for GPU acceleration.

Dependencies

torch

cell_utils.compute_scaled_cell(lengths, angles, device=None)[source]

Compute scaled cell parameters for a batch of structures.

Parameters:
  • lengths (torch.Tensor, shape (B, 3)) – Unit-cell lengths [a, b, c] for each structure in the batch.

  • angles (torch.Tensor, shape (B, 3)) – Unit-cell angles [α, β, γ] in degrees for each structure in the batch.

  • device (torch.device, optional) – Target device for output. If None, uses the device of lengths.

Returns:

Scaled cell feature vectors on the specified device, where each row is [1.0, b/a, c/a, α/90, β/90, γ/90].

Return type:

torch.Tensor, shape (B, 6)

cell_utils.compute_cell_matrix_batch(lengths, angles, device=None)[source]

Compute a batch of real-space cell matrices from unit-cell parameters.

Parameters:
  • lengths (torch.Tensor, shape (B, 3)) – Unit-cell lengths [a, b, c] for each structure in the batch.

  • angles (torch.Tensor, shape (B, 3)) – Unit-cell angles [α, β, γ] in degrees for each structure in the batch.

  • device (torch.device, optional) – Target device for output. If None, uses the device of lengths.

Returns:

Real-space cell matrices on the specified device, where each 3×3 matrix has rows:

[a, 0, 0] [b·cos(γ), b·sin(γ), 0] [c·cos(β), c·(cos(α) – cos(β)·cos(γ))/sin(γ), c_z]

with c_z = sqrt(c² – M[2,0]² – M[2,1]²).

Return type:

torch.Tensor, shape (B, 3, 3)

Crystallographic Unit Cell Transformations and Matrix Operations

The cell_utils module provides GPU-accelerated batch processing for crystallographic unit cell transformations, real-space matrix computations, and coordinate system conversions. It handles the mathematical operations required to convert between different representations of crystal structures.

Key Features:

  • Unit cell matrix construction - Convert cell parameters to transformation matrices

  • Coordinate transformations - Fractional ↔ Cartesian conversions

  • Cell parameter scaling - Normalized cell parameter representations

  • Batch processing - Efficient GPU operations for large structure datasets

  • Numerical stability - Robust handling of edge cases and degenerate cells

  • Volume calculations - Unit cell volumes and geometric properties

Core Functions

cell_utils.compute_cell_matrix_batch(lengths, angles, device=None)[source]

Compute a batch of real-space cell matrices from unit-cell parameters.

Parameters:
  • lengths (torch.Tensor, shape (B, 3)) – Unit-cell lengths [a, b, c] for each structure in the batch.

  • angles (torch.Tensor, shape (B, 3)) – Unit-cell angles [α, β, γ] in degrees for each structure in the batch.

  • device (torch.device, optional) – Target device for output. If None, uses the device of lengths.

Returns:

Real-space cell matrices on the specified device, where each 3×3 matrix has rows:

[a, 0, 0] [b·cos(γ), b·sin(γ), 0] [c·cos(β), c·(cos(α) – cos(β)·cos(γ))/sin(γ), c_z]

with c_z = sqrt(c² – M[2,0]² – M[2,1]²).

Return type:

torch.Tensor, shape (B, 3, 3)

Real-Space Unit Cell Matrix Construction

Converts crystallographic unit cell parameters (a, b, c, α, β, γ) into real-space transformation matrices for coordinate conversions and distance calculations.

Mathematical Foundation:

The transformation matrix converts fractional coordinates to Cartesian coordinates:

\[\begin{split}\begin{pmatrix} x \\ y \\ z \end{pmatrix} = \mathbf{M} \begin{pmatrix} u \\ v \\ w \end{pmatrix}\end{split}\]

Where the transformation matrix M is constructed as:

\[\begin{split}\mathbf{M} = \begin{pmatrix} a & 0 & 0 \\ b \cos\gamma & b \sin\gamma & 0 \\ c \cos\beta & c \frac{\cos\alpha - \cos\beta \cos\gamma}{\sin\gamma} & c_z \end{pmatrix}\end{split}\]

With \(c_z = \sqrt{c^2 - c^2\cos^2\beta - \left(c \frac{\cos\alpha - \cos\beta \cos\gamma}{\sin\gamma}\right)^2}\)

Parameters:

  • lengths (torch.Tensor, shape (B, 3)) - Unit cell lengths [a, b, c] in Ångstroms

  • angles (torch.Tensor, shape (B, 3)) - Unit cell angles [α, β, γ] in degrees

  • device (torch.device, optional) - Target computation device

Returns:

  • torch.Tensor, shape (B, 3, 3) - Real-space cell transformation matrices

Physical Interpretation:

  • Column 1: a vector along x-axis

  • Column 2: b vector in xy-plane

  • Column 3: c vector completing right-handed system

Usage Examples:

import torch
from cell_utils import compute_cell_matrix_batch

# Cubic crystal system (simple case)
lengths_cubic = torch.tensor([[5.0, 5.0, 5.0]])  # a=b=c=5Å
angles_cubic = torch.tensor([[90.0, 90.0, 90.0]])  # α=β=γ=90°

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cell_matrix_cubic = compute_cell_matrix_batch(lengths_cubic, angles_cubic, device)
print("Cubic cell matrix:")
print(cell_matrix_cubic[0])
# Expected: 5×identity matrix

# Triclinic crystal system (general case)
lengths_triclinic = torch.tensor([[6.0, 8.0, 10.0]])
angles_triclinic = torch.tensor([[75.0, 85.0, 95.0]])

cell_matrix_triclinic = compute_cell_matrix_batch(lengths_triclinic, angles_triclinic, device)
print("Triclinic cell matrix:")
print(cell_matrix_triclinic[0])

# Calculate unit cell volume using determinant
volume = torch.det(cell_matrix_triclinic)
print(f"Unit cell volume: {volume.item():.2f} Ų")

Coordinate Conversion Applications:

def fractional_to_cartesian(frac_coords, cell_matrix):
    """Convert fractional to Cartesian coordinates."""
    # frac_coords: (B, N, 3), cell_matrix: (B, 3, 3)
    return torch.matmul(frac_coords, cell_matrix.transpose(-2, -1))

def cartesian_to_fractional(cart_coords, cell_matrix):
    """Convert Cartesian to fractional coordinates."""
    cell_matrix_inv = torch.inverse(cell_matrix)
    return torch.matmul(cart_coords, cell_matrix_inv.transpose(-2, -1))

# Example usage
frac_coords = torch.tensor([[[0.25, 0.25, 0.25]]])  # Fractional coordinates
cart_coords = fractional_to_cartesian(frac_coords, cell_matrix_cubic)
print(f"Cartesian coordinates: {cart_coords[0, 0]}")  # [1.25, 1.25, 1.25]

Crystal System Applications:

def analyze_crystal_system(lengths, angles, tolerance=1e-3):
    """Classify crystal system from unit cell parameters."""
    a, b, c = lengths.unbind(dim=-1)
    alpha, beta, gamma = angles.unbind(dim=-1)

    # Check for cubic
    cubic = (torch.abs(a - b) < tolerance) & (torch.abs(b - c) < tolerance) & \
            (torch.abs(alpha - 90) < tolerance) & (torch.abs(beta - 90) < tolerance) & \
            (torch.abs(gamma - 90) < tolerance)

    # Check for tetragonal
    tetragonal = (torch.abs(a - b) < tolerance) & (torch.abs(c - a) > tolerance) & \
                 (torch.abs(alpha - 90) < tolerance) & (torch.abs(beta - 90) < tolerance) & \
                 (torch.abs(gamma - 90) < tolerance)

    # Check for hexagonal
    hexagonal = (torch.abs(a - b) < tolerance) & (torch.abs(c - a) > tolerance) & \
                (torch.abs(alpha - 90) < tolerance) & (torch.abs(beta - 90) < tolerance) & \
                (torch.abs(gamma - 120) < tolerance)

    # Classify
    systems = []
    for i in range(len(lengths)):
        if cubic[i]:
            systems.append("cubic")
        elif tetragonal[i]:
            systems.append("tetragonal")
        elif hexagonal[i]:
            systems.append("hexagonal")
        else:
            systems.append("triclinic")

    return systems

Distance Calculations:

def compute_distances_with_pbc(coord1, coord2, cell_matrix, use_minimum_image=True):
    """Compute distances with periodic boundary conditions."""

    # Convert to fractional coordinates for PBC
    cell_inv = torch.inverse(cell_matrix)
    frac1 = torch.matmul(coord1, cell_inv.transpose(-2, -1))
    frac2 = torch.matmul(coord2, cell_inv.transpose(-2, -1))

    # Compute difference vector
    diff_frac = frac2 - frac1

    if use_minimum_image:
        # Apply minimum image convention
        diff_frac = diff_frac - torch.round(diff_frac)

    # Convert back to Cartesian
    diff_cart = torch.matmul(diff_frac, cell_matrix.transpose(-2, -1))

    # Calculate distance
    distances = torch.norm(diff_cart, dim=-1)
    return distances
cell_utils.compute_scaled_cell(lengths, angles, device=None)[source]

Compute scaled cell parameters for a batch of structures.

Parameters:
  • lengths (torch.Tensor, shape (B, 3)) – Unit-cell lengths [a, b, c] for each structure in the batch.

  • angles (torch.Tensor, shape (B, 3)) – Unit-cell angles [α, β, γ] in degrees for each structure in the batch.

  • device (torch.device, optional) – Target device for output. If None, uses the device of lengths.

Returns:

Scaled cell feature vectors on the specified device, where each row is [1.0, b/a, c/a, α/90, β/90, γ/90].

Return type:

torch.Tensor, shape (B, 6)

Normalized Cell Parameter Representation

Computes scaled unit cell parameters for machine learning applications and comparative analysis by normalizing lengths and angles.

Scaling Scheme:

The scaling transforms unit cell parameters into a dimensionless 6-component feature vector:

\[\vec{f}_{cell} = [1.0, \frac{b}{a}, \frac{c}{a}, \frac{\alpha}{90°}, \frac{\beta}{90°}, \frac{\gamma}{90°}]\]

Parameters:

  • lengths (torch.Tensor, shape (B, 3)) - Unit cell lengths [a, b, c]

  • angles (torch.Tensor, shape (B, 3)) - Unit cell angles [α, β, γ] in degrees

  • device (torch.device, optional) - Target computation device

Returns:

  • torch.Tensor, shape (B, 6) - Scaled cell feature vectors

Advantages of Scaling:

  • Dimensionless representation - Removes length scale dependence

  • Machine learning compatibility - Normalized features for ML models

  • Comparative analysis - Enables cross-dataset comparisons

  • Reduced parameter space - From 6 to 5 independent dimensions

Usage Examples:

# Compare different crystal structures
structures = {
    'diamond': {'lengths': [3.57, 3.57, 3.57], 'angles': [90, 90, 90]},
    'graphite': {'lengths': [2.46, 2.46, 6.71], 'angles': [90, 90, 120]},
    'quartz': {'lengths': [4.91, 4.91, 5.41], 'angles': [90, 90, 120]}
}

for name, params in structures.items():
    lengths = torch.tensor([params['lengths']])
    angles = torch.tensor([params['angles']])
    scaled = compute_scaled_cell(lengths, angles)

    print(f"{name}: {scaled[0].tolist()}")

# Use for machine learning features
def extract_cell_features(crystal_dataset):
    """Extract scaled cell features for ML."""
    features = []
    for structure in crystal_dataset:
        scaled = compute_scaled_cell(
            structure['cell_lengths'],
            structure['cell_angles']
        )
        features.append(scaled)
    return torch.cat(features, dim=0)

Statistical Analysis Applications:

def analyze_cell_distributions(scaled_cells):
    """Analyze statistical distributions of scaled cell parameters."""

    # Component labels
    labels = ['1.0', 'b/a', 'c/a', 'α/90°', 'β/90°', 'γ/90°']

    # Compute statistics
    means = scaled_cells.mean(dim=0)
    stds = scaled_cells.std(dim=0)

    print("Cell parameter statistics:")
    for i, label in enumerate(labels):
        print(f"  {label}: {means[i]:.3f} ± {stds[i]:.3f}")

    # Identify outliers
    z_scores = torch.abs((scaled_cells - means) / stds)
    outliers = z_scores > 3.0  # 3-sigma outliers

    outlier_structures = torch.any(outliers, dim=1)
    print(f"Outlier structures: {outlier_structures.sum().item()}")

    return {
        'means': means,
        'stds': stds,
        'outliers': outlier_structures
    }

Advanced Cell Operations

Volume and Density Calculations

def compute_cell_volumes(cell_matrices):
    """Compute unit cell volumes from transformation matrices."""
    # Volume is the determinant of the cell matrix
    volumes = torch.det(cell_matrices)
    return torch.abs(volumes)  # Ensure positive volumes

def compute_cell_densities(cell_matrices, molecular_weights, z_values):
    """Compute crystal densities from cell parameters."""
    volumes = compute_cell_volumes(cell_matrices)  # Ų

    # Convert to cm³ and compute density
    avogadro = 6.022e23
    volumes_cm3 = volumes * 1e-24  # Ų to cm³

    # Density = (Z × MW) / (V × N_A)
    densities = (z_values * molecular_weights) / (volumes_cm3 * avogadro)
    return densities  # g/cm³

def compute_reciprocal_lattice(cell_matrices):
    """Compute reciprocal lattice vectors."""
    # Reciprocal lattice is 2π × (direct lattice)^(-T)
    volumes = compute_cell_volumes(cell_matrices)

    # B = 2π × A^(-T)
    reciprocal = 2 * torch.pi * torch.inverse(cell_matrices).transpose(-2, -1)
    reciprocal_volumes = (2 * torch.pi)**3 / volumes

    return reciprocal, reciprocal_volumes

Metric Tensor Operations

def compute_metric_tensor(cell_matrices):
    """Compute metric tensor for distance calculations."""
    # G = A^T × A
    metric = torch.matmul(cell_matrices.transpose(-2, -1), cell_matrices)
    return metric

def compute_distance_squared_fractional(frac_coords1, frac_coords2, metric_tensor):
    """Compute squared distances in fractional coordinates using metric tensor."""
    diff = frac_coords2 - frac_coords1

    # d² = Δu^T × G × Δu
    dist_sq = torch.sum(diff.unsqueeze(-2) @ metric_tensor @ diff.unsqueeze(-1), dim=(-2, -1))
    return dist_sq.squeeze()

Crystal System Classification

def classify_crystal_systems_batch(lengths, angles, tolerance=1.0):
    """Classify crystal systems for a batch of structures."""
    a, b, c = lengths.unbind(dim=-1)
    alpha, beta, gamma = angles.unbind(dim=-1)

    # Define classification criteria
    same_ab = torch.abs(a - b) < tolerance
    same_bc = torch.abs(b - c) < tolerance
    same_ac = torch.abs(a - c) < tolerance

    angle_90 = lambda x: torch.abs(x - 90) < tolerance
    angle_120 = lambda x: torch.abs(x - 120) < tolerance

    all_90 = angle_90(alpha) & angle_90(beta) & angle_90(gamma)
    alpha_beta_90 = angle_90(alpha) & angle_90(beta)
    gamma_120 = angle_120(gamma)

    # Classification logic
    cubic = same_ab & same_bc & all_90
    tetragonal = same_ab & ~same_bc & all_90
    orthorhombic = ~same_ab & ~same_bc & ~same_ac & all_90
    hexagonal = same_ab & ~same_bc & alpha_beta_90 & gamma_120
    trigonal = same_ab & same_bc & ~all_90  # Simplified
    monoclinic = ~all_90 & (angle_90(alpha) & angle_90(gamma))  # β ≠ 90°
    triclinic = ~all_90 & ~monoclinic

    # Convert to categorical labels
    systems = torch.zeros(lengths.shape[0], dtype=torch.long)
    systems[cubic] = 0      # Cubic
    systems[tetragonal] = 1 # Tetragonal
    systems[orthorhombic] = 2  # Orthorhombic
    systems[hexagonal] = 3  # Hexagonal
    systems[trigonal] = 4   # Trigonal
    systems[monoclinic] = 5 # Monoclinic
    systems[triclinic] = 6  # Triclinic

    return systems

Performance Optimization

Memory-Efficient Matrix Operations

def batch_cell_operations_memory_efficient(lengths, angles, batch_size=1000):
    """Process large datasets with memory management."""
    n_structures = lengths.shape[0]
    device = lengths.device

    cell_matrices = []
    scaled_cells = []

    for i in range(0, n_structures, batch_size):
        end_idx = min(i + batch_size, n_structures)
        batch_lengths = lengths[i:end_idx]
        batch_angles = angles[i:end_idx]

        # Process batch
        batch_matrices = compute_cell_matrix_batch(batch_lengths, batch_angles, device)
        batch_scaled = compute_scaled_cell(batch_lengths, batch_angles, device)

        # Move to CPU to save GPU memory
        cell_matrices.append(batch_matrices.cpu())
        scaled_cells.append(batch_scaled.cpu())

        # Clear GPU cache
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    return torch.cat(cell_matrices), torch.cat(scaled_cells)

Vectorized Distance Calculations

def compute_pairwise_distances_efficient(coords, cell_matrices, max_distance=15.0):
    """Efficiently compute pairwise distances with distance cutoff."""
    B, N, _ = coords.shape
    device = coords.device

    # Convert to fractional coordinates
    cell_inv = torch.inverse(cell_matrices)
    frac_coords = torch.matmul(coords, cell_inv.transpose(-2, -1))

    # Compute all pairwise differences
    diff = frac_coords.unsqueeze(2) - frac_coords.unsqueeze(1)  # (B, N, N, 3)

    # Apply minimum image convention
    diff = diff - torch.round(diff)

    # Convert back to Cartesian for distance calculation
    diff_cart = torch.matmul(diff, cell_matrices.unsqueeze(1))  # Broadcasting
    distances = torch.norm(diff_cart, dim=-1)

    # Apply distance cutoff
    mask = distances < max_distance
    distances = distances * mask.float()

    return distances, mask

Error Handling and Validation

Cell Parameter Validation

def validate_cell_parameters(lengths, angles):
    """Validate crystallographic cell parameters."""

    # Check for positive lengths
    if torch.any(lengths <= 0):
        raise ValueError("Cell lengths must be positive")

    # Check angle ranges
    if torch.any((angles <= 0) | (angles >= 180)):
        raise ValueError("Cell angles must be between 0° and 180°")

    # Check for physically reasonable values
    if torch.any(lengths > 1000):  # Unreasonably large cell
        raise ValueError("Cell lengths exceed reasonable limits (>1000 Å)")

    if torch.any(lengths < 1):  # Unreasonably small cell
        raise ValueError("Cell lengths below reasonable limits (<1 Å)")

    # Check for degenerate cases
    for i in range(len(lengths)):
        a, b, c = lengths[i]
        alpha, beta, gamma = angles[i]

        # Triangle inequality checks for cell validity
        cos_alpha = torch.cos(alpha * torch.pi / 180)
        cos_beta = torch.cos(beta * torch.pi / 180)
        cos_gamma = torch.cos(gamma * torch.pi / 180)

        # Validate that the cell can be constructed
        discriminant = 1 + 2*cos_alpha*cos_beta*cos_gamma - cos_alpha**2 - cos_beta**2 - cos_gamma**2

        if discriminant <= 0:
            raise ValueError(f"Invalid cell parameters for structure {i}: cannot construct valid unit cell")

Debugging and Diagnostics

def diagnose_cell_matrices(cell_matrices, expected_volumes=None):
    """Diagnose cell matrix computations for quality control."""

    print("Cell Matrix Diagnostics:")

    # Check for valid matrices
    determinants = torch.det(cell_matrices)
    volumes = torch.abs(determinants)

    print(f"  Volume range: {volumes.min():.2f} - {volumes.max():.2f} Ų")
    print(f"  Mean volume: {volumes.mean():.2f} ± {volumes.std():.2f} Ų")

    # Check for negative determinants (left-handed systems)
    negative_det = determinants < 0
    if negative_det.any():
        print(f"  Warning: {negative_det.sum()} structures have negative determinants")

    # Check for degenerate cells (zero volume)
    zero_volume = volumes < 1e-6
    if zero_volume.any():
        print(f"  Error: {zero_volume.sum()} structures have near-zero volumes")

    # Check orthogonality
    for i in range(min(5, len(cell_matrices))):  # Check first 5 structures
        M = cell_matrices[i]

        # Compute angles between vectors
        a_vec = M[0]
        b_vec = M[1]
        c_vec = M[2]

        ab_angle = torch.acos(torch.dot(a_vec, b_vec) / (torch.norm(a_vec) * torch.norm(b_vec)))
        ac_angle = torch.acos(torch.dot(a_vec, c_vec) / (torch.norm(a_vec) * torch.norm(c_vec)))
        bc_angle = torch.acos(torch.dot(b_vec, c_vec) / (torch.norm(b_vec) * torch.norm(c_vec)))

        ab_deg = ab_angle * 180 / torch.pi
        ac_deg = ac_angle * 180 / torch.pi
        bc_deg = bc_angle * 180 / torch.pi

        print(f"  Structure {i} vector angles: {ab_deg:.1f}°, {ac_deg:.1f}°, {bc_deg:.1f}°")

Integration Examples

Complete Crystal Analysis Pipeline

def comprehensive_cell_analysis(crystal_dataset):
    """Perform complete crystallographic cell analysis."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Extract cell parameters
    lengths = torch.stack([s['cell_lengths'] for s in crystal_dataset])
    angles = torch.stack([s['cell_angles'] for s in crystal_dataset])

    # Validate parameters
    validate_cell_parameters(lengths, angles)

    # Compute transformations
    cell_matrices = compute_cell_matrix_batch(lengths, angles, device)
    scaled_cells = compute_scaled_cell(lengths, angles, device)

    # Compute derived properties
    volumes = compute_cell_volumes(cell_matrices)
    crystal_systems = classify_crystal_systems_batch(lengths, angles)

    # Analyze distributions
    stats = analyze_cell_distributions(scaled_cells)

    return {
        'cell_matrices': cell_matrices,
        'scaled_cells': scaled_cells,
        'volumes': volumes,
        'crystal_systems': crystal_systems,
        'statistics': stats
    }

Cross-References

Related CSA Modules:

External Dependencies:

  • PyTorch - Tensor operations and GPU acceleration

  • NumPy - Mathematical functions and array operations

Scientific References:

  • Giacovazzo, C. et al. “Fundamentals of Crystallography” Oxford University Press (2011)

  • Massa, W. “Crystal Structure Determination” Springer (2004)

  • Shmueli, U. (ed.) “International Tables for Crystallography, Volume B” Kluwer Academic Publishers (2001)