I've uncovered a *large* inefficiency in the OCaml code I posted for Automatic Differentiation. The problem is that I don't take any care to avoid duplication of tags in a nested `D(tag, x, dx)` expression. That's not a problem in a computation with only one tag, but if you're trying to take the jacobian of some multi-dimensional function, you have derivatives wrt each of the arguments floating through the computation. If you duplicate tags every time you construct a new `D(...)` expression, then the nested length of the expression grows *really* fast. The key to avoiding duplication is to think of a given `D(tag, x, dx)` expression as a tree which has the heap property: all tags in the expressions `x` and `dx` are larger than `tag`. Then you just have to re-write the binary operations (these are the only ones which can combine `D(...)` expressions with different tags) to preserve the heap property. Updated code is posted below (sorry for another long post, but this change makes a big difference)! I've added the `jacobian` and `lower_jacobian` methods because I needed them to test for symplecticity (which can be expressed as the condition J^T.S.J = S, where S is the standard symplectic form [[0, I], [-I, 0]], where I is the identity matrix, and 0 is the zero matrix, and J is the jacobian of a map which is symplecitc). As before, this code is released under the GPL.

(** Code to implement differentation on arbitrary types.
diff.ml: Library to compute exact derivatives of arbitrary
mathematical functions.
Copyright (C) 2006 Will M. Farr
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*)
(** Input signature for [Make].*)
module type In = sig
type t
(** The multiplicative identity *)
val one : t
(** The additive identity *)
val zero : t
(** Comparisons *)
val compare : t -> t -> int
val (+) : t -> t -> t
val (-) : t -> t -> t
val ( * ) : t -> t -> t
val (/) : t -> t -> t
(** Arithmetic operations *)
val sin : t -> t
val cos : t -> t
val tan : t -> t
(** Trigonometric functions *)
val asin : t -> t
val acos : t -> t
val atan : t -> t
(** Inverse trigonometric functions *)
val sqrt : t -> t
val log : t -> t
val exp : t -> t
val ( ** ) : t -> t -> t
(** Powers, exponents and logs. *)
val print_t : out_channel -> t -> unit
end
(** Output signature for [Make]. *)
module type Out = sig
(** Input from [In] *)
type t
(** Opaque tags for variables being differentiated. *)
type tag
(** Output diff type. A [diff] is either a constant [t] value [C(x)]
or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be
[diff]. [D(tag, x, dx)] represents a small increment [dx] in the
value labeled by [tag] about the value [x]. *)
type diff =
| C of t
| D of tag * diff * diff
val zero : diff
val one : diff
(** Additive and multiplicative identities in [diff]. *)
(** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to
compute the derivative. For example, if we had already defined
[(/) : diff -> diff -> diff], and the input signature didn't
provide [log], we could define it using [let log = lift log ((/)
one)]. *)
val lift : (t -> t) -> (diff -> diff) -> (diff -> diff)
(** [lower f] takes a function defined on [diff] to the equivalent one
defined on [t]. It is an error if [f (C x)] does not evaluate to
[C y] for some [y]. The function is called [lower] because it is
the inverse of the function [lift] which lifts [f : t -> t] to
[diff -> diff]. *)
val lower : (diff -> diff) -> (t -> t)
(** [lower_multi f] lowers [f] from [diff array -> diff] to [t array ->
t]. It is an error if [f \[|(C x0); ...; (C xn)|\]] does not
evaluate to [\[|(C y0); ...; (C yn)|\]].*)
val lower_multi : (diff array -> diff) -> (t array -> t)
(** [lower_multi_multi f] lowers [f] from [diff array -> diff array]
to [t array -> t array]. *)
val lower_multi_multi : (diff array -> diff array) ->
(t array -> t array)
(** [lower_jacobian j] lowers the jacobian [j] from [diff array ->
diff array array] to [t array -> t array array]. *)
val lower_jacobian : (diff array -> diff array array) ->
(t array -> t array array)
(** [d f] returns the function which computes the derivative of
[f]. *)
val d : (diff -> diff) -> (diff -> diff)
(** [partial i f] returns the function which computes the derivative
of [f] with respect to its [i]th argument (counting from 0). *)
val partial : int -> (diff array -> diff) -> (diff array -> diff)
(** [jacobian f] returns the jacobian matrix with elements [m.(i).(j)
= d f.(i)/d x.(j)]. *)
val jacobian : (diff array -> diff array) ->
(diff array -> diff array array)
(** [compare d1 d2] compares only the values stored in the
derivative---that is either the [C x] or [D(_, x, _)]. [(<)],
[(>)], ... are defined in terms of compare, which is defined in
terms of [compare] from [In]. *)
val compare : diff -> diff -> int
val (<) : diff -> diff -> bool
val (<=) : diff -> diff -> bool
val (>) : diff -> diff -> bool
val (>=) : diff -> diff -> bool
val (=) : diff -> diff -> bool
val (<>) : diff -> diff -> bool
(** Comparison functions *)
val (+) : diff -> diff -> diff
val (-) : diff -> diff -> diff
val ( * ) : diff -> diff -> diff
val (/) : diff -> diff -> diff
(** Algebra. *)
val cos : diff -> diff
val sin : diff -> diff
val tan : diff -> diff
(** Trig *)
val acos : diff -> diff
val asin : diff -> diff
val atan : diff -> diff
(** Inverse trig *)
val sqrt : diff -> diff
val log : diff -> diff
val exp : diff -> diff
val ( ** ) : diff -> diff -> diff
(** Powers, exponents and logs *)
val print_diff : out_channel -> diff -> unit
end
module Make(I : In) : Out with type t = I.t = struct
type t = I.t
type tag = int
(* Terms are either constant or of the form (x + dx), with x
represented by tag. We are careful to maintain the invariant that
all the tags in x and dx are larger than the tag of (x + dx).
(That is, you can think of a D(tag, x, dx) as a tree :
tag
/ x dx
which satisfies the heap property tag < tagx and tag < tagdx.
*)
type diff =
| C of t
| D of tag * diff * diff
let rec print_diff out = function
| C(x) -> Printf.fprintf out "C(";
I.print_t out x;
Printf.fprintf out ")"
| D(tag, x, dx) ->
Printf.fprintf out "D(%d," tag;
print_diff out x;
Printf.fprintf out ", ";
print_diff out dx;
Printf.fprintf out ")"
(* Additive and multiplicative identities in derivatives. *)
let zero = C(I.zero)
let one = C(I.one)
(* Unique tags *)
let new_tag =
let count = ref 0 in
fun () ->
count := !count + 1;
!count
(* Have to define the arithmetic operators first because they are used
in [lift] and friends. To maintain the heap property of the tags,
we select the smallest of tagx and tagy when we're operating on two
D(_,_,_) objects. We know that we can directly construct
D(smallest, _, _), where _ and _ can be any of the sub derivatives.
But, we can only use the sub-derivatives from larger in any direct
construction D(larger, _, _); there is no guarantee that the tags
in the sub-derivatives from D(smaller, _, _) are in any relation to
larger. This is the reason for the somewhat obfuscated code
below. *)
let rec (+) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.(+) x y)
| C(x), D(tag, y, dy) ->
D(tag, d1 + y, dy)
| D(tag, x, dx), C(y) ->
D(tag, x + d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x + y, dx + dy)
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x + d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1 + y, dy)
let rec (-) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.(-) x y)
| C(x), D(tag, y, dy) ->
D(tag, d1 - y, zero - dy)
| D(tag, x, dx), C(y) ->
D(tag, x - d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x - y, dx - dy)
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x - d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1 - y, zero - dy)
let rec ( * ) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.( * ) x y)
| C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy)
| D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x * y, x*dy + dx*y)
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x*d2, dx*d2)
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1*y, d1*dy)
let rec (/) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.(/) x y)
| D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2)
| C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y))
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x/y, dx/y - x*dy/(y*y))
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
let y2 = y * y in
let mdyoy2 = zero - dy/y2 and
ooy = one/y in
D(tagx, x * (D(tagy, ooy, mdyoy2)),
dx*(D(tagy, ooy, mdyoy2)))
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1/y, zero - d1*dy/(y*y))
let lift f f' =
let rec lf = function
| C(x) -> C(f x)
| D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in
lf
(* Now that we have the algebra of derivatives worked out, we can
define lift, lift2, lower and lower2. *)
let lower f =
fun x ->
match f (C x) with
| C(y) -> y
| _ -> raise (Failure "lower expects numerical output")
let rec extract_deriv tag = function
| C(_) -> zero
| D(tagx, x, dx) when tagx = tag ->
dx
| D(tagx, x, dx) ->
D(tagx, extract_deriv tag x, extract_deriv tag dx)
and drop_deriv tag = function
| D(tagx, x, dx) when tagx = tag ->
x
| D(tagx, x, dx) ->
D(tagx, drop_deriv tag x, drop_deriv tag dx)
| x -> x (* Only matches C(_) *)
let d f =
fun x ->
let tag = new_tag () in
let res = f (D(tag, x, one)) in
extract_deriv tag res
let rec compare d1 d2 =
match d1, d2 with
| C(x), C(y) -> I.compare x y
| C(x), D(_, y, _) -> compare d1 y
| D(_, x, _), C(y) -> compare x d2
| D(_, x, _), D(_, y, _) -> compare x y
let rec cos = function
| C(x) -> C(I.cos x)
| D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and
sin = function
| C(x) -> C(I.sin x)
| D(tag, x, dx) -> D(tag, sin x, dx * (cos x))
let rec tan = function
| C(x) -> C(I.tan x)
| D(tag, x, dx) ->
let tx = tan x in
D(tag, tx, (one + tx*tx)*dx)
let rec sqrt = function
| C(x) -> C(I.sqrt x)
| D(tag, x, dx) ->
let sx = sqrt x in
D(tag, sx, dx / (sx + sx))
let rec acos = function
| C(x) -> C(I.acos x)
| D(tag, x, dx) ->
D(tag, acos x, zero - dx / (sqrt (one - x * x)))
let rec asin = function
| C(x) -> C(I.asin x)
| D(tag, x, dx) ->
D(tag, asin x, dx / (sqrt (one - x * x)))
let rec atan = function
| C(x) -> C(I.atan x)
| D(tag, x, dx) ->
D(tag, atan x, dx / (one + x*x))
let rec log = function
| C(x) -> C(I.log x)
| D(tag, x, dx) -> D(tag, log x, dx/x)
let rec exp = function
| C(x) -> C(I.exp x)
| D(tag, x, dx) ->
let ex = exp x in
D(tag, ex, dx*ex)
let rec ( ** ) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.( ** ) x y)
| C(x), D(tag, y, dy) ->
D(tag, d1**y, d1**y * (log d1) * dy)
| D(tag, x, dx), C(y) ->
D(tag, x**d2, x**(d2 - one)*d2*dx)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x**y, x**(y - one) * (dx*y + x*(log x)*dy))
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x**y*(one + (log x)*(D(tagy, zero, dy))),
dx*x**(y - one)*(y + (one + y*(log x))*(D(tagy, zero, dy))))
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, x**y + y*x**(y-one)*(D(tagx, zero, dx)),
dy*(x**y*(log x) + x**(y-one)*(one+y*(log x))*(D(tagx, zero, dx))))
let replace arr i x =
Array.mapi
(fun j y ->
if j <> i then y else x)
arr
let partial i f =
fun args ->
let x = args.(i) in
let one_d_f x =
f (replace args i x) in
(d one_d_f) x
let jacobian f args =
let n = Array.length args in
let tags = Array.init n (fun _ -> new_tag ()) in
let dargs = Array.mapi (fun i arg -> D(tags.(i), arg, one)) args in
let result = f dargs in
Array.init n
(fun i ->
let fi = result.(i) in
Array.init n
(fun j ->
let tj = tags.(j) in
Array.fold_left
(fun res tag ->
if tag = tj then
extract_deriv tag res
else
drop_deriv tag res)
fi
tags))
let c_ify arr =
Array.map (fun x -> (C x)) arr
let de_c_ify arr =
Array.map
(function
| C(x) -> x
| d ->
raise (Failure "cannot lower [|D(...); ...; D(...)|]"))
arr
let lower_multi f =
fun args ->
match (f (c_ify args)) with
| C(x) -> x
| _ -> raise (Failure "cannot lower D(...)")
let lower_multi_multi f =
fun args ->
de_c_ify (f (c_ify args))
let lower_jacobian j args =
Array.map
(fun carr ->
de_c_ify carr)
(j (c_ify args))
(* Define these all at once (and last) so as not to spoil the
comparison operator namespace. *)
let (<) d1 d2 =
compare d1 d2 < 0 and
(<=) d1 d2 =
not (compare d1 d2 > 0) and
(>) d1 d2 =
compare d1 d2 > 0 and
(>=) d1 d2 =
not (compare d1 d2 < 0) and
(=) d1 d2 =
compare d1 d2 = 0 and
(<>) d1 d2 =
not (compare d1 d2 = 0)
end
module DFloat = Make(struct
type t = float
include Pervasives
let (+) = (+.)
let (-) = (-.)
let (/) = (/.)
let ( * ) = ( *. )
let one = 1.0
let zero = 0.0
let print_t out = Printf.fprintf out "%g"
end)