24 October 2006

Updated Scheme Code for Derivatives

I've updated my functional differential geometry code to use the same techniques I've been writing about these past few days in OCaml for computing derivatives. It should (haven't rigorously tested this) improve the performance of computing structure constants (which involve the commutator of two tangent vectors in the group, and therefore taking second derivatives, and therefore multiple tags in the derivative objects, which the tree structure handles more efficiently) and anywhere else in the code where I need multiple derivatives. Issuing darcs get http://web.mit.edu/farr/www/SchemeCode/ should get the new version, or you can just download the code at that web address. Scheme can really be a breath of fresh air (compare the implementation of the macro define-binary-derivative in derivatives.ss to the repeated pattern which appears in the OCaml code), but the performance leaves something to be desired for large-scale numerical work. (The OCaml code runs *much* faster than the Scheme code, and still takes a couple of hours to measure the non-symplecticity in a 25-body simulation.)

As always, comments from my user(s?) are definitely welcome at farr@mit.edu.

23 October 2006

Better Jacobians in Computing Derivatives

And another fix to the code computing derivatives: before I was trying to compute the Jacobian of functions with only one call to the function:

let jacobian f args = 
    let n = Array.length args in 
    let tags = Array.init n (fun _ -> new_tag ()) in 
    let dargs = Array.mapi (fun i arg -> D(tags.(i), arg, one)) args in 
    let result = f dargs in 
    Array.init n 
      (fun i -> 
        let fi = result.(i) in 
        Array.init n 
          (fun j -> 
            let tj = tags.(j) in 
            Array.fold_left 
              (fun res tag -> 
                if tag = tj then 
                  extract_deriv tag res
                else
                  drop_deriv tag res)
              fi
              tags))
That's efficient in the sense that it doesn't re-compute results within f. Unfortunately, it introduces n tags, where n is the dimensionality of the argument to f. Each tag labels a fundamental ``incremental value'' dx, dy, .... A particular D(...) expression is a tree which represents an expression a0 + a1*dx + a2*dx*dy + ... which potentially contains one term for each of the 2^n different combinations of the dx, dy, .... This means that, though we save on repeated invocations of f, the cost for computing the jacobian becomes exponential in the dimension of the argument because each operation (+, sin, etc) has to manipulate these trees to a depth of 2^n.

The following simple change wastes effort computing f dargs multiple times, but scales only linearly in complexity with the dimension of args. This is a huge win. I've also added a function insert_tag which inserts a new D(tag, x, one) term within the derivative-computing wrappers (but does so in way consistent with the heap property of the diff-tree). The modified functions are below:

  let rec insert_tag tag = function 
    | (C(x) as d) -> D(tag, d, one)
    | (D(tagd, x, dx) as d) when tag < tagd -> 
        D(tag, d, one)
    | D(tagd, x, dx) when tagd = tag -> 
        raise (Failure "Cannot insert tag which is already present.")
    | D(tagd, x, dx) -> 
        D(tagd, insert_tag tag x, dx)
          
  let d f = 
    fun x -> 
      let tag = new_tag () in 
      let res = f (insert_tag tag x) in
      extract_deriv tag res

  let jacobian f args = 
    let n = Array.length args in 
    let jac = Array.make_matrix n n zero in 
    for j = 0 to Pervasives.(-) n 1 do 
      let tag = new_tag () in 
      let dargs = Array.mapi 
          (fun i arg -> if Pervasives.(=) i j then insert_tag tag arg else arg)
          args in 
      let result = f dargs in 
      Array.iteri (fun i res -> 
        jac.(i).(j) <- extract_deriv tag res)
        result
    done;
    jac

19 October 2006

Correction to Automatic Differentiation Code

I've uncovered a *large* inefficiency in the OCaml code I posted for Automatic Differentiation. The problem is that I don't take any care to avoid duplication of tags in a nested D(tag, x, dx) expression. That's not a problem in a computation with only one tag, but if you're trying to take the jacobian of some multi-dimensional function, you have derivatives wrt each of the arguments floating through the computation. If you duplicate tags every time you construct a new D(...) expression, then the nested length of the expression grows *really* fast. The key to avoiding duplication is to think of a given D(tag, x, dx) expression as a tree which has the heap property: all tags in the expressions x and dx are larger than tag. Then you just have to re-write the binary operations (these are the only ones which can combine D(...) expressions with different tags) to preserve the heap property. Updated code is posted below (sorry for another long post, but this change makes a big difference)! I've added the jacobian and lower_jacobian methods because I needed them to test for symplecticity (which can be expressed as the condition J^T.S.J = S, where S is the standard symplectic form [[0, I], [-I, 0]], where I is the identity matrix, and 0 is the zero matrix, and J is the jacobian of a map which is symplecitc). As before, this code is released under the GPL.

(** Code to implement differentation on arbitrary types. 

   diff.ml: Library to compute exact derivatives of arbitrary
   mathematical functions.

   Copyright (C) 2006 Will M. Farr 
   
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

*)

(** Input signature for [Make].*)
module type In = sig
  type t

(** The multiplicative identity *)
  val one : t

(** The additive identity *)
  val zero : t

(** Comparisons *)
  val compare : t -> t -> int

  val (+) : t -> t -> t
  val (-) : t -> t -> t
  val ( * ) : t -> t -> t
  val (/) : t -> t -> t
(** Arithmetic operations *)


  val sin : t -> t
  val cos : t -> t
  val tan : t -> t 
(** Trigonometric functions *)

  val asin : t -> t
  val acos : t -> t
  val atan : t -> t
(** Inverse trigonometric functions *)

  val sqrt : t -> t
  val log : t -> t
  val exp : t -> t
  val ( ** ) : t -> t -> t
(** Powers, exponents and logs. *)

  val print_t : out_channel -> t -> unit
end

(** Output signature for [Make]. *)
module type Out = sig 
(** Input from [In] *)
  type t 

(** Opaque tags for variables being differentiated. *)
  type tag

(** Output diff type.  A [diff] is either a constant [t] value [C(x)]
   or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be
   [diff].  [D(tag, x, dx)] represents a small increment [dx] in the
   value labeled by [tag] about the value [x]. *)
  type diff = 
    | C of t 
    | D of tag * diff * diff

  val zero : diff
  val one : diff
(** Additive and multiplicative identities in [diff]. *)

(** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to
   compute the derivative.  For example, if we had already defined
   [(/) : diff -> diff -> diff], and the input signature didn't
   provide [log], we could define it using [let log = lift log ((/)
   one)]. *)
  val lift : (t -> t) -> (diff -> diff) -> (diff -> diff)

(** [lower f] takes a function defined on [diff] to the equivalent one
   defined on [t].  It is an error if [f (C x)] does not evaluate to
   [C y] for some [y].  The function is called [lower] because it is
   the inverse of the function [lift] which lifts [f : t -> t] to
   [diff -> diff].  *)
  val lower : (diff -> diff) -> (t -> t)

(** [lower_multi f] lowers [f] from [diff array -> diff] to [t array ->
   t].  It is an error if [f \[|(C x0); ...; (C xn)|\]] does not
   evaluate to [\[|(C y0); ...; (C yn)|\]].*)
  val lower_multi : (diff array -> diff) -> (t array -> t)

(** [lower_multi_multi f] lowers [f] from [diff array -> diff array]
   to [t array -> t array].  *)
  val lower_multi_multi : (diff array -> diff array) -> 
    (t array -> t array)

(** [lower_jacobian j] lowers the jacobian [j] from [diff array ->
   diff array array] to [t array -> t array array]. *)
  val lower_jacobian : (diff array -> diff array array) -> 
    (t array -> t array array)

(** [d f] returns the function which computes the derivative of
   [f]. *)
  val d : (diff -> diff) -> (diff -> diff)

(** [partial i f] returns the function which computes the derivative
    of [f] with respect to its [i]th argument (counting from 0). *)
  val partial : int -> (diff array -> diff) -> (diff array -> diff)

(** [jacobian f] returns the jacobian matrix with elements [m.(i).(j)
   = d f.(i)/d x.(j)]. *)
  val jacobian : (diff array -> diff array) -> 
    (diff array -> diff array array)

(** [compare d1 d2] compares only the values stored in the
   derivative---that is either the [C x] or [D(_, x, _)].  [(<)],
   [(>)], ... are defined in terms of compare, which is defined in
   terms of [compare] from [In]. *)
  val compare : diff -> diff -> int

  val (<) : diff -> diff -> bool
  val (<=) : diff -> diff -> bool
  val (>) : diff -> diff -> bool
  val (>=) : diff -> diff -> bool
  val (=) : diff -> diff -> bool
  val (<>) : diff -> diff -> bool
(** Comparison functions *)

  val (+) : diff -> diff -> diff
  val (-) : diff -> diff -> diff
  val ( * ) : diff -> diff -> diff
  val (/) : diff -> diff -> diff
(** Algebra. *)

  val cos : diff -> diff
  val sin : diff -> diff
  val tan : diff -> diff
(** Trig *) 

  val acos : diff -> diff
  val asin : diff -> diff
  val atan : diff -> diff
(** Inverse trig *)

  val sqrt : diff -> diff
  val log : diff -> diff
  val exp : diff -> diff
  val ( ** ) : diff -> diff -> diff
(** Powers, exponents and logs *)

  val print_diff : out_channel -> diff -> unit
end

module Make(I : In) : Out with type t = I.t = struct
  type t = I.t

  type tag = int

(* Terms are either constant or of the form (x + dx), with x
   represented by tag.  We are careful to maintain the invariant that
   all the tags in x and dx are larger than the tag of (x + dx).
   (That is, you can think of a D(tag, x, dx) as a tree : 

          tag 
         /           x     dx

   which satisfies the heap property tag < tagx and tag < tagdx.
*)
  type diff = 
    | C of t
    | D of tag * diff * diff

  let rec print_diff out = function 
    | C(x) -> Printf.fprintf out "C(";
        I.print_t out x;
        Printf.fprintf out ")"
    | D(tag, x, dx) -> 
        Printf.fprintf out "D(%d," tag;
        print_diff out x;
        Printf.fprintf out ", ";
        print_diff out dx;
        Printf.fprintf out ")"

(* Additive and multiplicative identities in derivatives. *)
  let zero = C(I.zero)
  let one = C(I.one)

(* Unique tags *)
  let new_tag = 
    let count = ref 0 in 
    fun () -> 
      count := !count + 1;
      !count

(* Have to define the arithmetic operators first because they are used
   in [lift] and friends.  To maintain the heap property of the tags,
   we select the smallest of tagx and tagy when we're operating on two
   D(_,_,_) objects.  We know that we can directly construct
   D(smallest, _, _), where _ and _ can be any of the sub derivatives.
   But, we can only use the sub-derivatives from larger in any direct
   construction D(larger, _, _); there is no guarantee that the tags
   in the sub-derivatives from D(smaller, _, _) are in any relation to
   larger.  This is the reason for the somewhat obfuscated code
   below. *)
  let rec (+) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(+) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 + y, dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x + d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
        D(tagx, x + y, dx + dy)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
        D(tagx, x + d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1 + y, dy)

  let rec (-) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(-) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 - y, zero - dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x - d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x - y, dx - dy)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        D(tagx, x - d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1 - y, zero - dy)

  let rec ( * ) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.( * ) x y)
    | C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy)
    | D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x * y, x*dy + dx*y)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        D(tagx, x*d2, dx*d2)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1*y, d1*dy)

  let rec (/) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(/) x y)
    | D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2)
    | C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x/y, dx/y - x*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        let y2 = y * y in 
        let mdyoy2 = zero - dy/y2 and 
            ooy = one/y in 
        D(tagx, x * (D(tagy, ooy, mdyoy2)),
          dx*(D(tagy, ooy, mdyoy2)))
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1/y, zero - d1*dy/(y*y))

  let lift f f' = 
    let rec lf = function 
      | C(x) -> C(f x)
      | D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in 
    lf

