15 November 2006

The Paper is Out!

It's out! The paper of the century---get your copy before they're sold out at the newsstand. (If you're into N-body simulation methods, that is.)

Seriously: I'm very glad to have it out the door. The ball is now in ApJ's court, and I can get back to focusing on my thesis (for which I just submitted a title: "Numerical Relativity from a Gauge Theory Perspective").

Section 6.2.1 uses the OCaml code I posted about earlier for automatic differentiation to show that the integrator algorithm is as symplectic as we claim. This is something I'd like to see more physicists take up---particularly in the cosmological N-body simulations, where preserving the "coldness" of the dark matter phase-space distribution is very important. Everyone uses the leapfrog algorithm (for now---it might turn out that our algorithm is better) because it's symplectic at constant timesteps (and, though we're the first to realize it, with block-power-of-two timesteps---see Section 3, and numerical experiments in Sections 6.1 and 6.2.1), but they never measure whether this is really the case with their algorithms. If they were in the habit of doing this, I'm sure someone before us would have realized that block-power-of-two steps preserve symplecticity while ordinary adaptive steps break it (rather than assuming that any non-constant timestep breaks it).

Anyway, I'm excited to see what the reception of the physics community will be. I've got my fingers crossed....

07 November 2006

SRFI-42 Comprehensions for SRFI-4 Vectors

I spend quite a bit of time working with SRFI-4 homogeneous numeric vectors. Below is some code which adds some SRFI-4-vector generators and comprehensions to SRFI-42's Eager Comprehensions. I am impressed at the modularity in SRFI-42!

#|
Derived (almost verbatim) from the code for vector comprehensions in the SRFI-42 reference implementation.  The copyright on that code is:

Copyright (C) Sebastian Egner (2003). All Rights Reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Modifications are copyright (C) Will M. Farr (2006), goverened by the same conditions.  Feel free to email comments to .  
|#

(module srfi-4-comprehensions mzscheme
  (require (lib "42.ss" "srfi")
           (lib "4.ss" "srfi"))
  
  (provide :s8vector :u8vector :s16vector :u16vector :s32vector 
           :u32vector :s64vector :u64vector :f64vector :f32vector
           
           s8vector-ec u8vector-ec s16vector-ec u16vector-ec s32vector-ec 
           u32vector-ec s64vector-ec u64vector-ec f64vector-ec f32vector-ec
           
           s8vector-of-length-ec u8vector-of-length-ec s16vector-of-length-ec 
           u16vector-of-length-ec s32vector-of-length-ec 
           u32vector-of-length-ec s64vector-of-length-ec u64vector-of-length-ec 
           f64vector-of-length-ec f32vector-of-length-ec)
  
  (define-for-syntax symbol-append
    (case-lambda
      ((s) s)
      ((s1 s2 . ss)
       (apply symbol-append (string->symbol (string-append (symbol->string s1) (symbol->string s2))) ss))))
  
  (define-syntax make/prefix
    (lambda (stx)
      (syntax-case stx ()
        ((make-prefix-generator prefix)
         (let ((pre-sym (syntax-object->datum (syntax prefix))))
           (with-syntax ((vlength (datum->syntax-object (syntax prefix) (symbol-append pre-sym 'vector-length)))
                         (vref (datum->syntax-object (syntax prefix) (symbol-append pre-sym 'vector-ref)))
                         (vgen (datum->syntax-object (syntax prefix) (symbol-append ': pre-sym 'vector)))
                         (vfilter (datum->syntax-object (syntax prefix) (symbol-append 'ec-: pre-sym 'vector-filter)))
                         (vmake (datum->syntax-object (syntax prefix) (symbol-append pre-sym 'vector-make)))
                         (vset! (datum->syntax-object (syntax prefix) (symbol-append pre-sym 'vector-set!)))
                         (v->list (datum->syntax-object (syntax prefix) (symbol-append pre-sym 'vector->list)))
                         (list->v (datum->syntax-object (syntax prefix) (symbol-append 'list-> pre-sym 'vector)))
                         (v-ec (datum->syntax-object (syntax prefix) (symbol-append pre-sym 'vector-ec)))
                         (v-of-length-ec (datum->syntax-object (syntax prefix) (symbol-append pre-sym 'vector-of-length-ec))))
             (syntax 
              (begin  
                (define-syntax vgen
                  (syntax-rules (index)
                    ((vgen cc var arg)
                     (vgen cc var (index i) arg) )
                    ((vgen cc var (index i) arg)
                     (:do cc
                          (let ((vec arg) (len 0)) 
                            (set! len (vlength vec)))
                          ((i 0))
                          (< i len)
                          (let ((var (vref vec i))))
                          #t
                          ((+ i 1)) ))
                    ((vgen cc var (index i) arg1 arg2 arg (... ...))
                     (:parallel cc (vgen cc var arg1 arg2 arg (... ...)) (:integers i)) )
                    ((vgen cc var arg1 arg2 arg (... ...))
                     (:do cc
                          (let ((vec #f)
                                (len 0)
                                (vecs (vfilter (list arg1 arg2 arg (... ...)))) ))
                          ((k 0))
                          (if (< k len)
                              #t
                              (if (null? vecs)
                                  #f
                                  (begin (set! vec (car vecs))
                                         (set! vecs (cdr vecs))
                                         (set! len (vlength vec))
                                         (set! k 0)
                                         #t )))
                          (let ((var (vref vec k))))
                          #t
                          ((+ k 1)) ))))
                (define (vfilter vecs)
                  (if (null? vecs)
                      '()
                      (if (zero? (vlength (car vecs)))
                          (vfilter (cdr vecs))
                          (cons (car vecs) (vfilter (cdr vecs))) )))
                (define-syntax v-ec
                  (syntax-rules ()
                    ((v-ec etc1 etc (... ...))
                     (list->v (list-ec etc1 etc (... ...))) )))
                (define-syntax v-of-length-ec
                  (syntax-rules (nested)
                    ((v-of-length-ec k (nested q1 (... ...)) q etc1 etc (... ...))
                     (v-of-length-ec k (nested q1 (... ...) q) etc1 etc (... ...)) )
                    ((v-of-length-ec k q1 q2             etc1 etc (... ...))
                     (v-of-length-ec k (nested q1 q2)    etc1 etc (... ...)) )
                    ((v-of-length-ec k expression)
                     (v-of-length-ec k (nested) expression) )
                    ((v-of-length-ec k qualifier expression)
                     (let ((len k))
                       (let ((vec (vmake len))
                             (i 0) )
                         (do-ec qualifier
                                (if (< i len)
                                    (begin (vset! vec i expression)
                                           (set! i (+ i 1)) )
                                    (error "vector is too short for the comprehension") ))
                         (if (= i len)
                             vec
                             (error "vector is too long for the comprehension") ))))))))))))))

  (make/prefix s8)
  (make/prefix u8)
  (make/prefix s16)
  (make/prefix u16)
  (make/prefix s32)
  (make/prefix u32)
  (make/prefix s64)
  (make/prefix u64)
  (make/prefix f32)
  (make/prefix f64))

24 October 2006

Updated Scheme Code for Derivatives

I've updated my functional differential geometry code to use the same techniques I've been writing about these past few days in OCaml for computing derivatives. It should (haven't rigorously tested this) improve the performance of computing structure constants (which involve the commutator of two tangent vectors in the group, and therefore taking second derivatives, and therefore multiple tags in the derivative objects, which the tree structure handles more efficiently) and anywhere else in the code where I need multiple derivatives. Issuing darcs get http://web.mit.edu/farr/www/SchemeCode/ should get the new version, or you can just download the code at that web address. Scheme can really be a breath of fresh air (compare the implementation of the macro define-binary-derivative in derivatives.ss to the repeated pattern which appears in the OCaml code), but the performance leaves something to be desired for large-scale numerical work. (The OCaml code runs *much* faster than the Scheme code, and still takes a couple of hours to measure the non-symplecticity in a 25-body simulation.)

As always, comments from my user(s?) are definitely welcome at farr@mit.edu.

23 October 2006

Better Jacobians in Computing Derivatives

And another fix to the code computing derivatives: before I was trying to compute the Jacobian of functions with only one call to the function:

let jacobian f args = 
    let n = Array.length args in 
    let tags = Array.init n (fun _ -> new_tag ()) in 
    let dargs = Array.mapi (fun i arg -> D(tags.(i), arg, one)) args in 
    let result = f dargs in 
    Array.init n 
      (fun i -> 
        let fi = result.(i) in 
        Array.init n 
          (fun j -> 
            let tj = tags.(j) in 
            Array.fold_left 
              (fun res tag -> 
                if tag = tj then 
                  extract_deriv tag res
                else
                  drop_deriv tag res)
              fi
              tags))
That's efficient in the sense that it doesn't re-compute results within f. Unfortunately, it introduces n tags, where n is the dimensionality of the argument to f. Each tag labels a fundamental ``incremental value'' dx, dy, .... A particular D(...) expression is a tree which represents an expression a0 + a1*dx + a2*dx*dy + ... which potentially contains one term for each of the 2^n different combinations of the dx, dy, .... This means that, though we save on repeated invocations of f, the cost for computing the jacobian becomes exponential in the dimension of the argument because each operation (+, sin, etc) has to manipulate these trees to a depth of 2^n.

The following simple change wastes effort computing f dargs multiple times, but scales only linearly in complexity with the dimension of args. This is a huge win. I've also added a function insert_tag which inserts a new D(tag, x, one) term within the derivative-computing wrappers (but does so in way consistent with the heap property of the diff-tree). The modified functions are below:

  let rec insert_tag tag = function 
    | (C(x) as d) -> D(tag, d, one)
    | (D(tagd, x, dx) as d) when tag < tagd -> 
        D(tag, d, one)
    | D(tagd, x, dx) when tagd = tag -> 
        raise (Failure "Cannot insert tag which is already present.")
    | D(tagd, x, dx) -> 
        D(tagd, insert_tag tag x, dx)
          
  let d f = 
    fun x -> 
      let tag = new_tag () in 
      let res = f (insert_tag tag x) in
      extract_deriv tag res

  let jacobian f args = 
    let n = Array.length args in 
    let jac = Array.make_matrix n n zero in 
    for j = 0 to Pervasives.(-) n 1 do 
      let tag = new_tag () in 
      let dargs = Array.mapi 
          (fun i arg -> if Pervasives.(=) i j then insert_tag tag arg else arg)
          args in 
      let result = f dargs in 
      Array.iteri (fun i res -> 
        jac.(i).(j) <- extract_deriv tag res)
        result
    done;
    jac

19 October 2006

Correction to Automatic Differentiation Code

I've uncovered a *large* inefficiency in the OCaml code I posted for Automatic Differentiation. The problem is that I don't take any care to avoid duplication of tags in a nested D(tag, x, dx) expression. That's not a problem in a computation with only one tag, but if you're trying to take the jacobian of some multi-dimensional function, you have derivatives wrt each of the arguments floating through the computation. If you duplicate tags every time you construct a new D(...) expression, then the nested length of the expression grows *really* fast. The key to avoiding duplication is to think of a given D(tag, x, dx) expression as a tree which has the heap property: all tags in the expressions x and dx are larger than tag. Then you just have to re-write the binary operations (these are the only ones which can combine D(...) expressions with different tags) to preserve the heap property. Updated code is posted below (sorry for another long post, but this change makes a big difference)! I've added the jacobian and lower_jacobian methods because I needed them to test for symplecticity (which can be expressed as the condition J^T.S.J = S, where S is the standard symplectic form [[0, I], [-I, 0]], where I is the identity matrix, and 0 is the zero matrix, and J is the jacobian of a map which is symplecitc). As before, this code is released under the GPL.

(** Code to implement differentation on arbitrary types. 

   diff.ml: Library to compute exact derivatives of arbitrary
   mathematical functions.

   Copyright (C) 2006 Will M. Farr 
   
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

*)

(** Input signature for [Make].*)
module type In = sig
  type t

(** The multiplicative identity *)
  val one : t

(** The additive identity *)
  val zero : t

(** Comparisons *)
  val compare : t -> t -> int

  val (+) : t -> t -> t
  val (-) : t -> t -> t
  val ( * ) : t -> t -> t
  val (/) : t -> t -> t
(** Arithmetic operations *)


  val sin : t -> t
  val cos : t -> t
  val tan : t -> t 
(** Trigonometric functions *)

  val asin : t -> t
  val acos : t -> t
  val atan : t -> t
(** Inverse trigonometric functions *)

  val sqrt : t -> t
  val log : t -> t
  val exp : t -> t
  val ( ** ) : t -> t -> t
(** Powers, exponents and logs. *)

  val print_t : out_channel -> t -> unit
end

(** Output signature for [Make]. *)
module type Out = sig 
(** Input from [In] *)
  type t 

(** Opaque tags for variables being differentiated. *)
  type tag

(** Output diff type.  A [diff] is either a constant [t] value [C(x)]
   or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be
   [diff].  [D(tag, x, dx)] represents a small increment [dx] in the
   value labeled by [tag] about the value [x]. *)
  type diff = 
    | C of t 
    | D of tag * diff * diff

  val zero : diff
  val one : diff
(** Additive and multiplicative identities in [diff]. *)

(** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to
   compute the derivative.  For example, if we had already defined
   [(/) : diff -> diff -> diff], and the input signature didn't
   provide [log], we could define it using [let log = lift log ((/)
   one)]. *)
  val lift : (t -> t) -> (diff -> diff) -> (diff -> diff)

(** [lower f] takes a function defined on [diff] to the equivalent one
   defined on [t].  It is an error if [f (C x)] does not evaluate to
   [C y] for some [y].  The function is called [lower] because it is
   the inverse of the function [lift] which lifts [f : t -> t] to
   [diff -> diff].  *)
  val lower : (diff -> diff) -> (t -> t)

(** [lower_multi f] lowers [f] from [diff array -> diff] to [t array ->
   t].  It is an error if [f \[|(C x0); ...; (C xn)|\]] does not
   evaluate to [\[|(C y0); ...; (C yn)|\]].*)
  val lower_multi : (diff array -> diff) -> (t array -> t)

(** [lower_multi_multi f] lowers [f] from [diff array -> diff array]
   to [t array -> t array].  *)
  val lower_multi_multi : (diff array -> diff array) -> 
    (t array -> t array)

(** [lower_jacobian j] lowers the jacobian [j] from [diff array ->
   diff array array] to [t array -> t array array]. *)
  val lower_jacobian : (diff array -> diff array array) -> 
    (t array -> t array array)

(** [d f] returns the function which computes the derivative of
   [f]. *)
  val d : (diff -> diff) -> (diff -> diff)

(** [partial i f] returns the function which computes the derivative
    of [f] with respect to its [i]th argument (counting from 0). *)
  val partial : int -> (diff array -> diff) -> (diff array -> diff)

(** [jacobian f] returns the jacobian matrix with elements [m.(i).(j)
   = d f.(i)/d x.(j)]. *)
  val jacobian : (diff array -> diff array) -> 
    (diff array -> diff array array)

(** [compare d1 d2] compares only the values stored in the
   derivative---that is either the [C x] or [D(_, x, _)].  [(<)],
   [(>)], ... are defined in terms of compare, which is defined in
   terms of [compare] from [In]. *)
  val compare : diff -> diff -> int

  val (<) : diff -> diff -> bool
  val (<=) : diff -> diff -> bool
  val (>) : diff -> diff -> bool
  val (>=) : diff -> diff -> bool
  val (=) : diff -> diff -> bool
  val (<>) : diff -> diff -> bool
(** Comparison functions *)

  val (+) : diff -> diff -> diff
  val (-) : diff -> diff -> diff
  val ( * ) : diff -> diff -> diff
  val (/) : diff -> diff -> diff
(** Algebra. *)

  val cos : diff -> diff
  val sin : diff -> diff
  val tan : diff -> diff
(** Trig *) 

  val acos : diff -> diff
  val asin : diff -> diff
  val atan : diff -> diff
(** Inverse trig *)

  val sqrt : diff -> diff
  val log : diff -> diff
  val exp : diff -> diff
  val ( ** ) : diff -> diff -> diff
(** Powers, exponents and logs *)

  val print_diff : out_channel -> diff -> unit
end

module Make(I : In) : Out with type t = I.t = struct
  type t = I.t

  type tag = int

(* Terms are either constant or of the form (x + dx), with x
   represented by tag.  We are careful to maintain the invariant that
   all the tags in x and dx are larger than the tag of (x + dx).
   (That is, you can think of a D(tag, x, dx) as a tree : 

          tag 
         /           x     dx

   which satisfies the heap property tag < tagx and tag < tagdx.
*)
  type diff = 
    | C of t
    | D of tag * diff * diff

  let rec print_diff out = function 
    | C(x) -> Printf.fprintf out "C(";
        I.print_t out x;
        Printf.fprintf out ")"
    | D(tag, x, dx) -> 
        Printf.fprintf out "D(%d," tag;
        print_diff out x;
        Printf.fprintf out ", ";
        print_diff out dx;
        Printf.fprintf out ")"

(* Additive and multiplicative identities in derivatives. *)
  let zero = C(I.zero)
  let one = C(I.one)

(* Unique tags *)
  let new_tag = 
    let count = ref 0 in 
    fun () -> 
      count := !count + 1;
      !count

(* Have to define the arithmetic operators first because they are used
   in [lift] and friends.  To maintain the heap property of the tags,
   we select the smallest of tagx and tagy when we're operating on two
   D(_,_,_) objects.  We know that we can directly construct
   D(smallest, _, _), where _ and _ can be any of the sub derivatives.
   But, we can only use the sub-derivatives from larger in any direct
   construction D(larger, _, _); there is no guarantee that the tags
   in the sub-derivatives from D(smaller, _, _) are in any relation to
   larger.  This is the reason for the somewhat obfuscated code
   below. *)
  let rec (+) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(+) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 + y, dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x + d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
        D(tagx, x + y, dx + dy)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy ->
        D(tagx, x + d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1 + y, dy)

  let rec (-) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(-) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 - y, zero - dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x - d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x - y, dx - dy)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        D(tagx, x - d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1 - y, zero - dy)

  let rec ( * ) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.( * ) x y)
    | C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy)
    | D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x * y, x*dy + dx*y)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        D(tagx, x*d2, dx*d2)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1*y, d1*dy)

  let rec (/) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(/) x y)
    | D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2)
    | C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x/y, dx/y - x*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        let y2 = y * y in 
        let mdyoy2 = zero - dy/y2 and 
            ooy = one/y in 
        D(tagx, x * (D(tagy, ooy, mdyoy2)),
          dx*(D(tagy, ooy, mdyoy2)))
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, d1/y, zero - d1*dy/(y*y))

  let lift f f' = 
    let rec lf = function 
      | C(x) -> C(f x)
      | D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in 
    lf

(* Now that we have the algebra of derivatives worked out, we can
   define lift, lift2, lower and lower2. *)
  let lower f = 
    fun x -> 
      match f (C x) with 
      | C(y) -> y
      | _ -> raise (Failure "lower expects numerical output")

  let rec extract_deriv tag = function 
    | C(_) -> zero
    | D(tagx, x, dx) when tagx = tag -> 
        dx
    | D(tagx, x, dx) -> 
        D(tagx, extract_deriv tag x, extract_deriv tag dx) 
  and drop_deriv tag = function 
    | D(tagx, x, dx) when tagx = tag -> 
        x
    | D(tagx, x, dx) -> 
        D(tagx, drop_deriv tag x, drop_deriv tag dx)
    | x -> x (* Only matches C(_) *)
          
  let d f = 
    fun x -> 
      let tag = new_tag () in 
      let res = f (D(tag, x, one)) in 
      extract_deriv tag res

  let rec compare d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> I.compare x y
    | C(x), D(_, y, _) -> compare d1 y
    | D(_, x, _), C(y) -> compare x d2
    | D(_, x, _), D(_, y, _) -> compare x y

  let rec cos = function 
    | C(x) -> C(I.cos x)
    | D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and 
      sin = function 
        | C(x) -> C(I.sin x)
        | D(tag, x, dx) -> D(tag, sin x, dx * (cos x))

  let rec tan = function 
    | C(x) -> C(I.tan x)
    | D(tag, x, dx) -> 
        let tx = tan x in 
        D(tag, tx, (one + tx*tx)*dx)

  let rec sqrt = function 
    | C(x) -> C(I.sqrt x)
    | D(tag, x, dx) -> 
        let sx = sqrt x in 
        D(tag, sx, dx / (sx + sx))

  let rec acos = function 
    | C(x) -> C(I.acos x)
    | D(tag, x, dx) -> 
        D(tag, acos x, zero - dx / (sqrt (one - x * x)))

  let rec asin = function 
    | C(x) -> C(I.asin x)
    | D(tag, x, dx) -> 
        D(tag, asin x, dx / (sqrt (one - x * x)))

  let rec atan = function 
    | C(x) -> C(I.atan x)
    | D(tag, x, dx) -> 
        D(tag, atan x, dx / (one + x*x))

  let rec log = function 
    | C(x) -> C(I.log x)
    | D(tag, x, dx) -> D(tag, log x, dx/x)

  let rec exp = function 
    | C(x) -> C(I.exp x)
    | D(tag, x, dx) -> 
        let ex = exp x in 
        D(tag, ex, dx*ex)

  let rec ( ** ) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.( ** ) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1**y, d1**y * (log d1) * dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x**d2, x**(d2 - one)*d2*dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x**y, x**(y - one) * (dx*y + x*(log x)*dy))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx < tagy -> 
        D(tagx, x**y*(one + (log x)*(D(tagy, zero, dy))),
          dx*x**(y - one)*(y + (one + y*(log x))*(D(tagy, zero, dy))))
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagy, x**y + y*x**(y-one)*(D(tagx, zero, dx)),
          dy*(x**y*(log x) + x**(y-one)*(one+y*(log x))*(D(tagx, zero, dx))))
                      

  let replace arr i x = 
    Array.mapi 
      (fun j y -> 
        if j <> i then y else x)
      arr

  let partial i f = 
    fun args -> 
      let x = args.(i) in 
      let one_d_f x = 
        f (replace args i x) in 
      (d one_d_f) x

  let jacobian f args = 
    let n = Array.length args in 
    let tags = Array.init n (fun _ -> new_tag ()) in 
    let dargs = Array.mapi (fun i arg -> D(tags.(i), arg, one)) args in 
    let result = f dargs in 
    Array.init n 
      (fun i -> 
        let fi = result.(i) in 
        Array.init n 
          (fun j -> 
            let tj = tags.(j) in 
            Array.fold_left 
              (fun res tag -> 
                if tag = tj then 
                  extract_deriv tag res
                else
                  drop_deriv tag res)
              fi
              tags))
   
  let c_ify arr = 
    Array.map (fun x -> (C x)) arr

  let de_c_ify arr = 
    Array.map 
      (function 
        | C(x) -> x
        | d -> 
            raise (Failure "cannot lower [|D(...); ...; D(...)|]"))
      arr

  let lower_multi f = 
    fun args -> 
      match (f (c_ify args)) with 
      | C(x) -> x
      | _ -> raise (Failure "cannot lower D(...)")

  let lower_multi_multi f = 
    fun args -> 
      de_c_ify (f (c_ify args))

  let lower_jacobian j args = 
     Array.map 
      (fun carr -> 
        de_c_ify carr)
      (j (c_ify args))

(* Define these all at once (and last) so as not to spoil the
   comparison operator namespace. *)
let (<) d1 d2 = 
  compare d1 d2 < 0 and 
    (<=) d1 d2 = 
  not (compare d1 d2 > 0) and 
    (>) d1 d2 = 
  compare d1 d2 > 0 and
    (>=) d1 d2 = 
  not (compare d1 d2 < 0) and 
    (=) d1 d2 = 
  compare d1 d2 = 0 and 
    (<>) d1 d2 = 
  not (compare d1 d2 = 0)
end

module DFloat = Make(struct 
  type t = float 

  include Pervasives

  let (+) = (+.)
  let (-) = (-.)
  let (/) = (/.)
  let ( * ) = ( *. )

  let one = 1.0
  let zero = 0.0

  let print_t out = Printf.fprintf out "%g"
end)

16 October 2006

``Automatic Differentiation'' in OCaml

Below I've posted an OCaml functor which lifts an input module of mathematical functions (+,-,*,/,sin,cos,...) to equivalent functions which also support computing derivatives without any truncation error (i.e. (d sin) is (exactly) the function cos). This technique is commonly called automatic differentiation in C++ (where it is accomplished at compile time through template metaprogramming), but I first was introduced to it by SICM; you can also find examples in Haskell in the Functional Programming and Numerical Computation paper here.

The basic idea is to compute derivatives using the chain rule---if you have a function f and you know its derivative Df and similarly for g and Dg, then you can compute D(f o g)(x) = Df(g(x)) * Dg(x). Computationally you represent a ``derivative value'' by a pair (x,dx); then f((x,dx)) = (f(x), Df(x)*dx). To compute the derivative of some unknown function h, you simply extract the second part of h((x,1)) (this makes more sense---to me, anyway---if you think of (x,1) as the derivative output of the identity function).

There's a bit of extra bookkeeping in the structure pasted in below, mostly because you may have derivatives with respect to several different arguments floating around in the computation at the same time---so each derivative object acquires a unique tag. The only tricky thing about this is that you have to keep some extra terms in the derivatives of binary operators: (x,dx)*(y,dy) = ((x*y, dx*y), (x*dy, dx*dy)). (Here the inner (,) groups refer to the ``x'' label and the outer group refers to the ``y'' label.) You would think that the dx*dy term is second-order and could be dropped, but it's first-order in dx and first-order in dy, and they are not the same variable, so you have to keep it. Gerry Sussman pointed this bug out to me the first time I wrote some code to do this---apparently it had tripped them up in SICM, too. It's a bit nasty, because the bug doesn't show up until you try to take second derivatives (so you have more than one tag floating around in the computation). If you write your own code to do this, consider yourself warned :).

What do I use this for? I'm using it to show that my integrator algorithms are definitely symplectic maps. A (two-dimensional) map (q,p) |-> (Q(q,p), P(q,p)) is symplectic if (dQ/dq*dP/dp - dQ/dp*dP/dq)(q,p) = 1. A nice way to show this for a computation algorithm is to automatically differentiate it and evaluate the combination of derivatives above along a trajectory. The combination should be 1 within roundoff error (though the module below will compute the derivatives without truncation error---that is without finite differencing---it is still subject to roundoff effects in the finite precision of floating-point arithmetic if you use it on ordinary double-precision floats).

Sorry for the long post, but at least they don't come too often :). By the way, since this code is pretty long and involved, I guess I should specify a license: I'm releasing it under the GPL.

(** Code to implement differentiation on arbitrary types. 

   diff.ml: Library to compute exact derivatives of arbitrary
   mathematical functions.

   Copyright (C) 2006 Will M. Farr 
   
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

*)

(** Input signature for [Make].*)
module type In = sig
  type t

(** The multiplicative identity *)
  val one : t

(** The additive identity *)
  val zero : t

(** Comparisons *)
  val compare : t -> t -> int

  val (+) : t -> t -> t
  val (-) : t -> t -> t
  val ( * ) : t -> t -> t
  val (/) : t -> t -> t
(** Arithmetic operations *)


  val sin : t -> t
  val cos : t -> t
  val tan : t -> t 
(** Trigonometric functions *)

  val asin : t -> t
  val acos : t -> t
  val atan : t -> t
(** Inverse trigonometric functions *)

  val sqrt : t -> t
  val log : t -> t
  val exp : t -> t
  val ( ** ) : t -> t -> t
(** Powers, exponents and logs. *)
end

(** Output signature for [Make]. *)
module type Out = sig 
(** Input from [In] *)
  type t 

(** Opaque tags for variables being differentiated. *)
  type tag

(** Output diff type.  A [diff] is either a constant [t] value [C(x)]
   or a tuple of [D(tag, x, dx)] where [x] and [dx] can also be
   [diff].  [D(tag, x, dx)] represents a small increment [dx] in the
   value labeled by [tag] about the value [x]. *)
  type diff = 
    | C of t 
    | D of tag * diff * diff

  val zero : diff
  val one : diff
(** Additive and multiplicative identities in [diff]. *)

(** [lift f f'] takes [f : t -> t] to [diff -> diff] using [f'] to
   compute the derivative.  For example, if we had already defined
   [(/) : diff -> diff -> diff], and the input signature didn't
   provide [log], we could define it using [let log = lift log ((/)
   one)]. *)
  val lift : (t -> t) -> (diff -> diff) -> (diff -> diff)

(** [lower f] takes a function defined on [diff] to the equivalent one
   defined on [t].  It is an error if [f (C x)] does not evaluate to
   [C y] for some [y].  The function is called [lower] because it is
   the inverse of the function [lift] which lifts [f : t -> t] to
   [diff -> diff].  *)
  val lower : (diff -> diff) -> (t -> t)

(** [lower_multi f] lowers [f] from [diff array -> diff] to [t array ->
   t].  It is an error if [f \[|(C x0); ...; (C xn)|\]] does not
   evaluate to [\[|(C y0); ...; (C yn)|\]].*)
  val lower_multi : (diff array -> diff) -> (t array -> t)

(** [lower_multi_multi f] lowers [f] from [diff array -> diff array]
   to [t array -> t array].  *)
  val lower_multi_multi : (diff array -> diff array) -> 
    (t array -> t array)

(** [d f] returns the function which computes the derivative of
   [f]. *)
  val d : (diff -> diff) -> (diff -> diff)

(** [partial i f] returns the function which computes the derivative
    of [f] with respect to its [i]th argument (counting from 0). *)
  val partial : int -> (diff array -> diff) -> (diff array -> diff)

(** [compare d1 d2] compares only the values stored in the
   derivative---that is either the [C x] or [D(_, x, _)].  [(<)],
   [(>)], ... are defined in terms of compare, which is defined in
   terms of [compare] from [In]. *)
  val compare : diff -> diff -> int

  val (<) : diff -> diff -> bool
  val (<=) : diff -> diff -> bool
  val (>) : diff -> diff -> bool
  val (>=) : diff -> diff -> bool
  val (=) : diff -> diff -> bool
  val (<>) : diff -> diff -> bool
(** Comparison functions *)

  val (+) : diff -> diff -> diff
  val (-) : diff -> diff -> diff
  val ( * ) : diff -> diff -> diff
  val (/) : diff -> diff -> diff
(** Algebra. *)

  val cos : diff -> diff
  val sin : diff -> diff
  val tan : diff -> diff
(** Trig *) 

  val acos : diff -> diff
  val asin : diff -> diff
  val atan : diff -> diff
(** Inverse trig *)

  val sqrt : diff -> diff
  val log : diff -> diff
  val exp : diff -> diff
  val ( ** ) : diff -> diff -> diff
(** Powers, exponents and logs *)

end

module Make(I : In) : Out with type t = I.t = struct
  type t = I.t

  type tag = int

(* Terms are either constant or of the form (x + dx), with x
   represented by tag *)
  type diff = 
    | C of t
    | D of tag * diff * diff

(* Additive and multiplicative identities in derivatives. *)
  let zero = C(I.zero)
  let one = C(I.one)

(* Unique tags *)
  let new_tag = 
    let count = ref 0 in 
    fun () -> 
      count := !count + 1;
      !count

(* Have to define the arithmetic operators first because they are used
   in [lift] and friends. *)
  let rec (+) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(+) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 + y, dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x + d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy ->
        D(tagx, x + y, dx + dy)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx, D(tagy, x + y, dy), dx)

  let rec (-) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(-) x y)
    | C(x), D(tag, y, dy) -> 
        D(tag, d1 - y, zero - dy)
    | D(tag, x, dx), C(y) -> 
        D(tag, x - d2, dx)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x - y, dx - dy)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx, D(tagy, x - y, zero - dy), dx)

  let rec ( * ) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.( * ) x y)
    | C(x), D(tag, y, dy) -> D(tag, d1 * y, d1 * dy)
    | D(tag, x, dx), C(y) -> D(tag, x * d2, dx * d2)
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x * y, x*dy + dx*y)
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx, 
          D(tagy, x*y, x*dy),
          D(tagy, dx*y, dx*dy))

  let rec (/) d1 d2 = 
    match d1, d2 with 
    | C(x), C(y) -> C(I.(/) x y)
    | D(tag, x, dx), C(y) -> D(tag, x/d2, dx/d2)
    | C(x), D(tag, y, dy) -> D(tag, d1/y, zero - d1*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
        D(tagx, x/y, dx/y - x*dy/(y*y))
    | D(tagx, x, dx), D(tagy, y, dy) -> 
        D(tagx,
          D(tagy, x/y, zero - x*dy/(y*y)),
          D(tagy, dx/y, zero - dx*dy/(y*y)))

  let lift f f' = 
    let rec lf = function 
      | C(x) -> C(f x)
      | D(tag, x, dx) -> D(tag, lf x, (f' x)*dx) in 
    lf

(* Now that we have the algebra of derivatives worked out, we can
   define lift, lift2, lower and lower2. *)
  let lower f = 
    fun x -> 
      match f (C x) with 
      | C(y) -> y
      | _ -> raise (Failure "lower expects numerical output")

let rec extract_deriv tag = function 
  | C(_) -> zero
  | D(tagx, x, dx) when tagx = tag -> 
      extract_deriv tag x + drop_deriv tag dx 
  | D(tagx, x, dx) -> 
      D(tagx, extract_deriv tag x, extract_deriv tag dx) and 
    drop_deriv tag = function 
      | D(tagx, x, dx) when tagx = tag -> 
          x
      | D(tagx, x, dx) -> 
          D(tagx, drop_deriv tag x, drop_deriv tag dx)
      | x -> x (* Only matches C(_) *)

let d f = 
  fun x -> 
    let tag = new_tag () in 
    let res = f (D(tag, x, one)) in 
    extract_deriv tag res

let rec compare d1 d2 = 
  match d1, d2 with 
  | C(x), C(y) -> I.compare x y
  | C(x), D(_, y, _) -> compare d1 y
  | D(_, x, _), C(y) -> compare x d2
  | D(_, x, _), D(_, y, _) -> compare x y

let rec cos = function 
  | C(x) -> C(I.cos x)
  | D(tag, x, dx) -> D(tag, cos x, zero - dx * (sin x)) and 
    sin = function 
      | C(x) -> C(I.sin x)
      | D(tag, x, dx) -> D(tag, sin x, dx * (cos x))

let rec tan = function 
  | C(x) -> C(I.tan x)
  | D(tag, x, dx) -> 
      let tx = tan x in 
      D(tag, tx, (one + tx*tx)*dx)

let rec sqrt = function 
  | C(x) -> C(I.sqrt x)
  | D(tag, x, dx) -> 
      let sx = sqrt x in 
      D(tag, sx, dx / (sx + sx))

let rec acos = function 
  | C(x) -> C(I.acos x)
  | D(tag, x, dx) -> 
      D(tag, acos x, zero - dx / (sqrt (one - x * x)))

let rec asin = function 
  | C(x) -> C(I.asin x)
  | D(tag, x, dx) -> 
      D(tag, asin x, dx / (sqrt (one - x * x)))

let rec atan = function 
  | C(x) -> C(I.atan x)
  | D(tag, x, dx) -> 
      D(tag, atan x, dx / (one + x*x))

let rec log = function 
  | C(x) -> C(I.log x)
  | D(tag, x, dx) -> D(tag, log x, dx/x)

let rec exp = function 
  | C(x) -> C(I.exp x)
  | D(tag, x, dx) -> 
      let ex = exp x in 
      D(tag, ex, dx*ex)

let rec ( ** ) d1 d2 = 
  match d1, d2 with 
  | C(x), C(y) -> C(I.( ** ) x y)
  | C(x), D(tag, y, dy) -> 
      D(tag, d1**y, d1**y * (log d1) * dy)
  | D(tag, x, dx), C(y) -> 
      D(tag, x**d2, x**(d2 - one)*d2*dx)
  | D(tagx, x, dx), D(tagy, y, dy) when tagx = tagy -> 
      D(tagx, x**y, x**(y - one)*(dx*y + dy*x*(log x)))
  | D(tagx, x, dx), D(tagy, y, dy) -> 
      D(tagx,
        D(tagy, x**y, x**y*(log x)*dy),
        D(tagy, y*x**(y - one), x**(y - one)*(one + y*(log x))*dy))

  let replace arr i x = 
    Array.mapi 
      (fun j y -> 
        if j <> i then y else x)
      arr

  let partial i f = 
    fun args -> 
      let x = args.(i) in 
      let one_d_f x = 
        f (replace args i x) in 
      (d one_d_f) x

  let c_ify arr = 
    Array.map (fun x -> (C x)) arr

  let de_c_ify arr = 
    Array.map 
      (function 
        | C(x) -> x
        | _ -> raise (Failure "cannot lower [|D(...); ...; D(...)|]"))
      arr

  let lower_multi f = 
    fun args -> 
      match (f (c_ify args)) with 
      | C(x) -> x
      | _ -> raise (Failure "cannot lower D(...)")

  let lower_multi_multi f = 
    fun args -> 
      de_c_ify (f (c_ify args))

(* Define these all at once (and last) so as not to spoil the
   comparison operator namespace. *)
let (<) d1 d2 = 
  compare d1 d2 < 0 and 
    (<=) d1 d2 = 
  not (compare d1 d2 > 0) and 
    (>) d1 d2 = 
  compare d1 d2 > 0 and
    (>=) d1 d2 = 
  not (compare d1 d2 < 0) and 
    (=) d1 d2 = 
  compare d1 d2 = 0 and 
    (<>) d1 d2 = 
  not (compare d1 d2 = 0)
end

28 August 2006

Working Toward Gauge-Covariant Derivatives

I've posted the latest version of my functional differential geometry code (darcs pull http://web.mit.edu/farr/www/SchemeCode/ should get it for you). I now have code for arbitrary linear representations of lie groups. I'm working up to implementing the gauge covariant derivative (a generalization of the covariant derivative of GR to arbitrary symmetry operations). Unfortunately, the code is presently really slow---just like practically every other schemer out there, I have coded up a quick memoization HOF, which I reproduce below. I suspect that this will prove useful in optimizing the code, which performs many redundant computations.

UPDATE: I've fixed up the code for representations of lie groups. It's a lot cleaner, and mirrors the math more closely now. Latest version in the darcs repository.

;    Copyright (C) 2006  Will M. Farr 
;
;    This program is free software; you can redistribute it and/or modify
;    it under the terms of the GNU General Public License as published by
;    the Free Software Foundation; either version 2 of the License, or
;    (at your option) any later version.
;
;    This program is distributed in the hope that it will be useful,
;    but WITHOUT ANY WARRANTY; without even the implied warranty of
;    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
;    GNU General Public License for more details.
;
;    You should have received a copy of the GNU General Public License along
;    with this program; if not, write to the Free Software Foundation, Inc.,
;    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

(module memoization mzscheme
  (provide memoize1 memoize)
  
  (define (memoize1 proc1)
    (let ((results (make-hash-table 'weak 'equal)))
      (lambda (x)
        (hash-table-get 
         results 
         x
         (lambda ()
           (let ((result (proc1 x)))
             (hash-table-put! results x result)
             result))))))
  
  (define (memoize proc)
    (let ((aux (memoize1 (lambda (x) (apply proc x)))))
      (lambda args (aux args)))))

22 August 2006

Adventures in OCaml Land: Barnes-Hut Trees

I'm presently working on a simple cosmological gravitational n-body code (if you don't know a "gravitational n-body code" is, you can get a really great introduction to gravitational simulation here). The simplest way to evaluate the forces on a body (the fundamental part of calculating its trajectory in an n-body simulation) is to iterate over all other bodies summing the standard Newtonian gravitational force. Unfortunately, this is O(n) work for each body, resulting in an algorithm which scales as O(n^2) to advance the entire system. When n ~ 1e6 to 1e9, this isn't going to work very well.

Fortunately, Josh Barnes and Piet Hut had a better idea: recursively divide the simulation space into an oct-tree (oct- because we live in 3 dimensions, so splitting in two on each dimension gives 8 sub-trees for each tree) until each cell contains only one body. To compute the force on a particular body, iterate over the tree; for each cell, if the size of that cell is "small" compared with the distance to the body, approximate the forces due to all bodies in the cell by the force from a pseudo-body which resides at their center of mass and has a mass equal to the total mass of the bodies in the cell. (Whew---that's awkward to say. Hopefully it makes sense.) If the cell is not sufficiently "small", then consider its sub-cells recursively. In general (assuming you don't have to walk the entire tree to determine the force on each body), this process is O(log(n)) for each body's force, for a total time O(n*log(n)) to advance the system. Constructing the tree is also O(n*log(n)), so the whole process is O(n*log(n)). Much better!

Real cosmological codes are much more complicated than this prescription (in fact, they often use a different method for computing the really-long-range forces based on using Fourier transforms to solve the Poisson equation for the potential on a grid, but that's a story for another day). They also expend lots of work so they can use these trees in a distributed computation (if you want 1e9 particles, at 100 bytes/particle, just writing down your system state takes 100 GB, so you'd better distribute the computation or you'll not be able to fit it in main memory). I'm fortunately not involved in that type of work. When I say "simple" cosmological n-body code, I mean one which can simulate maybe 1e6 bodies on a workstation. The code I've posted below (released under the GPL) is a Barnes-Hut functor in OCaml. It forms the basis of my code, and, as you can see in the early comments, performs respectably, even with 1e6 bodies.

By the way, I'll leave it as an exercise for the reader to formulate the force accumulation algorithm in terms of fold_w_abort :). Also note that this code works (I think) in any number of dimensions (only tested in 3).

(** Barnes-Hut tree functor.  Takes any structure which defines
   [body_q] and [body_m] functions.

   native code time for 1M-body tree construction: 82s.  Memory usage
   (just for construction): 223M.  PowerBook G4 800 MHz, 1GB ram, Mac
   OS 10.4.7, 18 August 2006.

   Copyright (C) 2006 Will M. Farr

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
   
   @author Will M. Farr *)

module type BODY = 
sig
  type body
  val dim : int
  val m : body -> float
  val q : body -> float array
end

module Make(B : BODY) = 
struct 

(** [max a b] is a local, float-specialized, version of the
   [Pervasives.max] function. *)
let max (a : float) (b : float) = 
  if a > b then a else b

(** [min a b] is a local, float-specialized, version of the
   [Pervasives.min] function. *)
let min (a : float) (b : float) = 
  if a < b then a else b

(** Bounds are stored (for efficiency) as an array of [\[|low0, high0,
   low1, high1, ... |\]].  Each bound is a half-open interval: [lowi]
   <= xi < [highi]*)
  type bounds = float array

(** A cell contains, in order, the total mass, center of mass, cell
  bounds, cell size squared, sub-cells. *)
  type tree = 
    | Empty
    | Body of B.body
    | Cell of float * float array * bounds * float * tree array

(** [m t] returns the mass of the tree [t] *)
  let m = function 
    | Empty -> 0.0
    | Body(b) -> B.m b
    | Cell(m, _, _, _, _) -> m

(** [q t] returns the center of mass of the tree [t] *)
   let q = function 
   | Empty -> Array.make 3 0.0
   | Body(b) -> B.q b
   | Cell(_, com, _, _, _) -> com

(** [low_bound bds i] returns the lower bound on dimension [i] given
   bounds [bds]. *)
let low_bound (bds : bounds) i = bds.(2*i)

(** [high_bound bds i] return the high bound on dimension [i] given
   bounds [bds]. *)
let high_bound (bds : bounds) i = bds.(2*i+1)

(** [in_bounds bs v] checks whether [v] is in the bounds given by
   [bs]. *)
  let in_bounds (bs : bounds) v = 
    let n = Array.length v in 
    let rec loop i = 
      if i >= n then 
 true
      else if v.(i) >= low_bound bs i && v.(i) < high_bound bs i then 
 loop (i + 1)
      else
 false in 
    loop 0

(** [even i] and [odd i] *)
  let even i = (i mod 2) = 0
  let odd i = not (even i)

(** [pow_int i n] computes [i]^[n]. *)
  let rec pow_int i n = 
    if n = 0 then 
      1
    else if n = 1 then 
      i
    else if even n then 
      let nn = n / 2 in 
      (pow_int i nn)*(pow_int i nn)
    else 
      let nn = n / 2 in 
      i*(pow_int i nn)*(pow_int i nn)

(** [bit_set i n] is [true] if bit [i] of the integer [n] (bit 0 is
   the least-significant bit of [n]) is 1 and [false] otherwise. *)
  let bit_set i n = 
    n land (1 lsl i) > 0 

(** [sub_bounds bs] returns an array of the 2^d (where d is the
   dimension of the space---[bs] is 2*d elements long) sub-bounds of
   the [bs] obtained by splitting the bounds in half on each
   dimension. *)
  let sub_bounds bs = 
    let nb = 2*B.dim and
 nsb = 1 lsl B.dim in
    Array.init nsb 
      (fun i -> 
 (* i is now encodes (bitwise) whether the bound in each
    dimension is high or low *)
 let sb = Array.make nb 0.0 in 
 for j = 0 to B.dim - 1 do 
   (* j labels the dimension *)
   let mid = ((low_bound bs j) +. (high_bound bs j))/.2.0 in 
   if bit_set j i then 
     begin
       sb.(2*j) <- mid;
       sb.(2*j+1) <- high_bound bs j
     end
   else
     begin
       sb.(2*j) <- low_bound bs j;
       sb.(2*j+1) <- mid
     end
 done;
 sb)

(** [make_null_bounds ()] returns a fresh set of bounds which enclose
   {b no} possible object. *)
let make_null_bounds () = 
  Array.init (2*B.dim)
    (fun i -> 
      if even i then 
 infinity
      else
 neg_infinity)

(** [expand q] creates a value which is a bit bigger than [q] so that
   [bounds_of_bodies bs] returns bounds which guarantee to enclose
   [bs]. *)
let expand =
  let factor = sqrt epsilon_float in 
  fun q -> 
    q +. (abs_float q)*.factor

(** [bounds_of_bodies bs] returns a bounds which completely enclose
   the given bodies [bs]. *)
  let bounds_of_bodies bs = 
    let bds = make_null_bounds () in 
    List.iter
      (fun b -> 
 let bq = B.q b in 
 Array.iteri 
   (fun i q -> 
     bds.(2*i) <- min bds.(2*i) q;
     bds.(2*i+1) <- max bds.(2*i+1) (expand q)) 
   bq)
      bs;
    bds

(** [bounds_size_squared bds] returns the size squared of the given
   bounds (i.e. the sum of squares of distances along each
   dimension).*)
let bounds_size_squared bds = 
  let size = ref 0.0 and 
      n = (Array.length bds)/2 in 
  for i = 0 to n - 1 do 
    size := !size +. (Vector.square (bds.(2*i) -. bds.(2*i+1)))
  done;
  !size

(** [mass_and_com sts] returns the mass and center-of-mass of the [sts] *)
  let mass_and_com sts = 
    let mass = 
      Array.fold_left
 (fun mass t -> 
   mass +. (m t))
 0.0 
 sts in 
    let com = 
      Array.fold_left 
 (fun com t -> 
   match t with 
   | Empty -> com
   | _ -> 
       let mt = m t and 
    qt = q t in
       for i = 0 to Array.length qt - 1 do 
  com.(i) <- com.(i) +. qt.(i)*.mt/.mass
       done;
       com)
 (Array.make B.dim 0.0)
 sts in 
    (mass, com)

(** [tree_of_body_list bs] constructs a tree which contains [bs]. *)
let rec tree_of_body_list = function
  | [] -> Empty
  | [b] -> Body(b)
  | bs -> 
      let bds = bounds_of_bodies bs in
      let s = bounds_size_squared bds in 
      let sub_bds = sub_bounds bds in 
      let sub_trees = Array.map (fun bd -> 
 let sub_bs = List.filter (fun b -> in_bounds bd (B.q b)) bs in 
 tree_of_body_list sub_bs)
   sub_bds in 
      let m, com = mass_and_com sub_trees in 
      Cell(m, com, bds, s, sub_trees)

(** [tree_of_bodies bs] returns the tree which contains [bs]. *)
   let tree_of_bodies bs = 
     let lbs = Array.to_list bs in 
     tree_of_body_list lbs

(** [fold fn start t] is the fundamental tree iterator.  Alas, there
       is no guarantee what the order of application of [fn] is. *)
  let rec fold fn start t = 
    match t with 
    | Empty -> start
    | Body(_) -> fn start t
    | Cell(_, _, _, _, sts) -> 
 Array.fold_left
   (fold fn) (fn start t) sts

(** [fold_w_abort fn start t] folds [fn] over [t] (with initial value
   [start]).  [fn] is applied to each tree-node before it is applied
   to the sub-nodes.  If [fn] returns [(false, value)] value is
   returned as the result of the fold for this entire branch---the
   recursion stops *)
  let rec fold_w_abort fn start t = 
    match t with 
    | Empty -> start
    | Body(_) -> snd (fn start t)
    | Cell(_, _, _, _, sts) -> 
 let (cont, new_start) = fn start t in 
 if cont then 
   Array.fold_left (fold_w_abort fn) new_start sts
 else
   new_start

(** [contains b t] returns [true] if [t] contains [b]. *)
  let rec contains b t = 
    let q = B.q b in 
    match t with 
    | Empty -> false
    | Body(b2) -> b == b2
    | Cell(_, _, bds, _, sts) -> 
 if in_bounds bds q then begin 
   let n = Array.length sts in 
   let rec loop i = 
     if i >= n then 
       false
     else if contains b sts.(i) then 
       true
     else
       loop (i + 1) in 
   loop 0 
 end else
   false
end

18 August 2006

One More Example of Python Generators in Scheme

I've had some spare time at work these past few days waiting for modestly long-running data analysis procedures to complete. Unfortunately, this spare time comes in ~5 min chunks, so it's not really worth a serious context switch into other work-related stuff. So, I've been doing a lot of little things---like web browsing about Python generators in scheme. I'm unable to resist posting my own implementation of generators here. The interesting thing about these (as opposed to the two I've linked to above) is that they allow multiple-value yields. Probably not terribly efficient (I think I could get away with let/ec, for example), but nonetheless fun.
(module generators mzscheme
  (provide define-generator)
  
  (define-syntax define-generator
    (lambda (stx)
      (syntax-case stx ()
        ((define-generator (name arg ...) body0 body1 ...)
         (with-syntax ((yield (datum->syntax-object 
                               (syntax body0) 
                               'yield)))
           (syntax
            (define (name arg ...)
              (letrec ((continue-k #f)
                       (return-k #f)
                       (yield
                        (lambda args
                          (let/cc cont
                            (set! continue-k cont)
                            (call-with-values 
                             (lambda () (apply values args)) 
                             return-k)))))
                (lambda ()
                  (let/cc ret
                    (set! return-k ret)
                    (if continue-k
                        (continue-k '())
                        (begin
                          body0 body1 ...
                          (error 'name "reached end of generator values"))))))))))))))

Examples of use:

> (require generators)
> (define-generator (nums-from n)
    (let loop ((i n))
      (yield i)
      (loop (+ i 1))))
> (define ten-up (nums-from 10))
> (ten-up)
10
> (ten-up)
11
> (ten-up)
12
> (ten-up)
13
> (ten-up)
14
and
> (define-generator (next-two-from n)
    (let loop ((i n))
      (yield i (+ i 1))
      (loop (+ i 2))))
> (define ten-by-twos (next-two-from 10))
> (let-values (((a b)
                (ten-by-twos)))
    (+ a b))
21
> (ten-by-twos)
12
13

02 August 2006

SRFI-11 For Bigloo

Tired of using multiple-value-bind all the time in Bigloo, I've packaged up SRFI-11 for Bigloo. You can find it here. It even builds the _e library target for the interpreter's new library-load command. Enjoy your let-values.

10 July 2006

Yet Another Sudoku Solver

While flying home from vacation, I composed the following Sudoku solver. It seems to solve all the puzzles I throw at it, but definitely is not the most efficient algorithm possible. Basically it applies the Sudoku constraints (no repeat numbers in a row, column, or 3x3 block) to the puzzle (represented as a vector of vectors, with lists for elements which may have more than one possible number at a given point in the solution) until it cannot narrow the possibilities further. If there remain elements with more than one possible number, it makes an ambiguous choice of one of the multiple possibilities. Given this choice, the algorithm again tries to narrow the remaining multiple possibilities, and, if that fails, chooses ambiguously again, ad infinitum. Once there are no more multiple possibilities, the algorithm checks the sudoku constraints one more time---if they are not met, then we fail and choose a different (amb ...).

A primary goal here for reasonable performance was to minimize (within reasonable coding effort) the number of continuations which have to be captured by the amb operator---hence the repeated narrowing of possibilities before calls to amb. Here's the code (including a nifty macro which allows specifying a puzzle easily):

(module sudoku mzscheme
  (require (all-except (lib "43.ss" "srfi") vector-fill! vector->list)
           (only (lib "1.ss" "srfi") filter take drop)
           (lib "extra.ss" "swindle"))
  
  (provide sudoku-board solve write-sudoku-board)
  
  ;; Sudoku board is eventually represented as a vector of rows
  ;; which are themselves vectors of length nine.  The syntax
  ;; sudoku-board produces such a board from a list of numbers or 
  ;; '?, which represents an unknown square.
  
  (define (any) (list 1 2 3 4 5 6 7 8 9))
  
  (define (list->board lst)
    (apply vector 
           (let loop ((lst lst) (list-of-nines '()))
             (if (null? lst)
                 (reverse list-of-nines)
                 (loop (drop lst 9) 
                       (cons (apply vector (take lst 9)) list-of-nines))))))
  
  (define-syntax sudoku-board
    (syntax-rules ()
      ((sudoku-board elt ...)
       (list->board 
        (process-elts (processed ) elt ...)))))
  
  (define-syntax process-elts
    (syntax-rules (processed ?)
      ((process-elts (processed pelt ...) ?)
       (reverse (list (any) pelt ...)))
      ((process-elts (processed pelt ...) elt)
       (reverse (list elt pelt ...)))
      ((process-elts (processed pelt ...) ? elt ...)
       (process-elts (processed (any) pelt ...) elt ...))
      ((process-elts (processed pelt ...) elt elt2 ...)
       (process-elts (processed elt pelt ...) elt2 ...))))
  
  (define (cubant i)
    (floor (/ i 3)))
  
  (define (same-cubant? ii jj i j)
    (and (= (cubant i)
            (cubant ii))
         (= (cubant j)
            (cubant jj))))
  
  (define (known? elt)
    (number? elt))
  (define (unknown? elt)
    (list? elt))
  
  (define (board-map fn board)
    (vector-map 
     (lambda (i row)
       (vector-map 
        (lambda (j elt)
          (fn i j elt))
        row))
     board))
  
  (define (board-fold fn start board)
    (vector-fold
     (lambda (i start row)
       (vector-fold
        (lambda (j start elt)
          (fn i j start elt))
        start
        row))
     start 
     board))
  
  (define (board-ref b i j)
    (vector-ref (vector-ref b i) j))
  
  (define (prune ii jj number board)
    (board-map 
     (lambda (i j elt)
       (if (known? elt)
           elt
           (cond
             ((= i ii)
              (filter (lambda (elt) (not (= elt number))) elt))
             ((= j jj)
              (filter (lambda (elt) (not (= elt number))) elt))
             ((same-cubant? ii jj i j)
              (filter (lambda (elt) (not (= elt number))) elt))
             (else elt))))
     board))
  
  (define (singleton? elt)
    (and (pair? elt)
         (null? (cdr elt))))
  
  (define (expand-singletons board)
    (board-map
     (lambda (i j elt)
       (if (singleton? elt)
           (car elt)
           elt))
     board))
  
  (define (prune-all board)
    (let ((new-board
           (expand-singletons
            (board-fold
             (lambda (i j nb elt)
               (if (known? elt)
                   (prune i j elt nb)
                   nb))
             board
             board))))
      (if (equal? new-board board)
          new-board
          (prune-all new-board))))
  
  (define (amb-list list)
    (if (null? list)
        (amb)
        (amb (car list) (amb-list (cdr list)))))

  (define (amb-board board)
    (board-fold
     (lambda (i j nb elt)
       (let ((elt (board-ref nb i j)))
         (if (known? elt)
             nb
             (let ((choice (amb-list elt)))
               (prune-all (board-map 
                           (lambda (ii jj elt) (if (and (= i ii)
                                                        (= j jj))
                                                   choice
                                                   elt))
                                     nb))))))
     board
     board))
     
  (define (board-assertions board)
    (board-fold
     (lambda (i j dummy elt1)
       (board-fold
        (lambda (ii jj dummy elt2)
          (cond
            ((and (= ii i)
                  (= jj j)))
            ((= ii i)
             (amb-assert (not (= elt1 elt2))))
            ((= jj j)
             (amb-assert (not (= elt1 elt2))))
            ((same-cubant? ii jj i j)
             (amb-assert (not (= elt1 elt2))))))
        '()
        board))
     '()
     board))
  
  (define (solve board)
    (let ((new-board (prune-all board)))
      (let ((final-board (amb-board new-board)))
        (board-assertions final-board)
        final-board)))
  
  (define (write-sudoku-board board)
    (printf "(sudoku-board")
    (board-fold
     (lambda (i j dummy elt)
       (if (= j 0)
           (newline))
       (if (pair? elt)
           (printf " ?")
           (printf " ~a" elt))
       (if (and (= i 8)
                (= j 8))
           (printf ")~%")))
     '()
     board)))

Example of use:

  (write-sudoku-board
   (solve
    (sudoku-board
     ? ? 8 ? ? ? 1 5 ?
     ? ? ? ? ? 1 8 ? ? 
     3 ? 5 4 ? ? ? ? 9
     5 ? ? ? ? 9 ? ? ?
     ? 9 ? 2 3 4 ? 7 ?
     ? ? ? 1 ? ? ? ? 8
     4 ? ? ? ? 5 9 ? 1
     ? ? 6 7 ? ? ? ? ?
     ? 5 3 ? ? ? 2 ? ?)))
which evaluates to
  (sudoku-board
   7 4 8 3 9 2 1 5 6
   2 6 9 5 7 1 8 4 3
   3 1 5 4 8 6 7 2 9
   5 7 4 8 6 9 3 1 2
   8 9 1 2 3 4 6 7 5
   6 3 2 1 5 7 4 9 8
   4 8 7 6 2 5 9 3 1
   9 2 6 7 1 3 5 8 4
   1 5 3 9 4 8 2 6 7)

05 July 2006

Defining your own Lie Group

Today I was asked how to define your own Lie Group using my functional differential geometry software (see this post and subsequent posts). Basically, you have to provide five things:
  1. Chi : G -> R^n. This is the coordinate function on the group manifold.
  2. Chi^-1 : R^n -> G. The inverse of Chi.
  3. e \in G. The identity element.
  4. inverse: G -> G. The inverse function (on the group manifold: g -> g^-1).
  5. *: GxG -> G. The group multiplication function.
1&2 are provided through a <chart> object. Note carefully the signatures of these functions; in particular, note that only Chi and Chi^-1 deal with R^n while all others deal directly with group elements.

See lie-group-SO3.ss for an example of how this is done. It's a bit tricky, because the <chart> (which contains Chi and Chi^-1) must take and produce objects of the <lie-group-element> class, which require the <lie-group> class for one of the slots; but the <lie-group> class requires the <chart> object, so they must be recursively defined.

27 June 2006

Figured it Out

The bit at the end of the last post where I was worried about the commutator (lie bracket in the lie algebra of vector fields, not vectors at the identity) of two left-invariant vector fields not being itself left-invariant was just a mistake. I shouldn't expect [extend(d/dtheta), extend(d/dphi)] = d/dpsi, but rather [extend(d/dtheta), extend(d/dphi)] = extend(d/dpsi)! If you do that, then it works out:
(test-case 
 "Extended tangent vectors are really left-invariant"
 (check-tuple-close? 
  1e-6
  ((vector-field->component-field 
    (- ((lie-algebra-bracket SO3) d/dphi d/dtheta)
       ((natural-extension SO3) d/dpsi))
    SO3-rectangular-chart)
   ((slot-ref SO3-rectangular-chart 'chiinv)
    (up 0.02345453 0.0349587 0.0435897)))
  (up 0 0 0)))
passes with flying colors. (In case the above code is not completely clear---irony of the century---it's computing the components of the vector field ([extend(d/dtheta),extend(d/dphi)]-extend(d/dpsi)) at some arbitrary point in SO3 and verifying that they're zero.

26 June 2006

Lie Groups Mostly Working

It's working! (Mostly.) See:
;; Welcome to DrScheme, version 350.1-svn21jun2006.
;; Language: Swindle.

;; Require some modules 
(require "lie-group-SO3.ss")
(require "lie-groups.ss")
(require "manifolds.ss")
(require "tuples.ss")
;;

;;; Compute the structure constants for SO3 in the 
;;; x,y,z coordinate system.  Exactly what you 
;;; would expect.
(structure-constants SO3 SO3-rectangular-chart)
;#<down: elts=#(#<down: elts=#(#<up: elts=#(0 0 0)> 
;                              #<up: elts=#(0 0 1)> 
;                              #<up: elts=#(0 -1 0)>)> 
;               #<down: elts=#(#<up: elts=#(0 0 -1)> 
;                              #<up: elts=#(0 0 0)> 
;                              #<up: elts=#(1 0 0)>)> 
;               #<down: elts=#(#<up: elts=#(0 1 0)>
;                              #<up: elts=#(-1 0 0)> 
;                              #<up: elts=#(0 0 0)>)>)>
                                                                                                        
                                                                                                        
;; Name the euler angle coordinates
(define-named-coordinates (theta phi psi) 
  SO3-euler-angles-chart)

;; Note that coordinate vectors commute
((vector-field->component-field 
  (lie-bracket d/dtheta d/dphi) 
  SO3-euler-angles-chart) 
 (slot-ref SO3 'identity))
;; #<up: elts=#(0 0 0)>

;; While the extensions of coordinate vectors under 
;; left-multiplication do not.
((vector-field->component-field 
  ((lie-algebra-bracket SO3) d/dtheta d/dphi) 
  SO3-euler-angles-chart)
 (slot-ref SO3 'identity))
;; #<up: elts=#(0 0 1)>
[Edited: used to say that I didn't understand why [natural-extension(d/dtheta), natural-extension(d/dphi)] <> d/dpsi, but now I do.

24 June 2006

Functional Differential Geometry in PLT Scheme

I've been working a bit lately with Gerry Sussman and Jack Wisdom to extend their software for functional differential geometry to handle Lie Groups (which, after all, are just manifolds). My thesis will probably have to do with treating General Relativity as an SO(3,1) gauge theory, so I'm particularly interested in the Special Orthogonal Lie Groups. (Jack, as a solar-system dynamicist, is also particularly interested in the Special Orthogonal groups because they represent rotations---in addition to being interested in GR as a gauge theory. Gerry is just interested in everything.)

Unfortunately, the software they have for doing this runs in MIT Scheme, and I own a PowerBook running OS X. It's not really a problem to SSH into a computer which runs their system, but it's nicer to have access to it on my own computer. So: I've coded up a bare-bones version of the scmutils/differential-geometry system for myself in PLT Scheme. It does no symbolic manipulation (only numerical calculations allowed), and doesn't have much of the nifty stuff that comes with scmutils proper, but it is able to compute on arbitrary manifolds and to handle Lie groups. I've been using SchemeUnit to run tests on the code as I go along, so there's a pretty good chance that it doesn't have major bugs.

I've posted my current darcs repository here; a darcs get http://web.mit.edu/farr/www/SchemeCode should get it for you, if you're interested. I haven't really written any documentation yet, but the test files should give some examples of how the functions are meant to be used. Comments are, of course, welcome at farr@mit.edu.

I'm not really sure why I'm posting this for the wider world right now (since it's so incomplete and "internal"). I hope someone finds it useful or interesting.

20 June 2006

SRFI-78 for Bigloo

I've just posted an implementation of SRFI-78: Lightweight Testing for Bigloo. You can get it here. It leaves out the check-ec form because there's no implementation of SRFI-42 for Bigloo yet, but otherwise it's complete. The framework is very lightweight---it took about 10 minutes to port.

I want the framework so that I can test my new SRFI-4 implementation for Bigloo before I post it, so stay tuned for that! It feels good to be back in Scheme---much of my recent scientific coding has been in Ocaml (not that there's anything wrong with Ocaml, but I have missed Scheme).

08 June 2006

Fun With Streams

Or: a bit of Haskell in PLT Scheme.

In a one-off version of the n-body integrator I've been working on (seemingly in perpetuity, but probably finishing soon with a couple of papers), I had occasion to iterate a function to convergence. (i.e. to compute (f (f (f ... (f x)))) until the result converges.) Well, it's pretty easy to write:

(define (iter-to-convergence f)
 (lambda (x)
   (let loop ((last (f x)))
     (let ((next (f last)))
       (if (= last next)
           next
           (loop next))))))
but where's the fun in that? It's much more fun to use streams for this.

So, using the streams module appended below (shamelessly stolen from SICP), I can write

(define (iter-to-convergence f)
 (lambda (x)
   (letrec ((vals
             (stream-cons (f x) (stream-map f vals))))
     (stream-take-converged vals))))

This idiom is very common in Haskell (and other lazy languages)---the canonical example is, of course, (define ones (stream-cons 1 ones)). The fact that this is so easy to add to scheme is a real testament to the language's power. What other language lets you (lazily) have your cake and (eventually) eat it, too?

Here's the code for the stream module (feel free to steal it if you want):

(module streams mzscheme
 (require (only (lib "1.ss" "srfi") any))

 (provide stream? stream-car stream-cdr stream-ref stream-cons
          (rename my-stream stream) stream-map stream-filter stream-take
          stream-converged? stream-take-converged)
  
 (define-struct stream (a b) #f)

 (define stream-car stream-a)

 (define (stream-cdr s)
   (force (stream-b s)))

 (define (stream-ref s n)
   (if (= n 0)
       (stream-car s)
       (stream-ref (stream-cdr s) (- n 1))))

 (define-syntax stream-cons
   (syntax-rules ()
     ((cons-stream car cdr)
      (make-stream car (delay cdr)))))

 (define-syntax my-stream
   (syntax-rules ()
     ((stream a)
      (stream-cons a ()))
     ((stream a b)
      (stream-cons a b))
     ((stream a b ...)
      (stream-cons a (my-stream b ...)))))      

 (define (stream-map f . ss)
   (stream-cons (apply f (map stream-car ss)) (apply stream-map f (map stream-cdr ss))))

 (define (stream-for-each f . ss)
   (apply f (map stream-car ss))
   (stream-for-each f (map stream-cdr ss)))

 (define (stream-filter test? s)
   (let ((s1 (stream-car s)))
     (if (test? s1)
         (stream-cons s1 (stream-filter test? (stream-cdr s)))
         (stream-filter test? (stream-cdr s)))))

 (define (stream-take n s)
   (if (= n 0)
       '()
       (cons (stream-car s) (stream-take (- n 1) (stream-cdr s)))))

 (define (stream-converged? s)
   (= (stream-car s) (stream-ref s 1)))

 (define (stream-take-converged s)
   (if (stream-converged? s)
       (stream-car s)
       (stream-take-converged (stream-cdr s)))))

Here's a more involved example of what you can do with this: the sieve of Eratosthenes (SICP 3.5.2).

(define (first-n-primes n)
 (stream-take
  n
  (letrec ((integers-from (lambda (n) (stream-cons n (integers-from (+ n 1)))))
           (primes (stream-cons 2 (stream-filter prime? (integers-from 3))))
           (prime? (lambda (x)
                     (let iter ((ps primes))
                       (let ((p (stream-car ps)))
                         (cond
                           ((= (remainder x p) 0) #f)
                           ((> (* p p) x) #t)
                           (else (iter (stream-cdr ps)))))))))
    primes)))

08 March 2006

Tidbits

My god! It's been about a month since I last posted (about PLT Scheme's new JIT). Just to prime the pump (hopefully more to follow more quickly):
  • There's a very interesting discussion on the Ocaml mailing list about Software Transactional Memory for concurrent programming. In particular, see the referenced papers here and here. The basic idea is to wrap up all accesses to shared memory in a STM monad, which is eventually sequenced atomically. The neat thing about doing this is that the concurrent mechanisms compose; see the papers for how it all works out. Now, there's no reason another language couldn't do this, but Haskell is ideal because the pure semantics and the type system conspire to prevent any non-pure accesses from within the atomic section that are not to "registered" STM slots (i.e. the atomic sequencer won't thread parts of the IO monad through the computation---only STM monads are allowed). Pretty slick stuff.
  • The nbody integrator paper I'm working on is coming along nicely (I'm a physicist, though you wouldn't know it to look at the blog); just some quick (i.e. a day or two long) simulations to run and then a lot of editing before it's ready for submission. Some time soon I may post a quick synopsis of the idea; for now you can read about what it's not (though similar) in Jack Wisdom's Symplectic Maps for the N-Body Problem (with Holman). If that tickles your fancy, have a look at Jack's Symplectic Correctors (also with Holman). If that doesn't make your jaw drop then you must not be much of a dynamicist!

01 February 2006

PLT Gets JIT

According to this message to the PLT Scheme mailing list, PLT Scheme just got a JIT compiler. It only works on the non-3m (i.e. Boehm conservative collector) version of PLT Scheme, but it seems to provide some speedups. Subsequent messages suggest that they're working on a "proper" JIT compiler targeting LLVM which should provide even more speed. Very nice!

Update: I've been using the JIT version for the last day (my normal preference is for the 3m version), and it's really nice! I see speedup near a factor of 2 for much of my code (which is pretty loop-heavy, so maybe that helps get the extra speedups), and the whole experience of using DrScheme feels much more snappy!