TypeTrees for Autodiff
What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.
Structure
#![allow(unused)]
fn main() {
TypeTree(Vec<Type>)
Type {
offset: isize, // byte offset (-1 = everywhere)
size: usize, // size in bytes
kind: Kind, // Float, Integer, Pointer, etc.
child: TypeTree // nested structure
}
}
Example: fn compute(x: &f32, data: &[f32]) -> f32
Input 0: x: &f32
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: 0, size: 4, kind: Float, // Single value: use offset 0
child: TypeTree::new()
}])
}])
}
Input 1: data: &[f32]
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float, // -1 = all elements
child: TypeTree::new()
}])
}])
}
Output: f32
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: 0, size: 4, kind: Float, // Single scalar: use offset 0
child: TypeTree::new()
}])
}
Why Needed?
- Enzyme can’t deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata
What Enzyme Does With This Information:
Without TypeTrees:
; Enzyme sees generic LLVM IR:
define float @distance(ptr %p1, ptr %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}
With TypeTrees:
define "enzyme_type"="{[-1]:Float@float}" float @distance(
ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p1,
ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p2
) {
; Enzyme knows exact type layout
; Can generate efficient derivative code directly
}
TypeTrees - Offset and -1 Explained
Type Structure
#![allow(unused)]
fn main() {
Type {
offset: isize, // WHERE this type starts
size: usize, // HOW BIG this type is
kind: Kind, // WHAT KIND of data (Float, Int, Pointer)
child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
}
Offset Values
Regular Offset (0, 4, 8, etc.)
Specific byte position within a structure
#![allow(unused)]
fn main() {
struct Point {
x: f32, // offset 0, size 4
y: f32, // offset 4, size 4
id: i32, // offset 8, size 4
}
}
TypeTree for &Point (internal representation):
#![allow(unused)]
fn main() {
TypeTree(vec![
Type { offset: 0, size: 4, kind: Float }, // x at byte 0
Type { offset: 4, size: 4, kind: Float }, // y at byte 4
Type { offset: 8, size: 4, kind: Integer } // id at byte 8
])
}
Generates LLVM
"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer}"
Offset -1 (Special: “Everywhere”)
Means “this pattern repeats for ALL elements”
Example 1: Direct Array [f32; 100] (no pointer indirection)
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, // ALL positions
size: 4, // each f32 is 4 bytes
kind: Float, // every element is float
}])
}
Generates LLVM: "enzyme_type"="{[-1]:Float@float}"
Example 1b: Array Reference &[f32; 100] (with pointer indirection)
#![allow(unused)]
fn main() {
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // ALL array elements
size: 4, // each f32 is 4 bytes
kind: Float, // every element is float
}])
}])
}
Generates LLVM: "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"
Instead of listing 100 separate Types with offsets 0,4,8,12...396
Example 2: Slice &[i32]
#![allow(unused)]
fn main() {
// Pointer to slice data
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // ALL slice elements
size: 4, // each i32 is 4 bytes
kind: Integer
}])
}])
}
Generates LLVM: "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"
Example 3: Mixed Structure
#![allow(unused)]
fn main() {
struct Container {
header: i64, // offset 0
data: [f32; 1000], // offset 8, but elements use -1
}
}
#![allow(unused)]
fn main() {
TypeTree(vec![
Type { offset: 0, size: 8, kind: Integer }, // header
Type { offset: 8, size: 4000, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float // ALL array elements
}])
}
])
}
Key Distinction: Single Values vs Arrays
Single Values use offset 0 for precision:
&f32has exactly one f32 value at offset 0- More precise than using -1 (“everywhere”)
- Generates:
{[-1]:Pointer, [-1,0]:Float@float}
Arrays use offset -1 for efficiency:
&[f32; 100]has the same pattern repeated 100 times- Using -1 avoids listing 100 separate offsets
- Generates:
{[-1]:Pointer, [-1,-1]:Float@float}