Source code for vachoppy.vibration

"""
vachoppy.vibration
==================

Provides the `Vibration` class for calculating the characteristic atomic
vibrational frequency from a molecular dynamics trajectory.

The primary purpose of this module is to determine a suitable time scale for
coarse-graining the main trajectory analysis. The calculated mean vibrational
frequency is the inverse of the `t_interval` parameter used throughout `vachoppy`.
This allows for a clear, data-driven distinction between rapid atomic vibrations
and slower, diffusive hopping events.

Main Components
---------------
- **Vibration**: A class that analyzes atomic displacements within a short
  trajectory segment. It uses a statistical approach and a Fast Fourier Transform
  (FFT) to determine the mean vibrational frequency.

Typical Usage
-------------
This class is often used to automatically estimate the `t_interval` parameter
before running a full diffusion analysis.

.. code-block:: python

    from vachoppy.core import Site
    from vachoppy.vibration import Vibration

    # 1. First, set up the site information
    site_info = Site(path_structure="path/to/POSCAR", symbol="O")

    # 2. Initialize the Vibration class with a trajectory
    vib_analyzer = Vibration(
        path_traj="path/to/TRAJ_O.h5",
        site=site_info
    )

    # 3. Run the frequency calculation
    vib_analyzer.calculate()

    # 4. Access the result to determine a suitable t_interval
    if vib_analyzer.mean_frequency > 0:
        estimated_t_interval = 1 / vib_analyzer.mean_frequency
        print(f"Estimated t_interval: {estimated_t_interval:.3f} ps")

    # This estimated_t_interval can then be passed to a Calculator object.
"""

from __future__ import annotations

__all__ =['Vibration']

import os
import h5py
import json
import itertools
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from scipy.stats import norm
# from scipy.spatial.distance import cdist
from joblib import Parallel, delayed

from vachoppy.utils import monitor_performance


# ============================================
#   Helper functions for parallel processing
# ============================================

def _helper_distance_pbc(coord1, coord2, lattice):
    """(Helper) Calculates PBC-aware distance between two fractional coordinates."""
    displacement_frac = coord1 - coord2
    displacement_frac -= np.round(displacement_frac)
    return np.linalg.norm(np.dot(displacement_frac, lattice))

def _helper_get_segment_trajectory(trajectory, jump_detection_radius, lattice):
    """(Helper) Segments a single atom's trajectory into vibrational periods."""
    segments = []; start_index = 0
    if len(trajectory) < 10: return segments
    for i in range(1, len(trajectory)):
        segment_center = np.mean(trajectory[start_index:i], axis=0)
        distance = _helper_distance_pbc(trajectory[i], segment_center, lattice)
        if distance > jump_detection_radius:
            segment = trajectory[start_index:i]
            if len(segment) > 10: segments.append(segment)
            start_index = i
    segment_final = trajectory[start_index:]
    if len(segment_final) > 10: segments.append(segment_final)
    return segments

def _worker_get_displacements(args):
    """[Parallel Worker] Calculates displacements for a single atom's trajectory."""
    atom_traj_frac, jump_detection_radius, lattice = args
    segments = _helper_get_segment_trajectory(atom_traj_frac, jump_detection_radius, lattice)
    displacements_cart = []
    for seg in segments:
        center = np.mean(seg, axis=0)
        displacement = seg - center
        displacement -= np.round(displacement)
        displacements_cart.append(np.dot(displacement, lattice))
    return displacements_cart

def _worker_assign_sites(args):
    """[Parallel Worker] Assigns atoms to the nearest lattice site for a single timestep."""
    positions_at_t, site_positions_frac, site_radius, lattice = args
    positions_at_t, site_positions_frac, site_radius, lattice = args
    # distance_matrix = cdist(positions_at_t, site_positions_frac,
    #                         lambda u, v: _helper_distance_pbc(u, v, lattice))
    all_displacements_frac = positions_at_t[:, np.newaxis, :] - site_positions_frac[np.newaxis, :, :]
    all_displacements_frac -= np.round(all_displacements_frac)
    all_displacements_cart = np.einsum('ijk,kl->ijl', all_displacements_frac, lattice)
    distance_matrix = np.linalg.norm(all_displacements_cart, axis=2)
    
    closest_site_indices = np.argmin(distance_matrix, axis=1)
    min_distances = np.min(distance_matrix, axis=1)
    assignments_at_t = np.full(positions_at_t.shape[0], -1, dtype=int)
    assigned_mask = min_distances < site_radius
    assignments_at_t[assigned_mask] = closest_site_indices[assigned_mask]
    return assignments_at_t

