Thinking About Automatic Differentiation in Fun New Ways

Recently Tadej Ciglarič came up with a new pattern to do automatic differentiation in Stan Math’s C++ library and I thought it would be nice to have a little blog post talking about it. For the purposes of blogging I’m going to be simplifying a few things, but the main take here is that it’s a lot simpler write new functions that utilize Stan’s reverse mode autodiff (and this pattern can be a good bit faster in some cases!)

In Stan, when we need to calculate gradients we use the Stan Math libraries automatic differentiation functionality. Automatic Differentiation can seem kind of scary but it works by combining a pretty simple set of patterns. What it boils down to is

  1. Calculating a function z = f(x, y)‘s output (forward pass)
  2. Storing the input and output into a memory arena
  3. Adding a callback to a callback stack that calculates the derivative of the function
  4. Later, when we want to take the gradient, calling stan::math::grad(z)that goes up and calls all the callbacks in the callback stack and accumulates the gradients through all the outputs and inputs. (reverse pass)

For a much more detailed writing on autodiff check out the Stan math paper. that goes over this scheme in great detail. But for this we just need to know that there’s a forward pass that calculates the function, stores data, registers a callback, and then later we can call a separate function to do the reverse pass and accumulate the gradient calculations for all of our variables.

Under the Hood

For what we are talking about in this blog post, we need to lay out a few types and classes we will be working with. For an example, say we have two inputs we want to add together and then take their gradients like

#include <stan/math.hpp>
using stan::math::var
using stan::math::grad;
var b(10);
var c(12);
var a = b + c;
grad(a);
std::cout << "\na adjoint: " << a.adj();
std::cout << "\nb adjoint: " << b.adj();
std::cout << "\nc adjoint: " << c.adj();

The var type is Stan’s smart pointer class we use to store a pointer to the information used in both the forward and reverse pass.

struct var {
  vari* vi_;
  double val() {
    return vi_->value;
  }
  double adj() {
    return vi_->adjoint;
  }
};

The vari is the actual object that holds the values from the forward pass, the adjoints (gradient calcs) used in the reverse pass, and a virtual method we override later to add new gradient calculations.

struct vari {
  double value; // forward pass values
  double adjoint{0}; // reverse pass values
  virtual chain() = 0; // virtual method for gradient calculations
  vari(double x) : value(x) {
    // Add this pointer to the callback stack
    stack_allocator.push_back(this);
  }
  // Allocates itself onto the stack allator
  static inline void* operator new(size_t nbytes) noexcept {
    return stack_allocator.alloc(nbytes);
  }
}

You can see that the memory for that vari sits on a little allocator called stack_allocator. Every time we make a new vari a pointer to that instance of a vari is pushed onto a vector. I’m not going to write out the full code for stack_allocator but below is a pseudocode example.

struct stack_allocator {
  std::vector<vari*> vari_vec; // holds callbacks
  memory_arena mem_; // manages memory
  // allocate new memory
  void* alloc(size_t nbytes) {
    return mem_.alloc(nbytes);
  }
  // push a vari onto the stack
  void push_back(vari* x) {
    vari_vec.push_back(x);
  }
};

Example of Previous Pattern

Let’s take all the above and write out the code for our example where we want to do a = b + c. Previously the code for operator+() looked like this:

class add_vv_vari : public op_vv_vari {
 public:
  // Setup the forward pass
  add_vv_vari(vari* avi, vari* bvi)
      : op_vv_vari(avi->value + bvi->value, avi, bvi) {}
  // Setup the callback 
  void chain() override {
    if (unlikely(is_nan(this->value))) {
      avi_->adjoint = NOT_A_NUMBER;
      bvi_->adjoint = NOT_A_NUMBER;
    } else {
      avi_->adjoint += this->adjoint;
      bvi_->adjoint += this->adjoint;
    }
  }
};

inline var operator+(var a, var b) {
  return var(new internal::add_vv_vari(a.vi_, b.vi_));
}

—————————————–
Note: I’m just going to be going through (complaining about) the above code in this section so if you want to skip to the new stuff you can go to the next section.
—————————————–
There’s a lot going on here that can look confusing. The class add_vv_vari inherits from an op_vv_vari that _then_ inherits from vari. This allows add_vv_vari to inherit the vari constructor to put itself on the callback stack and it’s operator new to get new memory from the stack allocator. It’s chain() method overrides the vari classes virtual chain() function so that when we call grad() we will actually be going through and calling the appropriate gradient calculation. Then in the actual operator+ function all we do is return a var that allocates a new instance of add_vv_vari with a and b‘s inner vari pointer as inputs.

