This blog post originally appeared on Medium

Popular general-purpose auto-differentiation frameworks like PyTorch or TensorFlow are very capable, and, for the most part, there is little need for writing something more specialized.

Nevertheless, I have recently started writing my own autodiff package. This blog post describes what I’ve learned along the way. Think of this as a poor-man’s version of a Julia Evans blog post.

Note that there are many blog posts describing the mechanics of autodifferentiation much better than I could, so I skip the explanations here. Additionally, there are several other interesting posts and articles on building type-safe neural networks constructs, so while my library follows very similar patterns (statically-typed graphs and dependent types), I don’t dwell on the type system angle too much.

Finally, In case you’d like to jump straight to the code, the end result is here, together with an obligatory neural-network based FizzBuzz solution.


There are a couple of reasons why I wanted to have my own autodiff/backprop framework, rather than use PyTorch or TensorFlow.

Motivated by the desire for a lightweight solution that works well for recommender (and possibly NLP) models, I wrote down a list of design constraints.

There is also a short list of things I don’t want, or don’t care enough about to add for now:

Given the list of requirements (and non-requirements), some design decisions follow naturally.

When writing libraries, I often think of the API I want to be able to expose and work back from there. In this case, I want to write something like the following:

let slope = Parameter::new(1.0);
let intercept = Parameter::new(0.0);
let x = Input::new(3.0);
let y = Input::new(2.0 * 3.0 + 1.0);
let loss = (y  (slope * x + intercept)).square();

and have it just work.

Preliminaries done, we can move on to the fun part: figuring out how to implement the graph.

Representing the graph

What sort of data structure do we choose to represent the graph? I looked at two alternatives.

     Vector-based                              Graph-based

   +---------------+                       +-----------------+
   |               |                       |                 |
+-->     A * B     <--+                +--->      A * B      <--+
|  |               |  |                |   |                 |  |
|  +---------------+  |                |   +-----------------+  |
|  |               |  |                |                        |
|  |       B       +--+                |                        |
|  |               |                   |                        |
|  +---------------+            +------+---------+    +---------+-------+
|  |               |            |                |    |                 |
+--+       A       |            |       A        |    |        B        |
   |               |            |                |    |                 |
   +---------------+            +----------------+    +-----------------+

There are a couple of advantages to the vector-based approach.

But there are also disadvantages.

It’s not clear what sort of object we are storing in the node vector. All of the nodes are different types (of different sizes), and vectors are homogeneously typed. Rust offers two solutions to this problem, but neither is fully satisfactory.

The first is enums (sum types; ADTs; tagged unions). We define a Node type to be the union of all possible node types, and store that in the node vector. This way, everything has the same type. We still need to dispatch the node’s methods from the enclosing Node type to the contained inner node. This can be done via pattern matching (a switch statement on the tags of the union type); with Rust’s support for pattern matching and macros, writing the necessary code is a breeze.

However, this imposes a runtime cost. Every time we use a node, we need to go through the switch statement to resolve the inner type. In principle, optimizing compilers will compile such code to jump tables. In practice, the assembly generated for the dispatch code in my experiments was simply a linear scan over all the possibilities, imposing a dispatch cost that is linear in the number of concrete node types the framework supports. Worse still, the compiler is reluctant to inline both the switch itself and the called functions. The former is bad because it increases branch prediction misses, the latter increases function call overhead. (This problem is exacerbated by the recent branch-prediction attacks: it’s likely that compiler mitigations will make indirect instructions like these substantially more expensive.)

The final disadvantage of using sum types for the node vector is that it results in a closed system (akin to Scala’s sealed traits): downstream users of the library cannot add new node types.

The alternative is to use Rust’s runtime polymorphism mechanism, trait objects. Trait objects are a way of abstracting over the concrete type of an object: instead of storing structs inline, we hide them behind a pointer to their data and a table of their methods. When calling a method, we jump to the vtable, find the function, and execute it. Using trait objects, we put these fat pointers into the node vector instead of nodes themselves.

This solution, however, introduces exactly the kind of indirection we set out to avoid in the first place. Additionally, it completely defeats the compiler’s efforts at inlinining: the function to be called is not known until runtime.

What about the graph-based design? Here, each node is placed in its own location in memory, and can refer to its ancestors via references. Because each node can be re-used an arbitrary number of times, I use Rust’s equivalent of a shared_ptr from C++, the Rc<T>.

One immediate disadvantage of this approach is that it blurs the ownership structure of the graph, making cloning and serialization/deserialization difficult: because nodes can be re-used, naive cloning/deserialization will result in multiple copies of the same nodes being created.

The second disadvantage is the lack of a readily-available topological ordering: both forward and backward passes have to be done recursively, and care has to be taken to avoid re-computing the values of shared subgraphs.

The advantage of using the graph representation is the types of any node’s parents are known at compile time. Every node is (recursively) generic over the types of its parents: adding two InputNodes will produce an AddNode<InputNode, InputNode>. Adding that to another input node will produce an AddNode<AddNode<InputNode, InputNode>, InputNode> and so on. This gives me static method dispatch and the potential for inlining, in addition to a design that plays much more nicely with the type system.


Using some informal benchmarks, the graph-based approach is approximately 30% faster than the vector-based approach. The end result can run a full epoch of a BPR learning-to-rank factorization model on the Movielens 100K dataset (code) in under 20 milliseconds on my puny dual-core laptop, and should scale linearly with more cores.

This takes advantage of a number of optimizations in addition to the underlying graph structure.

There are a number of ways to make the computation faster still.

  1. At the moment, the code doesn’t do any subgraph result caching in the forward pass: if a node is used twice in the forward pass, all of the computations it depends on will be done twice. This can easily be solved via a simple topological sort algorithm, marking the nodes as evaluated once they have evaluated their value. (Addendum: this turns out to be incredibly important for recurrent neural networks, so is now implemented.)
  2. Similarly, gradients are passed straight to parameter nodes in the backward pass. If a node is used more than once, this means that unnecessary work is done in passing its gradients down one at a time. Accumulating all the gradients and only recursing once will save on that work. (Addendum: as above.)
  3. There is some unnecessary copying of inputs; making better use of references when possible should yield some small performance gains.

What’s next

I have written (and continue to maintain) a number of open-source Python ML packages. The models are written by hand in Cython, and while they perform well, extending them is tricky. This is due partly to Cython’s limitations, and partly due to the effort required for manual derivation of update rules.

I hope that this library (or some variation thereof) will make that task easier, and allow me to more easily implement complex models and release them as standalone Python packages. I’ll report back on how I fare.


Turns out that the graph representation is a little bit problematic when applied to recurrent neural networks: at every step of the recurrence, the complexity of the resulting types increases, leading to rather baroque types:

ate<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::Par
ameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, layers::recu
rrent::LSTMCellHidden<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHid
den<layers::recurrent::LSTMCellState<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<nodes::InputNode, nodes::Input
Node, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nod
es::ParameterNode>>, layers::recurrent::LSTMCellHidden<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>
>, nodes::IndexNode<nodes::ParameterNode>>, nodes::ParameterNode>>>>

Needless to say, after a couple of recurrent steps the compiler gives up. This can be resolved by implementing a fused LSTM cell, rather than assembling it from simpler operations, or opting for selective type erasure via trait objects. So far, I’ve used the second solution: the output values of each LSTM cell have their concrete types erased by boxing them up in a trait object. Still, it illustrates the dangers of relying on complex type system constructs.