def _worker_get_frequencies(args):
    """[Parallel Worker] Calculates vibrational frequencies for a single atom."""
    atom_traj_frac, atom_assignments, site_positions_frac, dt_s, lattice = args
    frequencies = []
    
    first_valid_step_arr = np.where(atom_assignments > -1)[0]
    if len(first_valid_step_arr) == 0: return []
    first_valid_step = first_valid_step_arr[0]
    
    filtered_traj = atom_traj_frac[first_valid_step:]
    filtered_assign = atom_assignments[first_valid_step:]
    
    jump_indices = np.where(filtered_assign[:-1] != filtered_assign[1:])[0] + 1
    seg_starts = np.insert(jump_indices, 0, 0)
    seg_ends = np.append(jump_indices, len(filtered_assign))
    
    for start, end in zip(seg_starts, seg_ends):
        assigned_site_id = filtered_assign[start]
        if assigned_site_id == -1: continue
        segment = filtered_traj[start:end]
        if len(segment) < 20 : continue
        
        site_center = site_positions_frac[assigned_site_id]
        displacement = segment - site_center
        displacement -= np.round(displacement)
        displacement_cart = np.dot(displacement, lattice)
        
        for axis in range(3):
            disp_axis = displacement_cart[:, axis]
            n = len(disp_axis); hann = np.hanning(n)
            power = np.abs(np.fft.fft(disp_axis * hann))**2
            freq_hz = np.fft.fftfreq(n, d=dt_s)
            mask = freq_hz > 0
            if not np.any(mask): continue
            freqs, power = freq_hz[mask], power[mask]
            frequencies.append(freqs[np.argmax(power)] / 1e12)
    return frequencies

# ============================================

