Source code for tad_multicharge.model.eeq

# This file is part of tad-multicharge.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Electronegativity equilibration charge model
============================================

Implementation of the electronegativity equlibration model for obtaining
atomic partial charges as well as atom-resolved electrostatic energies.

Example
-------
>>> import torch
>>> from tad_multicharge import eeq
>>> numbers = torch.tensor([7, 7, 1, 1, 1, 1, 1, 1])
>>> positions = torch.tensor([
...     [-2.98334550857544, -0.08808205276728, +0.00000000000000],
...     [+2.98334550857544, +0.08808205276728, +0.00000000000000],
...     [-4.07920360565186, +0.25775116682053, +1.52985656261444],
...     [-1.60526800155640, +1.24380481243134, +0.00000000000000],
...     [-4.07920360565186, +0.25775116682053, -1.52985656261444],
...     [+4.07920360565186, -0.25775116682053, -1.52985656261444],
...     [+1.60526800155640, -1.24380481243134, +0.00000000000000],
...     [+4.07920360565186, -0.25775116682053, +1.52985656261444],
... ])
>>> total_charge = torch.tensor(0.0)
>>> cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
>>> eeq_model = eeq.EEQModel.param2019()
>>> qat, energy = eeq_model.solve(
...     numbers, positions, total_charge, cn, return_energy=True
... )
>>> print(torch.sum(energy, -1))
tensor(-0.1750)
>>> print(qat)
tensor([-0.8347, -0.8347,  0.2731,  0.2886,  0.2731,  0.2731,  0.2886,  0.2731])
"""

from __future__ import annotations

import math
from typing import Literal, overload

import torch
from tad_mctc import storch
from tad_mctc.batch import real_atoms, real_pairs
from tad_mctc.ncoord import coordination_number, erf_count

from ..param import defaults, eeq2019
from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype
from .base import ChargeModel

__all__ = ["EEQModel", "get_charges"]


[docs] class EEQModel(ChargeModel): """ Electronegativity equilibration charge model published in - E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, 150, 154122. DOI: `10.1063/1.5090222 <https://dx.doi.org/10.1063/1.5090222>`__ """
[docs] @classmethod def param2019( cls, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> EEQModel: """ Create the EEQ model from the standard (2019) parametrization. Parameters ---------- device : torch.device | None, optional PyTorch device for the tensors. Defaults to `None`. dtype : torch.dtype | None, optional PyTorch floating point type for the tensors. Defaults to `None`. Returns ------- EEQModel Instance of the EEQ charge model class. """ dd: DD = { "device": device, "dtype": dtype if dtype is not None else get_default_dtype(), } return cls( eeq2019.chi.to(**dd), eeq2019.kcn.to(**dd), eeq2019.eta.to(**dd), eeq2019.rad.to(**dd), **dd, )
@overload def solve( self, numbers: Tensor, positions: Tensor, total_charge: Tensor, cn: Tensor, return_energy: Literal[False] = False, solve_mode: Literal["schur", "linear"] = "schur", ) -> Tensor: ... @overload def solve( self, numbers: Tensor, positions: Tensor, total_charge: Tensor, cn: Tensor, return_energy: Literal[True], solve_mode: Literal["schur", "linear"] = "schur", ) -> tuple[Tensor, Tensor]: ... @overload def solve( self, numbers: Tensor, positions: Tensor, total_charge: Tensor, cn: Tensor, return_energy: bool, solve_mode: Literal["schur", "linear"] = "schur", ) -> Tensor | tuple[Tensor, Tensor]: ...
[docs] def solve( self, numbers: Tensor, positions: Tensor, total_charge: Tensor, cn: Tensor, return_energy: bool = False, solve_mode: Literal["schur", "linear"] = "schur", ) -> Tensor | tuple[Tensor, Tensor]: """ Solve the electronegativity equilibration for the partial charges minimizing the electrostatic energy. Parameters ---------- numbers : Tensor Atomic numbers of all atoms in the system. (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates of the atoms in system (shape: ``(..., nat, 3)``). total_charge : Tensor Total charge of the system. cn : Tensor Coordination numbers for all atoms in the system. return_energy : bool, optional Return the EEQ energy as well. Defaults to `False`. solve_mode : Literal["schur", "linear"], optional Choose the solution method for the linear system. - ``"schur"``: Use Schur-complement based method with Cholesky factorization (default, recommended). - ``"linear"``: Solve the full bordered linear system directly. Less stable and slower for large systems. Defaults to ``"schur"``. Returns ------- Tensor | (Tensor, Tensor) Tensor of electrostatic charges or tuple of partial charges and electrostatic energies if ``return_energy=True``. Example ------- >>> import torch >>> from tad_multicharge import eeq >>> numbers = torch.tensor([7, 1, 1, 1]) >>> positions = torch.tensor([ ... [+0.00000000000000, +0.00000000000000, -0.54524837997150], ... [-0.88451840382282, +1.53203081565085, +0.18174945999050], ... [-0.88451840382282, -1.53203081565085, +0.18174945999050], ... [+1.76903680764564, +0.00000000000000, +0.18174945999050], ... ], requires_grad=True) >>> total_charge = torch.tensor(0.0, requires_grad=True) >>> cn = torch.tensor([3.0, 1.0, 1.0, 1.0]) >>> eeq_model = eeq.EEQModel.param2019() >>> e = eeq_model.solve(numbers, positions, total_charge, cn)[0] >>> energy = torch.sum(e, -1) >>> energy.backward() >>> print(positions.grad) tensor([[-9.3132e-09, 7.4506e-09, -4.8064e-02], [-1.2595e-02, 2.1816e-02, 1.6021e-02], [-1.2595e-02, -2.1816e-02, 1.6021e-02], [ 2.5191e-02, -6.9849e-10, 1.6021e-02]]) >>> print(total_charge.grad) tensor(0.6312) """ if self.device != positions.device: name = self.__class__.__name__ raise RuntimeError( f"All tensors of '{name}' must be on the same device!\n" f"Use `{name}.param2019(device=device)` to correctly set it." ) if self.dtype != positions.dtype: name = self.__class__.__name__ raise RuntimeError( f"All tensors of '{name}' must have the same dtype!\n" f"Use `{name}.param2019(dtype=dtype)` to correctly set it." ) total_charge = torch.atleast_1d(total_charge) # Attempt reshaping to proper batch shape: (n,) -> (n, 1) if total_charge.ndim == 1: if len(total_charge) != 1: total_charge = total_charge.view(-1, 1) if total_charge.ndim != numbers.ndim: raise ValueError( f"Total charge must have the same number of dimensions as " f"the atomic numbers tensor. Got\n" f"- atomic numbers: {numbers.shape}\n" f"- total charge: {total_charge.shape}" ) eps = torch.tensor(torch.finfo(positions.dtype).eps, **self.dd) zero = torch.tensor(0.0, **self.dd) stop = torch.sqrt(torch.tensor(2.0 / math.pi, **self.dd)) # sqrt(2/pi) real = real_atoms(numbers) mask = real_pairs(numbers, mask_diagonal=True) distances = torch.where( mask, storch.cdist(positions, positions, p=2), eps, ) diagonal = mask.new_zeros(mask.shape) diagonal.diagonal(dim1=-2, dim2=-1).fill_(True) ############# # Build RHS # ############# cc = torch.where( real, -self.chi[numbers] + storch.sqrt(cn) * self.kcn[numbers], zero, ) ################## # Build A matrix # ################## # radii rad = self.rad[numbers] rads = rad.unsqueeze(-1) ** 2 + rad.unsqueeze(-2) ** 2 gamma = torch.where(mask, 1.0 / storch.sqrt(rads), zero) # hardness eta = torch.where( real, self.eta[numbers] + stop / rad, torch.tensor(1.0, **self.dd), ) coulomb = torch.where( diagonal, eta.unsqueeze(-1), torch.where( mask, torch.erf(distances * gamma) / distances, zero, ), ) ############## # Constraint # ############## # Build 'ones' vector for the constraint constraint = torch.where( real, torch.ones(numbers.shape, **self.dd), torch.zeros(numbers.shape, **self.dd), ) ####################### # Solve linear system # ####################### if solve_mode == "schur": return self._solve_schur( cc, constraint, coulomb, total_charge, return_energy ) if solve_mode == "linear": return self._solve_linear( cc, constraint, coulomb, total_charge, return_energy ) raise ValueError(f"Unknown EEQ solve mode '{solve_mode}'!")
def _solve_linear( self, cc: Tensor, constraint: Tensor, coulomb: Tensor, total_charge: Tensor, return_energy: bool, ) -> Tensor | tuple[Tensor, Tensor]: """ Solve the EEQ linear system via standard linear solver. Parameters ---------- cc : Tensor Right-hand side vector. constraint : Tensor Constraint vector (ones for real atoms, zeros else). coulomb : Tensor Coulomb interaction matrix. total_charge : Tensor Total charge of the system. return_energy : bool Whether to return the electrostatic energy as well. Returns ------- Tensor | (Tensor, Tensor) Partial charges or tuple of partial charges and energies. """ zeros = torch.zeros(cc.shape[:-1], **self.dd) rhs = torch.concat((cc, total_charge), dim=-1) # | Coulomb Constraint | # | Constraint 0 | matrix = torch.concat( ( torch.concat((coulomb, constraint.unsqueeze(-1)), dim=-1), torch.concat( (constraint, zeros.unsqueeze(-1)), dim=-1 ).unsqueeze(-2), ), dim=-2, ) x = torch.linalg.solve(matrix, rhs) # do not compute energy unless specifically requested if return_energy is False: return x[..., :-1] # remove constraint for energy calculation _x = x[..., :-1] _m = matrix[..., :-1, :-1] _rhs = rhs[..., :-1] # E_scalar = 0.5 * x^T @ A @ x - b @ x^T # E_vector = x * (0.5 * A @ x - b) _e = _x * (0.5 * torch.einsum("...ij,...j->...i", _m, _x) - _rhs) return _x, _e def _solve_schur( self, cc: Tensor, constraint: Tensor, coulomb: Tensor, total_charge: Tensor, return_energy: bool, ) -> Tensor | tuple[Tensor, Tensor]: """ Solve the EEQ linear system via Schur-complement method. [ A C ][ q ] = [ b ] [ C^T 0 ][ m ] [ Q ] q = A^{-1}(b - C m) m = (C^T A^{-1} b - Q) / (C^T A^{-1} C) Parameters ---------- cc : Tensor Right-hand side vector. constraint : Tensor Constraint vector (ones for real atoms, zeros else). coulomb : Tensor Coulomb interaction matrix. total_charge : Tensor Total charge of the system. return_energy : bool Whether to return the electrostatic energy as well. Returns ------- Tensor | (Tensor, Tensor) Partial charges or tuple of partial charges and energies. """ # Solve A X = B for two RHS at once: B = [b, 1]. # Stack along last dimension giving `(..., nat, 2)`. B = torch.stack((cc, constraint), dim=-1) # Factor once via Cholesky: A = L L^T # (fast & stable since A is SPD; bordered systems is indefinite) L = torch.linalg.cholesky(coulomb) # (..., nat, nat) # Solve A X = B for both RHS at once using the Cholesky factor # X[..., :, 0] = A^{-1} b ; X[..., :, 1] = A^{-1} C X = torch.cholesky_solve(B, L) # (..., nat, 2) z = X[..., :, 0] # A^{-1} b, (..., nat) y = X[..., :, 1] # A^{-1} C, (..., nat) # m = (C^T z - Q) / (C^T y) ; shape (..., 1) num = (constraint * z).sum(dim=-1, keepdim=True) - total_charge den = (constraint * y).sum(dim=-1, keepdim=True) m = num / den # q = z - y * m (broadcast m over the `nat` dimension) q = z - y * m # (..., nat) # Do not compute energy unless specifically requested if return_energy is False: return q # E_scalar = 0.5 * x^T @ A @ x - b @ x^T # E_vector = x * (0.5 * A @ x - b) e = q * (0.5 * torch.einsum("...ij,...j->...i", coulomb, q) - cc) return q, e
@overload def get_eeq( numbers: Tensor, positions: Tensor, chrg: Tensor, *, counting_function: CountingFunction = erf_count, rcov: Tensor | None = None, cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF, cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX, kcn: Tensor | float | int = defaults.EEQ_KCN, return_energy: Literal[False], **kwargs: Any, ) -> Tensor: ... @overload def get_eeq( numbers: Tensor, positions: Tensor, chrg: Tensor, *, counting_function: CountingFunction = erf_count, rcov: Tensor | None = None, cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF, cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX, kcn: Tensor | float | int = defaults.EEQ_KCN, return_energy: Literal[True], **kwargs: Any, ) -> tuple[Tensor, Tensor]: ... def get_eeq( numbers: Tensor, positions: Tensor, chrg: Tensor, *, counting_function: CountingFunction = erf_count, rcov: Tensor | None = None, cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF, cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX, kcn: Tensor | float | int = defaults.EEQ_KCN, return_energy: bool = False, solve_mode: Literal["schur", "linear"] = "schur", **kwargs: Any, ) -> Tensor | tuple[Tensor, Tensor]: """ Calculate atomic EEQ charges and energies. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). chrg : Tensor Total charge of system. counting_function : CountingFunction Calculate weight for pairs. Defaults to :func:`tad_mctc.ncoord.erf_count`. rcov : Tensor | None, optional Covalent radii for each species. Defaults to ``None``. cutoff : Tensor | float | int | None, optional Real-space cutoff. Defaults to :data:`tad_multicharge.defaults.CUTOFF_EEQ`. cn_max : Tensor | float | int | None, optional Maximum coordination number. Defaults to :data:`tad_multicharge.defaults.CUTOFF_EEQ_MAX`. kcn : Tensor | float | int, optional Steepness of the counting function. return_energy : bool, optional Return the EEQ energy as well. Defaults to ``False``. solve_mode : Literal["schur", "linear"], optional Choose the solution method for the linear system. - ``"schur"``: Use Schur-complement based method with Cholesky factorization (default, recommended). - ``"linear"``: Solve the full bordered linear system directly. Less stable and slower for large systems. Defaults to ``"schur"``. **kwargs : Any Additional keyword arguments for EEQ CN calculation. Returns ------- (Tensor, Tensor) Tuple of electrostatic energies and partial charges. """ eeq = EEQModel.param2019(device=positions.device, dtype=positions.dtype) cn = coordination_number( numbers, positions, counting_function=counting_function, rcov=rcov, cutoff=cutoff, cn_max=cn_max, kcn=kcn, **kwargs, ) return eeq.solve( numbers, positions, chrg, cn, return_energy=return_energy, solve_mode=solve_mode, )
[docs] def get_charges( numbers: Tensor, positions: Tensor, chrg: Tensor, cutoff: Tensor | None = None, ) -> Tensor: """ Calculate atomic EEQ charges. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). chrg : Tensor Total charge of system. cutoff : Tensor | None, optional Real-space cutoff. Defaults to ``None``. Returns ------- Tensor Atomic charges. """ return get_eeq(numbers, positions, chrg, cutoff=cutoff, return_energy=False)
def get_energy( numbers: Tensor, positions: Tensor, chrg: Tensor, cutoff: Tensor | None = None, ) -> Tensor: """ Calculate atomic EEQ energies. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system of shape ``(..., nat)``. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). chrg : Tensor Total charge of system. cutoff : Tensor | None, optional Real-space cutoff. Defaults to ``None``. Returns ------- Tensor Atomic energies. """ return get_eeq( numbers, positions, chrg, cutoff=cutoff, return_energy=True, )[1]