(* Now that we have the algebra of derivatives worked out, we can
   define lift, lift2, lower and lower2. *)
  let lower f = 
    fun x -> 
      match f (C x) with 
      | C(y) -> y
      | _ -> raise (Failure "lower expects numerical output")

  let rec extract_deriv tag = function 
    | C(_) -> zero
    | D(tagx, x, dx) when tagx = tag -> 
        dx
    | D(tagx, x, dx) -> 
        D(tagx, extract_deriv tag x, extract_deriv tag dx) 
  and drop_deriv tag = function 
    | D(tagx, x, dx) when tagx = tag -> 
        x
    | D(tagx, x, dx) -> 
        D(tagx, drop_deriv tag x, drop_deriv tag dx)
    | x -> x (* Only matches C(_) *)
          
  let d f = 
    fun x -> 
      let tag = new_tag () in 
      let res = f (D(tag, x, one)) in 
      extract_deriv tag res

  let rec compare d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> I.compare x y
    | C(x), D(_, y, _) -> compare d1 y
    | D(_, x, _), C(y) -> compare x d2
    | D(_, x, _), D(_, y, _) -> compare x y

  let rec cos = function 
    | C(x) -> C(I.cos x)
    | D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and 
      sin = function 
        | C(x) -> C(I.sin x)
        | D(tag, x, dx) -> D(tag, sin x, dx * (cos x))

  let rec tan = function 
    | C(x) -> C(I.tan x)
    | D(tag, x, dx) -> 
        let tx = tan x in 
        D(tag, tx, (one + tx*tx)*dx)

  let rec sqrt = function 
    | C(x) -> C(I.sqrt x)
    | D(tag, x, dx) -> 
        let sx = sqrt x in 
        D(tag, sx, dx / (sx + sx))

  let rec acos = function 
    | C(x) -> C(I.acos x)
    | D(tag, x, dx) -> 
        D(tag, acos x, zero - dx / (sqrt (one - x * x)))

  let rec asin = function 
    | C(x) -> C(I.asin x)
    | D(tag, x, dx) -> 
        D(tag, asin x, dx / (sqrt (one - x * x)))

  let rec atan = function 
    | C(x) -> C(I.atan x)
    | D(tag, x, dx) -> 
        D(tag, atan x, dx / (one + x*x))

  let rec log = function 
    | C(x) -> C(I.log x)
    | D(tag, x, dx) -> D(tag, log x, dx/x)

  let rec exp = function 
    | C(x) -> C(I.exp x)
    | D(tag, x, dx) -> 
        let ex = exp x in 
        D(tag, ex, dx*ex)

  let rec ( ** ) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.( ** ) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1**y, d1**y * (log d1) * dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x**d2, x**(d2 - one)*d2*dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x**y, x**(y - one) * (dx*y + x*(log x)*dy))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        D(tagx, x**y*(one + (log x)*(D(tagy, zero, dy))),
          dx*x**(y - one)*(y + (one + y*(log x))*(D(tagy, zero, dy))))
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, x**y + y*x**(y-one)*(D(tagx, zero, dx)),
          dy*(x**y*(log x) + x**(y-one)*(one+y*(log x))*(D(tagx, zero, dx))))
                      

  let replace arr i x = 
    Array.mapi 
      (fun j y -> 
        if j <> i then y else x)
      arr

  let partial i f = 
    fun args -> 
      let x = args.(i) in 
      let one_d_f x = 
        f (replace args i x) in 
      (d one_d_f) x

  let jacobian f args = 
    let n = Array.length args in 
    let tags = Array.init n (fun _ -> new_tag ()) in 
    let dargs = Array.mapi (fun i arg -> D(tags.(i), arg, one)) args in 
    let result = f dargs in 
    Array.init n 
      (fun i -> 
        let fi = result.(i) in 
        Array.init n 
          (fun j -> 
            let tj = tags.(j) in 
            Array.fold_left 
              (fun res tag -> 
                if tag = tj then 
                  extract_deriv tag res
                else
                  drop_deriv tag res)
              fi
              tags))
   
  let c_ify arr = 
    Array.map (fun x -> (C x)) arr

  let de_c_ify arr = 
    Array.map 
      (function 
        | C(x) -> x
        | d -> 
            raise (Failure "cannot lower [|D(...); ...; D(...)|]"))
      arr

  let lower_multi f = 
    fun args -> 
      match (f (c_ify args)) with 
      | C(x) -> x
      | _ -> raise (Failure "cannot lower D(...)")

  let lower_multi_multi f = 
    fun args -> 
      de_c_ify (f (c_ify args))

  let lower_jacobian j args = 
     Array.map 
      (fun carr -> 
        de_c_ify carr)
      (j (c_ify args))

