JAX-ReaxFF: A Gradient Based Framework for Extremely Fast Optimization of Reactive Force Fields


Molecular dynamics (MD) simulations ease the study of the chemistry of interest. While classical models that governs the interaction of the atoms lack reactivity, the quantum mechanics based methods increase the computational cost drastically. ReaxFF fills the gap between these two ends of the spectrum by allowing bond breaking and dynamic charges. To achieve realistic simulations using ReaxFF, the model parameters need to be carefully tuned based on the training data created using more accurate but expensive methods. The current optimization methods focus on black-box optimization methods such as genetic algorithms (GAs), Monte-Carlo methods and covariance matrix adaptation evolutionary strategy (CMA-ES). Due to the stochastic behavior of these methods, the training requires hundreds of thousands of error evaluations for complex training tasks and each error evaluation usually involves energy minimization of the many molecules in the training data. In this work, we propose a novel approach which takes advantage of the modern tools developed for machine learning to improve the training efficiency of the force field development for ReaxFF. By calculating the gradients of the loss function using JAX library developed by Google, we are able to use well studied local optimization methods such as the limited Broyden–Fletcher–Goldfarb–Shanno (LBFGS) and Sequential Least Squares Programming (SLSQP) method. To further decrease the training time, we skip the energy minimization of the molecules during the local optimization since the parameters of the model are not likely to change drastically during a local search. JAX allows us to easily parallelize the error evolution by compiling the code for CPUs, GPUs or TPUs. With the help of the modern accelerators and the gradient information, we are able to decrease the training time from days to minutes.