To be honest, kind of confusing. There’s a lot of interconnected pieces, we have to make a new class, dealing with a bit of memory management, and this is only _one_ operator+()! We still need to write add_vd_vari() for when the left hand side or right hand side is not a double since we will have a separate derivative calculation to make!

The main problems that makes the above kind of weird is working directly with pointers and that we use inheritance to combine both the data and callback in one class. What we want to do is break up the data and the callback for calculating the reverse pass. It would be nice to do some template magic to get rid of a lot of this, but sadly the stack_allocator class has to store all the callbacks as a std::vector<base_class*>. Since the compiler does not know at compile time what is going to go into that vector we have to use dynamic polymorphism (virtual functions) to call the appropriate callback.

Creating A New Autodiff Scheme

One important fact is that since a var is just a smart pointer wrapper around a pointer to a vari it is safe to directly take copies of it since we’ll only be copying a pointer to memory that’s already been allocated elsewhere. So it actually would be safe to write the following

class add_vv_vari : public op_vv_vari {
 public:
  add_vv_vari(var a, var b)
      : op_vv_vari(a.val() + b.val(), a, b) {}
  void chain() override {
    if (unlikely(is_nan(this->val_))) {
      a.adj() = NOT_A_NUMBER;
      b.adj() = NOT_A_NUMBER;
    } else {
      a.adj() += this->adj_;
      b.adj() += this->adj_;
    }
  }
};

Though now we have add_vv_vari that inherits from vari storing vars which holds other vari and that’s kind of confusing still! The issue is still the inheritance pattern, so let’s see if we can rewrite vari to break up the callback from the data.

/**
 * Any class that inherits from this will
 * allocate itself onto our stack allocator
 */
struct stack_allocated_base {
 public:
  static inline void* operator new(size_t nbytes) noexcept {
    return stack_allocator.alloc(nbytes);
  }
};

struct vari : stack_allocated_base {
  double value;
  double adjoint;
  /**
   *  This no longer goes on our callback stack
   *   since it only holds data.
   */
  vari(double x) : value(x) {}
};

// This class will put itself on the callback stack
struct callback_base : stack_allocated_base {
  virtual void chain() = 0;
};

So we’ve separated out the data from the actual callback! We will also change the stack allocator to reflect this

struct stack_allocator {
  // holds callbacks
  std::vector<callback_base*> callback_vec;
  // manages all our memory
  memory_arena mem_;
  // allocate new memory
  void* alloc(size_t nbytes) {
    return mem_.alloc(nbytes);
  }
  // push a vari onto the stack
  void push_back(callback_base* x) {
    callback_vec.push_back(x);
  }
};

Now we need some nice way to add callbacks to our callback stack when we want them. We can use a helper class and C++ lambdas to make it easier to add these callbacks

// F is a functor we call when the chain method 
//  is called for the reverse pass
template <typename F>
struct reverse_pass_callback_impl : public callback_base {
  F rev_functor_;
  explicit reverse_pass_callback_vari(F&& rev_functor)
      : rev_functor_(std::forward<F>(rev_functor)) {
    stack_allocator.push_back(this);
  }

  inline void chain() final { rev_functor_(); }
};

template <typename F>
inline void reverse_pass_callback(F&& functor) {
  new reverse_pass_callback_impl<F>(std::forward<F>(functor));
}

In the above we have a struct that inherits from callback_base with a template parameter that is a lambda. So when call chain it actually diverts to the lambda we’ve specified. The reverse_pass_callback() function then hides the operator new call so we don’t directly have to deal with memory. With reverse_pass_callback() we can make the `operator+` function a lot simpler!

inline var operator+(var a, var b) {
  // make the return value
  var return_val(a.val() + b.val());
  // Add a callback with a lambda holding the vars
  reverse_pass_callback([a, b, return_val]() mutable {
    if (unlikely(is_nan(return_val.val()))) {
      a.adj() = NOT_A_NUMBER;
      b.adj() = NOT_A_NUMBER;
    } else {
      a.adj() += return_val.adj();
      b.adj() += return_val.adj();
    }
  });
  return return_val
}

So that’s nice! In the above we can separate out the forward and reverse pass for autodiff where the return value is just calculated very normally. Then we pass a lambda to reverse_pass_callback() which will add the gradient calculation to our stack of callbacks so when we make the call to calculate the gradient it knows how to calculate the gradient for this function. Though there’s one issue here, we have to support the new and old style of autodiff so under the hood the callback stack is actually a std::vector<vari*> and reverse_pass_callback() inherits from vari as well. So when we use the reverse_pass_callback() pattern we can end up making two vari, one for the return and another for the callback. Making two of these is not a big deal for medium to large N size problems, but can actually matter for scalar operations. For this we have a make_callback_vari() that is more similar to the old style. Though it does remove the nicety of the return var where we instead need to work with a vari, it can give speedups over reverse_pass_callback() in scalar and small N problems.