(* Define these all at once (and last) so as not to spoil the
   comparison operator namespace. *)
let (<) d1 d2 = 
  compare d1 d2 < 0 and 
    (<=) d1 d2 = 
  not (compare d1 d2 > 0) and 
    (>) d1 d2 = 
  compare d1 d2 > 0 and
    (>=) d1 d2 = 
  not (compare d1 d2 < 0) and 
    (=) d1 d2 = 
  compare d1 d2 = 0 and 
    (<>) d1 d2 = 
  not (compare d1 d2 = 0)
end

module DFloat = Make(struct 
  type t = float 

  include Pervasives

  let (+) = (+.)
  let (-) = (-.)
  let (/) = (/.)
  let ( * ) = ( *. )

  let one = 1.0
  let zero = 0.0

  let print_t out = Printf.fprintf out "%g"
end)

16 October 2006

``Automatic Differentiation'' in OCaml

Below I've posted an OCaml functor which lifts an input module of mathematical functions (+,-,*,/,sin,cos,...) to equivalent functions which also support computing derivatives without any truncation error (i.e. (d sin) is (exactly) the function cos). This technique is commonly called automatic differentiation in C++ (where it is accomplished at compile time through template metaprogramming), but I first was introduced to it by SICM; you can also find examples in Haskell in the Functional Programming and Numerical Computation paper here.