[docs] class Vibration: """Calculates atomic vibrational frequencies from a molecular dynamics trajectory. This class analyzes a short segment of a trajectory to determine the characteristic vibrational frequency of atoms. The workflow involves: 1. Determining a data-driven "site radius" based on atomic displacements. 2. Assigning atoms to their nearest lattice sites. 3. Segmenting individual atom trajectories into periods of stable vibration. 4. Calculating the frequency spectrum for all segments using FFT. The main entry point is the `.calculate()` method. The results can be visualized with the `.plot_*()` methods or inspected via attributes. Args: path_traj (str): Path to the HDF5 trajectory file. site (Site): An initialized `Site` object containing lattice site information. sampling_size (int, optional): Number of initial trajectory frames to use for the analysis. Defaults to 5000. filter_high_freq (bool, optional): If True, filters out high-frequency outliers using the IQR method. Defaults to True. verbose (bool, optional): Verbosity flag. Defaults to True. Attributes: mean_frequency (float): The mean of all calculated vibrational frequencies in THz. frequencies (list[float]): A list containing all individual vibrational frequencies calculated from the trajectory segments. displacements (numpy.ndarray): A flattened array of all measured atomic displacements (Å) during vibrational periods. site_radius (float): The calculated site radius (2σ of displacements) in Å, used for atom-to-site assignment. Raises: FileNotFoundError: If the `path_traj` file does not exist. ValueError: If the HDF5 file is missing required data or metadata. IOError: If the HDF5 file cannot be read. """ def __init__(self, path_traj: str, site: Site, sampling_size: int = 5000, filter_high_freq : bool = True, verbose: bool = True): self.path_traj = path_traj self.site = site self.filter_high_freq = filter_high_freq self.verbose = verbose self._validate_traj(self.path_traj) self.site_positions = np.array([s['coord'] for s in self.site.lattice_sites]) self.dt = None self.symbol = None self.lattice = None self.sampling_size = None self.positions = None # fractional coordinates self._read_traj(sampling_size) self.displacements = None self.mu_displacements = None self._helper_distance_pbc = None self.site_radius = None self.frequencies = None self.mean_frequency = None def _validate_traj(self, path_traj: str) -> None: """ Validates the structure and content of the HDF5 trajectory file. This method checks for the correct file extension, file existence, and the presence of required datasets ('positions', 'forces') and metadata attributes ('symbol', 'nsw', 'dt', 'temperature', 'atom_counts', 'lattice'). Args: path_traj (str): The file path to validate. Raises: ValueError: If the file extension is not '.h5' or if required metadata or datasets are missing. FileNotFoundError: If the trajectory file does not exist. IOError: If the file cannot be read as an HDF5 file. """ if not path_traj.endswith('.h5'): raise ValueError(f"Error: Trajectory file must have a .h5 extension, but got '{path_traj}'.") if not os.path.isfile(path_traj): raise FileNotFoundError(f"Error: Input file '{path_traj}' not found.") try: with h5py.File(path_traj, "r") as f: required_datasets = ["positions", "forces"] for dataset in required_datasets: if dataset not in f: raise ValueError(f"Error: Required dataset '{dataset}' not found in '{path_traj}'.") metadata_str = f.attrs.get("metadata") if not metadata_str: raise ValueError(f"Error: Required attribute 'metadata' not found in '{path_traj}'.") cond = json.loads(metadata_str) required_keys = ["symbol", "nsw", "dt", "temperature", "atom_counts", "lattice"] for key in required_keys: if key not in cond: raise ValueError(f"Error: Required key '{key}' not found in metadata of '{path_traj}'.") except (IOError, OSError) as e: raise IOError(f"Error: Failed to read '{path_traj}' as an HDF5 file. Reason: {e}") def _read_traj(self, sampling_size: int) -> None: """Reads metadata and a chunk of trajectory data from the HDF5 file.""" with h5py.File(self.path_traj, 'r') as f: cond = json.loads(f.attrs['metadata']) self.dt = cond.get('dt') self.symbol = cond.get('symbol') self.lattice = np.array(cond.get('lattice'), dtype=np.float64) num_frames = cond.get('nsw') self.sampling_size = min(sampling_size, num_frames) self.positions = f['positions'][:self.sampling_size].astype(np.float64) def _filter_frequencies_iqr(self, frequencies: list) -> list: """Filters high-frequency outliers from a list of frequencies using the IQR method.""" freq_array = np.array(frequencies) q1 = np.percentile(freq_array, 25) q3 = np.percentile(freq_array, 75) iqr = q3 - q1 upper_bound = q3 + 1.5 * iqr filtered_frequencies = freq_array[freq_array < upper_bound] removed_count = len(frequencies) - len(filtered_frequencies) if self.verbose: print("="*52) print(f" High-Frequency Filtering Results (IQR)") print("="*52) print(f" - Cutoff Frequency : {upper_bound:.2f} THz") print(f" - Removed Outlier Frequencies : {removed_count} (out of {len(frequencies)})") return filtered_frequencies.tolist() def _get_site_radius(self, n_jobs: int = -1, jump_detection_radius: float = 1.0) -> None: """Calculates the vibrational amplitude and determines the site radius.""" n_atoms = self.positions.shape[1] tasks = [(self.positions[:, i, :], jump_detection_radius, self.lattice) for i in range(n_atoms)] results = Parallel(n_jobs=n_jobs, verbose=0)( delayed(_worker_get_displacements)(task) for task in tqdm(tasks, desc=f"Compute Displacement", bar_format='{l_bar}{bar:30}{r_bar}', ascii=True, disable=not self.verbose) ) all_displacements_cart = list(itertools.chain.from_iterable(results)) if not all_displacements_cart: raise ValueError("Could not find any valid vibrational segments to analyze.") self.displacements = np.concatenate(all_displacements_cart).flatten() self.mu_displacements, self.sigma_displacements = norm.fit(self.displacements) self.site_radius = 2 * self.sigma_displacements
[docs] def plot_displacements(self, bins: int = 50, disp: bool = True, save: bool = True, title: str | None = "Displacement Distribution", filename: str = "displacement.png", dpi : int = 300) -> None: """Plots a histogram of the atomic displacements with a Gaussian fit. Args: bins (int, optional): Number of bins for the histogram. Defaults to 50. disp (bool, optional): If True, displays the plot. Defaults to True. save (bool, optional): If True, saves the plot to a file. Defaults to True. title (str | None, optional): A custom title for the plot. filename (str, optional): Filename for the saved plot. dpi (int, optional): Resolution for the saved figure. Raises: AttributeError: If `.calculate()` has not been run yet. """ if self.displacements is None: raise AttributeError("Displacement data not available. Please run the .calculate() method first.") plt.figure(figsize=(10, 6)) plt.hist( self.displacements, bins=bins, density=True, color='skyblue', alpha=0.7, edgecolor='black', label=f"{self.symbol} Displacements" ) xmin, xmax = plt.xlim() x = np.linspace(xmin, xmax, 200) p = norm.pdf(x, self.mu_displacements, self.sigma_displacements) plt.plot(x, p, 'r-', linewidth=2, label="Gaussian Fit") plt.title(title) plt.xlabel("Displacement (Ang)", fontsize=12) plt.ylabel("Probability Density", fontsize=12) plt.legend() plt.grid(True, linestyle='--') if save: plt.savefig(filename, dpi=dpi) if disp: plt.show()
[docs] def plot_frequencies(self, bins: int = 50, disp: bool = True, save: bool = True, title: str | None = "Frequency Distribution", filename: str = "frequency.png", dpi: int = 300) -> None: """Plots a histogram of the calculated vibrational frequencies. Args: bins (int, optional): Number of bins for the histogram. Defaults to 50. disp (bool, optional): If True, displays the plot. Defaults to True. save (bool, optional): If True, saves the plot to a file. Defaults to True. title (str | None, optional): A custom title for the plot. filename (str, optional): Filename for the saved plot. dpi (int,optional): Resolution for the saved figure. Raises: AttributeError: If `.calculate()` has not been run yet. """ if self.frequencies is None: raise AttributeError("Frequency data not available. Please run the .calculate() method first.") plt.figure(figsize=(10, 6)) plt.hist( self.frequencies, bins=bins, density=True, color='mediumpurple', alpha=0.7, edgecolor='black', label=f'{self.symbol} Frequencies' ) plt.axvline( self.mean_frequency, color='r', linestyle='--', linewidth=2, label=f'Mean: {self.mean_frequency:.2f} THz' ) plt.title(title) plt.xlabel('Frequency (THz)', fontsize=12) plt.ylabel('Probability Density', fontsize=12) plt.legend() plt.grid(True, linestyle='--') if save: plt.savefig(filename, dpi=dpi) if disp: plt.show()
[docs] @monitor_performance def calculate(self, n_jobs: int = -1, jump_detection_radius: float = 1.0, verbose: bool | None = None) -> None: """Executes the full vibrational frequency analysis workflow. This is the main method to run the analysis. It performs the following steps: 1. Calculates the average vibrational amplitude to determine a site radius. 2. Assigns each atom to its nearest lattice site at each timestep. 3. Segments the trajectory for each atom based on these site assignments. 4. Calculates the vibrational frequencies for all segments using FFT. 5. Optionally filters high-frequency outliers. The results are stored in the object's attributes (e.g., `self.mean_frequency`). Args: n_jobs (int, optional): The number of CPU cores for parallel processing. -1 uses all available cores. Defaults to -1. jump_detection_radius (float, optional): The radius (Å) used to distinguish between vibrations and jumps during the initial amplitude estimation. Defaults to 1.0. verbose (bool | None, optional): Overrides the class-level verbosity for this method run. If None, the class-level setting is used. Defaults to None. """ if verbose is None: verbose = self.verbose # Site radius estimation self._get_site_radius(n_jobs=n_jobs, jump_detection_radius=jump_detection_radius) # Site assignment n_steps, n_atoms, _ = self.positions.shape dt_s = self.dt * 1e-15 site_assign_tasks = [(self.positions[i], self.site_positions, self.site_radius, self.lattice) for i in range(n_steps)] site_assignments_list = Parallel(n_jobs=n_jobs, verbose=0)( delayed(_worker_assign_sites)(task) for task in tqdm(site_assign_tasks, desc=f'Capture Vibrations ', bar_format='{l_bar}{bar:30}{r_bar}', ascii=True, disable=not self.verbose) ) site_assignments = np.array(site_assignments_list) # Frequency calculation freq_tasks = [(self.positions[:, i, :], site_assignments[:, i], self.site_positions, dt_s, self.lattice) for i in range(n_atoms)] results = Parallel(n_jobs=n_jobs, verbose=0)( delayed(_worker_get_frequencies)(task) for task in tqdm(freq_tasks, desc=f"Compute Frequenciy ", bar_format='{l_bar}{bar:30}{r_bar}', ascii=True, disable=not self.verbose) ) frequencies = list(itertools.chain.from_iterable(results)) if self.verbose: print("") if self.filter_high_freq: frequencies = self._filter_frequencies_iqr(frequencies) self.frequencies = frequencies if self.frequencies: self.mean_frequency = np.mean(self.frequencies) else: self.mean_frequency = 0 if self.verbose: self.summary()
[docs] def summary(self): """Prints a formatted summary of the vibrational analysis results. The summary includes the mean displacement, determined site radius, and the final mean vibrational frequency. """ print("="*52) print(f" Vibrational Analysis Results Summary") print("="*52) print(f" - Mean Vibrational Amplitude (σ) : {self.sigma_displacements:.3f} Ang") print(f" - Determined Site Radius (2 x σ) : {self.site_radius:.3f} Ang") print(f" - Total Vibrational Frequencies : {len(self.frequencies)} found") print(f" - Mean Vibrational Frequency : {self.mean_frequency:.3f} THz") print("="*52 + "\n")