inline var operator+(const var& a, const var& b) {
  return make_callback_vari(a.val() + b.val(),
    [a, b](const auto& vi) mutable {
      if (unlikely(std::isnan(vi.val_))) {
        a.adj() = NOT_A_NUMBER;
        b.adj() = NOT_A_NUMBER;
      } else {
        a.adj() += vi.adj_;
        b.adj() += vi.adj_;
      }
    });
}

So pretty nice! Personally I think this is a lot easier to see what parts of the calculations are done in the forward sweep and then later in the reverse sweep. And the main optimization here is that by separating out the data from the callback we can remove the virtual function from the vari which makes the vari class smaller and more cache friendly. There’s other optimizations we get here as well when working with matrices, but this has gotten a bit long so I think that can go into a separate blog post. We are currently adding new functions and rewriting old ones to use this pattern. If you are interested in contributing to Stan math feel free to message me on twitter @BreveStonder or checkout the Stan math issue page. If you have any other questions feel free to leave a comment below!

6 thoughts on “Thinking About Automatic Differentiation in Fun New Ways

  1. I really like the second approach (the one currently based on vari). That provides the same level of memory efficiency as the original implementation.

    Some suggestions

    I have a couple suggestions:

    1. Implementing .adj() and .val() methods for vari would allow the body code to look exactly the same. It’s always possible to wrap a vari in a var (the var is just a pointer to the vari implementation allowing the RAII pattern to be used).

    2. Rather than a make_callback_vari function, var could be given a two-argument constructor of the same signature and then the return can just be a braces initializer return { value, [](…) { … }};

    Memory and virtual function calls

    The first formulation, while simpler, introduces a redundant vari object (which is expensive given that there’s a virtual funciton and hence an 8-byte vtable pointer), which the post mentions. It also introduces a redundant copy of the value (variable return_val), which I didn’t see called out in the post. I tried very hard in the original implementation to avoid redundant copies or even have to create values eagerly. That’s why the chain() method is implemented as a lazy function directly on the vari (where it can access the value).

    The virtual method chain() could be completely eliminated by following this approach everywhere. We’d still need a type with a virtual function to wrap the closures—otherwise there’s no way to collect them. Does that introduce a redundant vtable pointer? I haven’t worked through the details. The role of vari after removing chain() becomes nothing more than providing a pointer into memory. We could probably even get rid of var and just put that pointer directly into var without wrapping it in a class. That’d bring the first version to near the original version, but there’s still that redundant copy of the value. In the second suggestion in this post, the value is just extracted from the vari. I don’t see how a similar move could be made with the first version.

    We only pass vari* around within a var, so cache properties of vari aren’t so relevant. The vtable pointer now goes into the closure created by the lambda for the operator. That also gets passed by pointer. The speed bottnleneck for both is not being able to optimize the virtual call statically, which means the reverse pass behaves more like interpreted code than compiled code. Luckily, C++ compilers keep getting better at virtual function calls.

    More background

    A more complete introduction to reverse-mode autodiff using this continuation-based pattern implemented through closures can be found in my Discourse post, A new continuation-based autodiff by refactoring. It is a complete standalone reverse-mode autodiff implementation that I put together for pedagogical purposes.

    It’s also the style I adopted for the Automatic Differentiation Handbook, which has an appendix with runnable C++ code. The link is to the GitHub repo where you’ll find a precompiled pdf of the current draft plus the C++ code.

    P.S. Here goes nothing. WordPress is tricky because you can’t edit comments and I”m not an editor of this blog.

    Like

    1. Thanks Bob!

      > 1. Implementing .adj() and .val() methods for vari would allow the body code to look exactly the same. It’s always possible to wrap a vari in a var (the var is just a pointer to the vari implementation allowing the RAII pattern to be used).

      Yes! Ben just added those in a recent PR so that should pretty things up

      > 2. Rather than a make_callback_vari function, var could be given a two-argument constructor of the same signature and then the return can just be a braces initializer return { value, [](…) { … }};

      I think we need to wait to move to C++17 for this. Right now a `var` is an alias for a `var_value`. So we would need to be able to write something like

      “`cpp
      const auto chainer = [a, b](const auto& vi&) { a.adj() += vi.adj();};
      var_value ret(a.val() + b.val(), chainer);
      “`

      With C++17 the templates can be deduced from the constructor so we can just write

      “`cpp
      var_value ret(a.val() + b, [a](const auto& vi&) { a.adj() += vi.adj();});
      “`

      Though reading (2) again am I getting what your saying right? I think we could do this in C++14 with a `make_callback_var()` function that does this stuff under the hood

      > The first formulation, while simpler, introduces a redundant vari object (which is expensive given that there’s a virtual function and hence an 8-byte vtable pointer), which the post mentions. It also introduces a redundant copy of the value (variable return_val), which I didn’t see called out in the post.

      Yes that’s a good point as well. I wonder if we could have a vari that has no chain that can be used in `make_callback_vari()` to get rid of the extra vtable pointer. If we had a pure data `vari` and `callback_vari` like in this post I think the extra copy (+ the alloc of the data and callback separately) would be worth it to have the more compact data vari.

      > I tried very hard in the original implementation to avoid redundant copies or even have to create values eagerly. That’s why the chain() method is implemented as a lazy function directly on the vari (where it can access the value).

      Yes I should have opined more on how nice the current systems memory layout is! I always found it wonderfully nice that a vari was 24 bytes so reading the vari pointer in var and the underlying vari is 32 bytes so it fits in 64 byte cache snuggly.

      > The virtual method chain() could be completely eliminated by following this approach everywhere. We’d still need a type with a virtual function to wrap the closures—otherwise there’s no way to collect them. Does that introduce a redundant vtable pointer?

      I _think_ we could write it to avoid an extra vtable pointer. The only solutions I have right now feel hacky and I’m playing with a couple ways to do that.

      > We only pass vari* around within a var, so cache properties of vari aren’t so relevant.

      What’s the intuition here? My thought process is that when we call `a.adj() + b.adj();` that call for the `a` adjoint causes the cpu to fetch the pointer and then the items the pointer is pointing to. If the `vari` were allocated next to each other then the fetch for `a`’s vari is also going to fetch `b`’s vari since we have to get things in 64 bytes. Is my intuition not right here?

      > The vtable pointer now goes into the closure created by the lambda for the operator. That also gets passed by pointer. The speed bottnleneck for both is not being able to optimize the virtual call statically, which means the reverse pass behaves more like interpreted code than compiled code. Luckily, C++ compilers keep getting better at virtual function calls.

      Is it possible right now for the compiler to statically analyze the chain calls? My intuition here is that since we put all of these onto a `std::vector chain_stack;` and call them via `chain_stack[i]->chain()` the compiler can’t really be aware of which vari it’s calling. It’s sort of like the question of how to optimize calls to an array of function pointers which idt can be inlined or statically analyzed beyond anything within the method being called.

      Like

  2. Thanks, for the suggestions!

    I agree 1. and 2. would be nice quality of life improvements, but I don’t think they are very important.

    About “Memory and virtual function calls”:
    Yeah the `reverse_pass_callback` introduces an extra vari. That is why it should not be used for scalar operations. Scalar operations should use `callback_vari` instead. However for vector or matrix operations we can’t use `make_callback_vari` and the old approach also creates an extra vari. This is what the `reverse_pass_callback` is intended for. Also I think the copy of value does not matter that much, as it is just a var, which might be optimized away by compiler anyway.

    I doubt we can avoid the cost of virtual call for chain unless we switch to static expression graph (like something tensorflow uses). I mean we can avoid the virtual call, but than we need some other mechanism that will be just as expensive.

    I skimmed over the linked discourse post. I like the idea of storing value within the var if it is not used in the reverse pass. To implement that we would probably need to change var class so that all accesses to its internals (value etc) go trough functions. I would like to do that anyway as it would enable some other optimizations.

    Like

    1. I agree that we can’t avoid that virtual function call cost—it’s either literally a virtual function call or a function pointer. Does TensorFlow code generate to statically resolve function calls or does it just implement the call graph? I wouldn’t have thought they’d bother code generating because of the structure of their matrix calculations.

      You’re right—it’s a huge amount of work to do these refactorings. I’m amazed that you’ve had so much patience in working through all the functions for things like matrix return types. Daniel Lee and I did that a few times earlier when there were fewer, less complicated functions.

      The binding is reverse_pass_callback([a, b, return_val]()…); Doesn’t that store a copy of return_val in the closure? Now that I’m looking at this, how do we make sure the closure is stored in our memory arena? Or how do you make sure it’s recovered if it’s not in the arena?

      Like

      1. I don’t know all the details of Tensorflow implementation. I just meant statically constructing of compute graph.

        > Doesn’t that store a copy of return_val in the closure?

        It does. But the compiler can likely optimize it so it is stored only in the closure.

        >Now that I’m looking at this, how do we make sure the closure is stored in our memory arena?

        Closure is part of the lambda object. Lambda is stored in a `reverse_pass_callback_vari`, which is allocated using our allocator.

        Like

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s