The basic idea is to compute derivatives using the chain rule---if you have a function f and you know its derivative Df and similarly for g and Dg, then you can compute D(f o g)(x) = Df(g(x)) * Dg(x). Computationally you represent a ``derivative value'' by a pair (x,dx); then f((x,dx)) = (f(x), Df(x)*dx). To compute the derivative of some unknown function h, you simply extract the second part of h((x,1)) (this makes more sense---to me, anyway---if you think of (x,1) as the derivative output of the identity function).

There's a bit of extra bookkeeping in the structure pasted in below, mostly because you may have derivatives with respect to several different arguments floating around in the computation at the same time---so each derivative object acquires a unique tag. The only tricky thing about this is that you have to keep some extra terms in the derivatives of binary operators: (x,dx)*(y,dy) = ((x*y, dx*y), (x*dy, dx*dy)). (Here the inner (,) groups refer to the ``x'' label and the outer group refers to the ``y'' label.) You would think that the dx*dy term is second-order and could be dropped, but it's first-order in dx and first-order in dy, and they are not the same variable, so you have to keep it. Gerry Sussman pointed this bug out to me the first time I wrote some code to do this---apparently it had tripped them up in SICM, too. It's a bit nasty, because the bug doesn't show up until you try to take second derivatives (so you have more than one tag floating around in the computation). If you write your own code to do this, consider yourself warned :).

