Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

The std::autodiff module in Rust allows differentiable programming:

#![feature(autodiff)]
use std::autodiff::*;

// f(x) = x * x, f'(x) = 2.0 * x
// bar therefore returns (x * x, 2.0 * x)
#[autodiff_reverse(bar, Active, Active)]
fn foo(x: f32) -> f32 { x * x }

fn main() {
    assert_eq!(bar(3.0, 1.0), (9.0, 6.0));
    assert_eq!(bar(4.0, 1.0), (16.0, 8.0));
}

The detailed documentation for the std::autodiff module is available at std::autodiff.

Differentiable programming is used in various fields like numerical computing, solid mechanics, computational chemistry, fluid dynamics or for Neural Network training via Backpropagation, ODE solver, differentiable rendering, quantum computing, and climate simulations.

std::autodiff is currently based on Enzyme, an LLVM based tool for automatic differentation. There are three main reasons for relying on compiler based autodiff:

  • Usability: Current autodiff crates do not support normal Rust programs. They either enforce a custom DSL, require the usage of library provided types (instead of e.g. slices or arrays), or are limited to scalar functions. Compiler based autodiff allows users to write normal Rust code, including arrays, slices, user-defined structs and enums, control flow, and more.
  • Performance: Most existing Rust autodiff approaches have a constant overhead per operation. This can easily be amortized for ML applications which have few expensive operations on large tensors. It is, however, often unacceptable for applications in the HPC or scientific computing field. By working on (optimized) LLVM IR, compiler based autodiff can achieve significantly better performance in those cases.
  • Features: By operating on such a low level and sharing the implementation with other LLVM based languages, we can leverage the large amount of work already done in the Enzyme project. For example, we can support Rust code calling MPI routines, or GPU code, including libraries like CuBLAS.