Below I've posted an OCaml functor which lifts an input module of mathematical functions (+,-,*,/,sin,cos,...) to equivalent functions which also support computing derivatives without any truncation error (i.e. (d sin) is (exactly) the function cos). This technique is commonly called automatic differentiation in C++ (where it is accomplished at compile time through template metaprogramming), but I first was introduced to it by SICM; you can also find examples in Haskell in the
The basic idea is to compute derivatives using the chain rule---if you have a function f and you know its derivative Df and similarly for g and Dg, then you can compute D(f o g)(x) = Df(g(x)) * Dg(x). Computationally you represent a ``derivative value'' by a pair (x,dx); then f((x,dx)) = (f(x), Df(x)*dx). To compute the derivative of some unknown function h, you simply extract the second part of h((x,1)) (this makes more sense---to me, anyway---if you think of (x,1) as the derivative output of the identity function).
There's a bit of extra bookkeeping in the structure pasted in below, mostly because you may have derivatives with respect to several different arguments floating around in the computation at the same time---so each derivative object acquires a unique tag. The only tricky thing about this is that you have to keep some extra terms in the derivatives of binary operators: (x,dx)*(y,dy) = ((x*y, dx*y), (x*dy, dx*dy)). (Here the inner (,) groups refer to the ``x'' label and the outer group refers to the ``y'' label.) You would think that the dx*dy term is second-order and could be dropped, but it's first-order in dx and first-order in dy, and they are not the same variable, so you have to keep it. Gerry Sussman pointed this bug out to me the first time I wrote some code to do this---apparently it had tripped them up in SICM, too. It's a bit nasty, because the bug doesn't show up until you try to take second derivatives (so you have more than one tag floating around in the computation). If you write your own code to do this, consider yourself warned :).
What do I use this for? I'm using it to show that my integrator algorithms are definitely symplectic maps. A (two-dimensional) map (q,p) |-> (Q(q,p), P(q,p)) is symplectic if (dQ/dq*dP/dp - dQ/dp*dP/dq)(q,p) = 1. A nice way to show this for a computation algorithm is to automatically differentiate it and evaluate the combination of derivatives above along a trajectory. The combination should be 1 within roundoff error (though the module below will compute the derivatives without truncation error---that is without finite differencing---it is still subject to roundoff effects in the finite precision of floating-point arithmetic if you use it on ordinary double-precision floats).
Sorry for the long post, but at least they don't come too often :). By the way, since this code is pretty long and involved, I guess I should specify a license: I'm releasing it under the GPL.
(** Code to implement differentiation on arbitrary types. diff.ml: Library to compute exact derivatives of arbitrary mathematical functions. Copyright (C) 2006 Will M. FarrThis program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. *) (** Input signature for [Make].*) module type In = sig type t (** The multiplicative identity *) val one : t (** The additive identity *) val zero : t (** Comparisons *) val compare : t -> t -> int val (+) : t -> t -> t val (-) : t -> t -> t val ( * ) : t -> t -> t val (/) : t -> t -> t (** Arithmetic operations *) val sin : t -> t val cos : t -> t val tan : t -> t (** Trigonometric functions *) val asin : t -> t val acos : t -> t val atan : t -> t (** Inverse trigonometric functions *) val sqrt : t -> t val log : t -> t val exp : t -> t val ( ** ) : t -> t -> t (** Powers, exponents and logs. *) end (** Output signature for [Make]. *) module type Out = sig (** Input from [In] *) type t (** Opaque tags for variables being differentiated. *) type tag (** Output diff type. A [diff] is either a constant [t] value [C(x)] or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be [diff]. [D(tag, x, dx)] represents a small increment [dx] in the value labeled by [tag] about the value [x]. *) type diff = | C of t | D of tag * diff * diff val zero : diff val one : diff (** Additive and multiplicative identities in [diff]. *) (** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to compute the derivative. For example, if we had already defined [(/) : diff -> diff -> diff], and the input signature didn't provide [log], we could define it using [let log = lift log ((/) one)]. *) val lift : (t -> t) -> (diff -> diff) -> (diff -> diff) (** [lower f] takes a function defined on [diff] to the equivalent one defined on [t]. It is an error if [f (C x)] does not evaluate to [C y] for some [y]. The function is called [lower] because it is the inverse of the function [lift] which lifts [f : t -> t] to [diff -> diff]. *) val lower : (diff -> diff) -> (t -> t) (** [lower_multi f] lowers [f] from [diff array -> diff] to [t array -> t]. It is an error if [f \[|(C x0); ...; (C xn)|\]] does not evaluate to [\[|(C y0); ...; (C yn)|\]].*) val lower_multi : (diff array -> diff) -> (t array -> t) (** [lower_multi_multi f] lowers [f] from [diff array -> diff array] to [t array -> t array]. *) val lower_multi_multi : (diff array -> diff array) -> (t array -> t array) (** [d f] returns the function which computes the derivative of [f]. *) val d : (diff -> diff) -> (diff -> diff) (** [partial i f] returns the function which computes the derivative of [f] with respect to its [i]th argument (counting from 0). *) val partial : int -> (diff array -> diff) -> (diff array -> diff) (** [compare d1 d2] compares only the values stored in the derivative---that is either the [C x] or [D(_, x, _)]. [(<)], [(>)], ... are defined in terms of compare, which is defined in terms of [compare] from [In]. *) val compare : diff -> diff -> int val (<) : diff -> diff -> bool val (<=) : diff -> diff -> bool val (>) : diff -> diff -> bool val (>=) : diff -> diff -> bool val (=) : diff -> diff -> bool val (<>) : diff -> diff -> bool (** Comparison functions *) val (+) : diff -> diff -> diff val (-) : diff -> diff -> diff val ( * ) : diff -> diff -> diff val (/) : diff -> diff -> diff (** Algebra. *) val cos : diff -> diff val sin : diff -> diff val tan : diff -> diff (** Trig *) val acos : diff -> diff val asin : diff -> diff val atan : diff -> diff (** Inverse trig *) val sqrt : diff -> diff val log : diff -> diff val exp : diff -> diff val ( ** ) : diff -> diff -> diff (** Powers, exponents and logs *) end module Make(I : In) : Out with type t = I.t = struct type t = I.t type tag = int (* Terms are either constant or of the form (x + dx), with x represented by tag *) type diff = | C of t | D of tag * diff * diff (* Additive and multiplicative identities in derivatives. *) let zero = C(I.zero) let one = C(I.one) (* Unique tags *) let new_tag = let count = ref 0 in fun () -> count := !count + 1; !count (* Have to define the arithmetic operators first because they are used in [lift] and friends. *) let rec (+) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.(+) x y) | C(x), D(tag, y, dy) -> D(tag, d1 + y, dy) | D(tag, x, dx), C(y) -> D(tag, x + d2, dx) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x + y, dx + dy) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagx, D(tagy, x + y, dy), dx) let rec (-) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.(-) x y) | C(x), D(tag, y, dy) -> D(tag, d1 - y, zero - dy) | D(tag, x, dx), C(y) -> D(tag, x - d2, dx) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x - y, dx - dy) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagx, D(tagy, x - y, zero - dy), dx) let rec ( * ) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.( * ) x y) | C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy) | D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x * y, x*dy + dx*y) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagx, D(tagy, x*y, x*dy), D(tagy, dx*y, dx*dy)) let rec (/) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.(/) x y) | D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2) | C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y)) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x/y, dx/y - x*dy/(y*y)) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagx, D(tagy, x/y, zero - x*dy/(y*y)), D(tagy, dx/y, zero - dx*dy/(y*y))) let lift f f' = let rec lf = function | C(x) -> C(f x) | D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in lf (* Now that we have the algebra of derivatives worked out, we can define lift, lift2, lower and lower2. *) let lower f = fun x -> match f (C x) with | C(y) -> y | _ -> raise (Failure "lower expects numerical output") let rec extract_deriv tag = function | C(_) -> zero | D(tagx, x, dx) when tagx = tag -> extract_deriv tag x + drop_deriv tag dx | D(tagx, x, dx) -> D(tagx, extract_deriv tag x, extract_deriv tag dx) and drop_deriv tag = function | D(tagx, x, dx) when tagx = tag -> x | D(tagx, x, dx) -> D(tagx, drop_deriv tag x, drop_deriv tag dx) | x -> x (* Only matches C(_) *) let d f = fun x -> let tag = new_tag () in let res = f (D(tag, x, one)) in extract_deriv tag res let rec compare d1 d2 = match d1, d2 with | C(x), C(y) -> I.compare x y | C(x), D(_, y, _) -> compare d1 y | D(_, x, _), C(y) -> compare x d2 | D(_, x, _), D(_, y, _) -> compare x y let rec cos = function | C(x) -> C(I.cos x) | D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and sin = function | C(x) -> C(I.sin x) | D(tag, x, dx) -> D(tag, sin x, dx * (cos x)) let rec tan = function | C(x) -> C(I.tan x) | D(tag, x, dx) -> let tx = tan x in D(tag, tx, (one + tx*tx)*dx) let rec sqrt = function | C(x) -> C(I.sqrt x) | D(tag, x, dx) -> let sx = sqrt x in D(tag, sx, dx / (sx + sx)) let rec acos = function | C(x) -> C(I.acos x) | D(tag, x, dx) -> D(tag, acos x, zero - dx / (sqrt (one - x * x))) let rec asin = function | C(x) -> C(I.asin x) | D(tag, x, dx) -> D(tag, asin x, dx / (sqrt (one - x * x))) let rec atan = function | C(x) -> C(I.atan x) | D(tag, x, dx) -> D(tag, atan x, dx / (one + x*x)) let rec log = function | C(x) -> C(I.log x) | D(tag, x, dx) -> D(tag, log x, dx/x) let rec exp = function | C(x) -> C(I.exp x) | D(tag, x, dx) -> let ex = exp x in D(tag, ex, dx*ex) let rec ( ** ) d1 d2 = match d1, d2 with | C(x), C(y) -> C(I.( ** ) x y) | C(x), D(tag, y, dy) -> D(tag, d1**y, d1**y * (log d1) * dy) | D(tag, x, dx), C(y) -> D(tag, x**d2, x**(d2 - one)*d2*dx) | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> D(tagx, x**y, x**(y - one)*(dx*y + dy*x*(log x))) | D(tagx, x, dx), D(tagy, y, dy) -> D(tagx, D(tagy, x**y, x**y*(log x)*dy), D(tagy, y*x**(y - one), x**(y - one)*(one + y*(log x))*dy)) let replace arr i x = Array.mapi (fun j y -> if j <> i then y else x) arr let partial i f = fun args -> let x = args.(i) in let one_d_f x = f (replace args i x) in (d one_d_f) x let c_ify arr = Array.map (fun x -> (C x)) arr let de_c_ify arr = Array.map (function | C(x) -> x | _ -> raise (Failure "cannot lower [|D(...); ...; D(...)|]")) arr let lower_multi f = fun args -> match (f (c_ify args)) with | C(x) -> x | _ -> raise (Failure "cannot lower D(...)") let lower_multi_multi f = fun args -> de_c_ify (f (c_ify args)) (* Define these all at once (and last) so as not to spoil the comparison operator namespace. *) let (<) d1 d2 = compare d1 d2 < 0 and (<=) d1 d2 = not (compare d1 d2 > 0) and (>) d1 d2 = compare d1 d2 > 0 and (>=) d1 d2 = not (compare d1 d2 < 0) and (=) d1 d2 = compare d1 d2 = 0 and (<>) d1 d2 = not (compare d1 d2 = 0) end
3 comments:
Andrew, your python code is completely different to this ocaml code. You're using symbolic differentiation whereas the ocaml code is using a quite different algorithm 'automatic differentiation'. With automatic differentiation you can differentiate things like entire fluid dynamic simulations. You couldn't begin to do this with symbolic differentiation.
For my own take on automatic differentiation I have written something here.
Here is my example of the automatic differentiation with Python
I think it's great that automatic differentiation is also becoming popular in programming languages like ocaml.
The first AD tools have been mostly written in Fortran and C and I believe the most efficient AD tools are still in Fortran and C.
Lateley, with operator overloading techniques possible in Fortran90 and C++ these tools have been becoming more user-friendly while maintaining their performace.
My favourite tools are ADC03 for C/C++ and ADF03 for Fortran.
Post a Comment