What do I use this for? I'm using it to show that my integrator algorithms are definitely symplectic maps. A (two-dimensional) map (q,p) |-> (Q(q,p), P(q,p)) is symplectic if (dQ/dq*dP/dp - dQ/dp*dP/dq)(q,p) = 1. A nice way to show this for a computation algorithm is to automatically differentiate it and evaluate the combination of derivatives above along a trajectory. The combination should be 1 within roundoff error (though the module below will compute the derivatives without truncation error---that is without finite differencing---it is still subject to roundoff effects in the finite precision of floating-point arithmetic if you use it on ordinary double-precision floats).

Sorry for the long post, but at least they don't come too often :). By the way, since this code is pretty long and involved, I guess I should specify a license: I'm releasing it under the GPL.

(** Code to implement differentiation on arbitrary types. 

   diff.ml: Library to compute exact derivatives of arbitrary
   mathematical functions.

   Copyright (C) 2006 Will M. Farr 
   
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

*)

(** Input signature for [Make].*)
module type In = sig
  type t

(** The multiplicative identity *)
  val one : t

(** The additive identity *)
  val zero : t

(** Comparisons *)
  val compare : t -> t -> int

  val (+) : t -> t -> t
  val (-) : t -> t -> t
  val ( * ) : t -> t -> t
  val (/) : t -> t -> t
(** Arithmetic operations *)


  val sin : t -> t
  val cos : t -> t
  val tan : t -> t 
(** Trigonometric functions *)

  val asin : t -> t
  val acos : t -> t
  val atan : t -> t
(** Inverse trigonometric functions *)

  val sqrt : t -> t
  val log : t -> t
  val exp : t -> t
  val ( ** ) : t -> t -> t
(** Powers, exponents and logs. *)
end

(** Output signature for [Make]. *)
module type Out = sig 
(** Input from [In] *)
  type t 

(** Opaque tags for variables being differentiated. *)
  type tag

(** Output diff type.  A [diff] is either a constant [t] value [C(x)]
   or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be
   [diff].  [D(tag, x, dx)] represents a small increment [dx] in the
   value labeled by [tag] about the value [x]. *)
  type diff = 
    | C of t 
    | D of tag * diff * diff

  val zero : diff
  val one : diff
(** Additive and multiplicative identities in [diff]. *)

(** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to
   compute the derivative.  For example, if we had already defined
   [(/) : diff -> diff -> diff], and the input signature didn't
   provide [log], we could define it using [let log = lift log ((/)
   one)]. *)
  val lift : (t -> t) -> (diff -> diff) -> (diff -> diff)

(** [lower f] takes a function defined on [diff] to the equivalent one
   defined on [t].  It is an error if [f (C x)] does not evaluate to
   [C y] for some [y].  The function is called [lower] because it is
   the inverse of the function [lift] which lifts [f : t -> t] to
   [diff -> diff].  *)
  val lower : (diff -> diff) -> (t -> t)

(** [lower_multi f] lowers [f] from [diff array -> diff] to [t array ->
   t].  It is an error if [f \[|(C x0); ...; (C xn)|\]] does not
   evaluate to [\[|(C y0); ...; (C yn)|\]].*)
  val lower_multi : (diff array -> diff) -> (t array -> t)

(** [lower_multi_multi f] lowers [f] from [diff array -> diff array]
   to [t array -> t array].  *)
  val lower_multi_multi : (diff array -> diff array) -> 
    (t array -> t array)

(** [d f] returns the function which computes the derivative of
   [f]. *)
  val d : (diff -> diff) -> (diff -> diff)

(** [partial i f] returns the function which computes the derivative
    of [f] with respect to its [i]th argument (counting from 0). *)
  val partial : int -> (diff array -> diff) -> (diff array -> diff)

(** [compare d1 d2] compares only the values stored in the
   derivative---that is either the [C x] or [D(_, x, _)].  [(<)],
   [(>)], ... are defined in terms of compare, which is defined in
   terms of [compare] from [In]. *)
  val compare : diff -> diff -> int

  val (<) : diff -> diff -> bool
  val (<=) : diff -> diff -> bool
  val (>) : diff -> diff -> bool
  val (>=) : diff -> diff -> bool
  val (=) : diff -> diff -> bool
  val (<>) : diff -> diff -> bool
(** Comparison functions *)

  val (+) : diff -> diff -> diff
  val (-) : diff -> diff -> diff
  val ( * ) : diff -> diff -> diff
  val (/) : diff -> diff -> diff
(** Algebra. *)

  val cos : diff -> diff
  val sin : diff -> diff
  val tan : diff -> diff
(** Trig *) 

  val acos : diff -> diff
  val asin : diff -> diff
  val atan : diff -> diff
(** Inverse trig *)

  val sqrt : diff -> diff
  val log : diff -> diff
  val exp : diff -> diff
  val ( ** ) : diff -> diff -> diff
(** Powers, exponents and logs *)

end

module Make(I : In) : Out with type t = I.t = struct
  type t = I.t

  type tag = int

(* Terms are either constant or of the form (x + dx), with x
   represented by tag *)
  type diff = 
    | C of t
    | D of tag * diff * diff

(* Additive and multiplicative identities in derivatives. *)
  let zero = C(I.zero)
  let one = C(I.one)

(* Unique tags *)
  let new_tag = 
    let count = ref 0 in 
    fun () -> 
      count := !count + 1;
      !count

(* Have to define the arithmetic operators first because they are used
   in [lift] and friends. *)
  let rec (+) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(+) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 + y, dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x + d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
        D(tagx, x + y, dx + dy)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx, D(tagy, x + y, dy), dx)

  let rec (-) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(-) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 - y, zero - dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x - d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x - y, dx - dy)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx, D(tagy, x - y, zero - dy), dx)

  let rec ( * ) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.( * ) x y)
    | C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy)
    | D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x * y, x*dy + dx*y)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx, 
          D(tagy, x*y, x*dy),
          D(tagy, dx*y, dx*dy))

  let rec (/) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(/) x y)
    | D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2)
    | C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x/y, dx/y - x*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx,
          D(tagy, x/y, zero - x*dy/(y*y)),
          D(tagy, dx/y, zero - dx*dy/(y*y)))

  let lift f f' = 
    let rec lf = function 
      | C(x) -> C(f x)
      | D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in 
    lf

(* Now that we have the algebra of derivatives worked out, we can
   define lift, lift2, lower and lower2. *)
  let lower f = 
    fun x -> 
      match f (C x) with 
      | C(y) -> y
      | _ -> raise (Failure "lower expects numerical output")

