I've uncovered a *large* inefficiency in the OCaml code I posted for Automatic Differentiation. The problem is that I don't take any care to avoid duplication of tags in a nested D(tag, x, dx) expression. That's not a problem in a computation with only one tag, but if you're trying to take the jacobian of some multi-dimensional function, you have derivatives wrt each of the arguments floating through the computation. If you duplicate tags every time you construct a new D(...) expression, then the nested length of the expression grows *really* fast. The key to avoiding duplication is to think of a given D(tag, x, dx) expression as a tree which has the heap property: all tags in the expressions x and dx are larger than tag. Then you just have to re-write the binary operations (these are the only ones which can combine D(...) expressions with different tags) to preserve the heap property. Updated code is posted below (sorry for another long post, but this change makes a big difference)! I've added the jacobian and lower_jacobian methods because I needed them to test for symplecticity (which can be expressed as the condition J^T.S.J = S, where S is the standard symplectic form [[0, I], [-I, 0]], where I is the identity matrix, and 0 is the zero matrix, and J is the jacobian of a map which is symplecitc). As before, this code is released under the GPL.
(** Code to implement differentation on arbitrary types.
diff.ml: Library to compute exact derivatives of arbitrary
mathematical functions.
Copyright (C) 2006 Will M. Farr
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*)
(** Input signature for [Make].*)
module type In = sig
type t
(** The multiplicative identity *)
val one : t
(** The additive identity *)
val zero : t
(** Comparisons *)
val compare : t -> t -> int
val (+) : t -> t -> t
val (-) : t -> t -> t
val ( * ) : t -> t -> t
val (/) : t -> t -> t
(** Arithmetic operations *)
val sin : t -> t
val cos : t -> t
val tan : t -> t
(** Trigonometric functions *)
val asin : t -> t
val acos : t -> t
val atan : t -> t
(** Inverse trigonometric functions *)
val sqrt : t -> t
val log : t -> t
val exp : t -> t
val ( ** ) : t -> t -> t
(** Powers, exponents and logs. *)
val print_t : out_channel -> t -> unit
end
(** Output signature for [Make]. *)
module type Out = sig
(** Input from [In] *)
type t
(** Opaque tags for variables being differentiated. *)
type tag
(** Output diff type. A [diff] is either a constant [t] value [C(x)]
or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be
[diff]. [D(tag, x, dx)] represents a small increment [dx] in the
value labeled by [tag] about the value [x]. *)
type diff =
| C of t
| D of tag * diff * diff
val zero : diff
val one : diff
(** Additive and multiplicative identities in [diff]. *)
(** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to
compute the derivative. For example, if we had already defined
[(/) : diff -> diff -> diff], and the input signature didn't
provide [log], we could define it using [let log = lift log ((/)
one)]. *)
val lift : (t -> t) -> (diff -> diff) -> (diff -> diff)
(** [lower f] takes a function defined on [diff] to the equivalent one
defined on [t]. It is an error if [f (C x)] does not evaluate to
[C y] for some [y]. The function is called [lower] because it is
the inverse of the function [lift] which lifts [f : t -> t] to
[diff -> diff]. *)
val lower : (diff -> diff) -> (t -> t)
(** [lower_multi f] lowers [f] from [diff array -> diff] to [t array ->
t]. It is an error if [f \[|(C x0); ...; (C xn)|\]] does not
evaluate to [\[|(C y0); ...; (C yn)|\]].*)
val lower_multi : (diff array -> diff) -> (t array -> t)
(** [lower_multi_multi f] lowers [f] from [diff array -> diff array]
to [t array -> t array]. *)
val lower_multi_multi : (diff array -> diff array) ->
(t array -> t array)
(** [lower_jacobian j] lowers the jacobian [j] from [diff array ->
diff array array] to [t array -> t array array]. *)
val lower_jacobian : (diff array -> diff array array) ->
(t array -> t array array)
(** [d f] returns the function which computes the derivative of
[f]. *)
val d : (diff -> diff) -> (diff -> diff)
(** [partial i f] returns the function which computes the derivative
of [f] with respect to its [i]th argument (counting from 0). *)
val partial : int -> (diff array -> diff) -> (diff array -> diff)
(** [jacobian f] returns the jacobian matrix with elements [m.(i).(j)
= d f.(i)/d x.(j)]. *)
val jacobian : (diff array -> diff array) ->
(diff array -> diff array array)
(** [compare d1 d2] compares only the values stored in the
derivative---that is either the [C x] or [D(_, x, _)]. [(<)],
[(>)], ... are defined in terms of compare, which is defined in
terms of [compare] from [In]. *)
val compare : diff -> diff -> int
val (<) : diff -> diff -> bool
val (<=) : diff -> diff -> bool
val (>) : diff -> diff -> bool
val (>=) : diff -> diff -> bool
val (=) : diff -> diff -> bool
val (<>) : diff -> diff -> bool
(** Comparison functions *)
val (+) : diff -> diff -> diff
val (-) : diff -> diff -> diff
val ( * ) : diff -> diff -> diff
val (/) : diff -> diff -> diff
(** Algebra. *)
val cos : diff -> diff
val sin : diff -> diff
val tan : diff -> diff
(** Trig *)
val acos : diff -> diff
val asin : diff -> diff
val atan : diff -> diff
(** Inverse trig *)
val sqrt : diff -> diff
val log : diff -> diff
val exp : diff -> diff
val ( ** ) : diff -> diff -> diff
(** Powers, exponents and logs *)
val print_diff : out_channel -> diff -> unit
end
module Make(I : In) : Out with type t = I.t = struct
type t = I.t
type tag = int
(* Terms are either constant or of the form (x + dx), with x
represented by tag. We are careful to maintain the invariant that
all the tags in x and dx are larger than the tag of (x + dx).
(That is, you can think of a D(tag, x, dx) as a tree :
tag
/ x dx
which satisfies the heap property tag < tagx and tag < tagdx.
*)
type diff =
| C of t
| D of tag * diff * diff
let rec print_diff out = function
| C(x) -> Printf.fprintf out "C(";
I.print_t out x;
Printf.fprintf out ")"
| D(tag, x, dx) ->
Printf.fprintf out "D(%d," tag;
print_diff out x;
Printf.fprintf out ", ";
print_diff out dx;
Printf.fprintf out ")"
(* Additive and multiplicative identities in derivatives. *)
let zero = C(I.zero)
let one = C(I.one)
(* Unique tags *)
let new_tag =
let count = ref 0 in
fun () ->
count := !count + 1;
!count
(* Have to define the arithmetic operators first because they are used
in [lift] and friends. To maintain the heap property of the tags,
we select the smallest of tagx and tagy when we're operating on two
D(_,_,_) objects. We know that we can directly construct
D(smallest, _, _), where _ and _ can be any of the sub derivatives.
But, we can only use the sub-derivatives from larger in any direct
construction D(larger, _, _); there is no guarantee that the tags
in the sub-derivatives from D(smaller, _, _) are in any relation to
larger. This is the reason for the somewhat obfuscated code
below. *)
let rec (+) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.(+) x y)
| C(x), D(tag, y, dy) ->
D(tag, d1 + y, dy)
| D(tag, x, dx), C(y) ->
D(tag, x + d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x + y, dx + dy)
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x + d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1 + y, dy)
let rec (-) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.(-) x y)
| C(x), D(tag, y, dy) ->
D(tag, d1 - y, zero - dy)
| D(tag, x, dx), C(y) ->
D(tag, x - d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x - y, dx - dy)
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x - d2, dx)
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1 - y, zero - dy)
let rec ( * ) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.( * ) x y)
| C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy)
| D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x * y, x*dy + dx*y)
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x*d2, dx*d2)
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1*y, d1*dy)
let rec (/) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.(/) x y)
| D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2)
| C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y))
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x/y, dx/y - x*dy/(y*y))
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
let y2 = y * y in
let mdyoy2 = zero - dy/y2 and
ooy = one/y in
D(tagx, x * (D(tagy, ooy, mdyoy2)),
dx*(D(tagy, ooy, mdyoy2)))
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, d1/y, zero - d1*dy/(y*y))
let lift f f' =
let rec lf = function
| C(x) -> C(f x)
| D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in
lf
(* Now that we have the algebra of derivatives worked out, we can
define lift, lift2, lower and lower2. *)
let lower f =
fun x ->
match f (C x) with
| C(y) -> y
| _ -> raise (Failure "lower expects numerical output")
let rec extract_deriv tag = function
| C(_) -> zero
| D(tagx, x, dx) when tagx = tag ->
dx
| D(tagx, x, dx) ->
D(tagx, extract_deriv tag x, extract_deriv tag dx)
and drop_deriv tag = function
| D(tagx, x, dx) when tagx = tag ->
x
| D(tagx, x, dx) ->
D(tagx, drop_deriv tag x, drop_deriv tag dx)
| x -> x (* Only matches C(_) *)
let d f =
fun x ->
let tag = new_tag () in
let res = f (D(tag, x, one)) in
extract_deriv tag res
let rec compare d1 d2 =
match d1, d2 with
| C(x), C(y) -> I.compare x y
| C(x), D(_, y, _) -> compare d1 y
| D(_, x, _), C(y) -> compare x d2
| D(_, x, _), D(_, y, _) -> compare x y
let rec cos = function
| C(x) -> C(I.cos x)
| D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and
sin = function
| C(x) -> C(I.sin x)
| D(tag, x, dx) -> D(tag, sin x, dx * (cos x))
let rec tan = function
| C(x) -> C(I.tan x)
| D(tag, x, dx) ->
let tx = tan x in
D(tag, tx, (one + tx*tx)*dx)
let rec sqrt = function
| C(x) -> C(I.sqrt x)
| D(tag, x, dx) ->
let sx = sqrt x in
D(tag, sx, dx / (sx + sx))
let rec acos = function
| C(x) -> C(I.acos x)
| D(tag, x, dx) ->
D(tag, acos x, zero - dx / (sqrt (one - x * x)))
let rec asin = function
| C(x) -> C(I.asin x)
| D(tag, x, dx) ->
D(tag, asin x, dx / (sqrt (one - x * x)))
let rec atan = function
| C(x) -> C(I.atan x)
| D(tag, x, dx) ->
D(tag, atan x, dx / (one + x*x))
let rec log = function
| C(x) -> C(I.log x)
| D(tag, x, dx) -> D(tag, log x, dx/x)
let rec exp = function
| C(x) -> C(I.exp x)
| D(tag, x, dx) ->
let ex = exp x in
D(tag, ex, dx*ex)
let rec ( ** ) d1 d2 =
match d1, d2 with
| C(x), C(y) -> C(I.( ** ) x y)
| C(x), D(tag, y, dy) ->
D(tag, d1**y, d1**y * (log d1) * dy)
| D(tag, x, dx), C(y) ->
D(tag, x**d2, x**(d2 - one)*d2*dx)
| D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
D(tagx, x**y, x**(y - one) * (dx*y + x*(log x)*dy))
| D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
D(tagx, x**y*(one + (log x)*(D(tagy, zero, dy))),
dx*x**(y - one)*(y + (one + y*(log x))*(D(tagy, zero, dy))))
| D(tagx, x, dx), D(tagy, y, dy) ->
D(tagy, x**y + y*x**(y-one)*(D(tagx, zero, dx)),
dy*(x**y*(log x) + x**(y-one)*(one+y*(log x))*(D(tagx, zero, dx))))
let replace arr i x =
Array.mapi
(fun j y ->
if j <> i then y else x)
arr
let partial i f =
fun args ->
let x = args.(i) in
let one_d_f x =
f (replace args i x) in
(d one_d_f) x
let jacobian f args =
let n = Array.length args in
let tags = Array.init n (fun _ -> new_tag ()) in
let dargs = Array.mapi (fun i arg -> D(tags.(i), arg, one)) args in
let result = f dargs in
Array.init n
(fun i ->
let fi = result.(i) in
Array.init n
(fun j ->
let tj = tags.(j) in
Array.fold_left
(fun res tag ->
if tag = tj then
extract_deriv tag res
else
drop_deriv tag res)
fi
tags))
let c_ify arr =
Array.map (fun x -> (C x)) arr
let de_c_ify arr =
Array.map
(function
| C(x) -> x
| d ->
raise (Failure "cannot lower [|D(...); ...; D(...)|]"))
arr
let lower_multi f =
fun args ->
match (f (c_ify args)) with
| C(x) -> x
| _ -> raise (Failure "cannot lower D(...)")
let lower_multi_multi f =
fun args ->
de_c_ify (f (c_ify args))
let lower_jacobian j args =
Array.map
(fun carr ->
de_c_ify carr)
(j (c_ify args))
(* Define these all at once (and last) so as not to spoil the
comparison operator namespace. *)
let (<) d1 d2 =
compare d1 d2 < 0 and
(<=) d1 d2 =
not (compare d1 d2 > 0) and
(>) d1 d2 =
compare d1 d2 > 0 and
(>=) d1 d2 =
not (compare d1 d2 < 0) and
(=) d1 d2 =
compare d1 d2 = 0 and
(<>) d1 d2 =
not (compare d1 d2 = 0)
end
module DFloat = Make(struct
type t = float
include Pervasives
let (+) = (+.)
let (-) = (-.)
let (/) = (/.)
let ( * ) = ( *. )
let one = 1.0
let zero = 0.0
let print_t out = Printf.fprintf out "%g"
end)