let rec extract_deriv tag = function 
  | C(_) -> zero
  | D(tagx, x, dx) when tagx = tag -> 
      extract_deriv tag x + drop_deriv tag dx 
  | D(tagx, x, dx) -> 
      D(tagx, extract_deriv tag x, extract_deriv tag dx) and 
    drop_deriv tag = function 
      | D(tagx, x, dx) when tagx = tag -> 
          x
      | D(tagx, x, dx) -> 
          D(tagx, drop_deriv tag x, drop_deriv tag dx)
      | x -> x (* Only matches C(_) *)

let d f = 
  fun x -> 
    let tag = new_tag () in 
    let res = f (D(tag, x, one)) in 
    extract_deriv tag res

let rec compare d1 d2 = 
  match d1, d2 with 
  | C(x), C(y) -> I.compare x y
  | C(x), D(_, y, _) -> compare d1 y
  | D(_, x, _), C(y) -> compare x d2
  | D(_, x, _), D(_, y, _) -> compare x y

let rec cos = function 
  | C(x) -> C(I.cos x)
  | D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and 
    sin = function 
      | C(x) -> C(I.sin x)
      | D(tag, x, dx) -> D(tag, sin x, dx * (cos x))

let rec tan = function 
  | C(x) -> C(I.tan x)
  | D(tag, x, dx) -> 
      let tx = tan x in 
      D(tag, tx, (one + tx*tx)*dx)

let rec sqrt = function 
  | C(x) -> C(I.sqrt x)
  | D(tag, x, dx) -> 
      let sx = sqrt x in 
      D(tag, sx, dx / (sx + sx))

let rec acos = function 
  | C(x) -> C(I.acos x)
  | D(tag, x, dx) -> 
      D(tag, acos x, zero - dx / (sqrt (one - x * x)))

let rec asin = function 
  | C(x) -> C(I.asin x)
  | D(tag, x, dx) -> 
      D(tag, asin x, dx / (sqrt (one - x * x)))

let rec atan = function 
  | C(x) -> C(I.atan x)
  | D(tag, x, dx) -> 
      D(tag, atan x, dx / (one + x*x))

let rec log = function 
  | C(x) -> C(I.log x)
  | D(tag, x, dx) -> D(tag, log x, dx/x)

let rec exp = function 
  | C(x) -> C(I.exp x)
  | D(tag, x, dx) -> 
      let ex = exp x in 
      D(tag, ex, dx*ex)

let rec ( ** ) d1 d2 = 
  match d1, d2 with 
  | C(x), C(y) -> C(I.( ** ) x y)
  | C(x), D(tag, y, dy) -> 
      D(tag, d1**y, d1**y * (log d1) * dy)
  | D(tag, x, dx), C(y) -> 
      D(tag, x**d2, x**(d2 - one)*d2*dx)
  | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
      D(tagx, x**y, x**(y - one)*(dx*y + dy*x*(log x)))
  | D(tagx, x, dx), D(tagy, y, dy) -> 
      D(tagx,
        D(tagy, x**y, x**y*(log x)*dy),
        D(tagy, y*x**(y - one), x**(y - one)*(one + y*(log x))*dy))

  let replace arr i x = 
    Array.mapi 
      (fun j y -> 
        if j <> i then y else x)
      arr

  let partial i f = 
    fun args -> 
      let x = args.(i) in 
      let one_d_f x = 
        f (replace args i x) in 
      (d one_d_f) x

  let c_ify arr = 
    Array.map (fun x -> (C x)) arr

  let de_c_ify arr = 
    Array.map 
      (function 
        | C(x) -> x
        | _ -> raise (Failure "cannot lower [|D(...); ...; D(...)|]"))
      arr

  let lower_multi f = 
    fun args -> 
      match (f (c_ify args)) with 
      | C(x) -> x
      | _ -> raise (Failure "cannot lower D(...)")

  let lower_multi_multi f = 
    fun args -> 
      de_c_ify (f (c_ify args))

(* Define these all at once (and last) so as not to spoil the
   comparison operator namespace. *)
let (<) d1 d2 = 
  compare d1 d2 < 0 and 
    (<=) d1 d2 = 
  not (compare d1 d2 > 0) and 
    (>) d1 d2 = 
  compare d1 d2 > 0 and
    (>=) d1 d2 = 
  not (compare d1 d2 < 0) and 
    (=) d1 d2 = 
  compare d1 d2 = 0 and 
    (<>) d1 d2 = 
  not (compare d1 d2 = 0)
end