22 August 2006

Adventures in OCaml Land: Barnes-Hut Trees

I'm presently working on a simple cosmological gravitational n-body code (if you don't know a "gravitational n-body code" is, you can get a really great introduction to gravitational simulation here). The simplest way to evaluate the forces on a body (the fundamental part of calculating its trajectory in an n-body simulation) is to iterate over all other bodies summing the standard Newtonian gravitational force. Unfortunately, this is O(n) work for each body, resulting in an algorithm which scales as O(n^2) to advance the entire system. When n ~ 1e6 to 1e9, this isn't going to work very well.

Fortunately, Josh Barnes and Piet Hut had a better idea: recursively divide the simulation space into an oct-tree (oct- because we live in 3 dimensions, so splitting in two on each dimension gives 8 sub-trees for each tree) until each cell contains only one body. To compute the force on a particular body, iterate over the tree; for each cell, if the size of that cell is "small" compared with the distance to the body, approximate the forces due to all bodies in the cell by the force from a pseudo-body which resides at their center of mass and has a mass equal to the total mass of the bodies in the cell. (Whew---that's awkward to say. Hopefully it makes sense.) If the cell is not sufficiently "small", then consider its sub-cells recursively. In general (assuming you don't have to walk the entire tree to determine the force on each body), this process is O(log(n)) for each body's force, for a total time O(n*log(n)) to advance the system. Constructing the tree is also O(n*log(n)), so the whole process is O(n*log(n)). Much better!

Real cosmological codes are much more complicated than this prescription (in fact, they often use a different method for computing the really-long-range forces based on using Fourier transforms to solve the Poisson equation for the potential on a grid, but that's a story for another day). They also expend lots of work so they can use these trees in a distributed computation (if you want 1e9 particles, at 100 bytes/particle, just writing down your system state takes 100 GB, so you'd better distribute the computation or you'll not be able to fit it in main memory). I'm fortunately not involved in that type of work. When I say "simple" cosmological n-body code, I mean one which can simulate maybe 1e6 bodies on a workstation. The code I've posted below (released under the GPL) is a Barnes-Hut functor in OCaml. It forms the basis of my code, and, as you can see in the early comments, performs respectably, even with 1e6 bodies.

By the way, I'll leave it as an exercise for the reader to formulate the force accumulation algorithm in terms of fold_w_abort :). Also note that this code works (I think) in any number of dimensions (only tested in 3).

(** Barnes-Hut tree functor.  Takes any structure which defines
   [body_q] and [body_m] functions.

   native code time for 1M-body tree construction: 82s.  Memory usage
   (just for construction): 223M.  PowerBook G4 800 MHz, 1GB ram, Mac
   OS 10.4.7, 18 August 2006.

   Copyright (C) 2006 Will M. Farr

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
   
   @author Will M. Farr *)

module type BODY = 
sig
  type body
  val dim : int
  val m : body -> float
  val q : body -> float array
end

module Make(B : BODY) = 
struct 

(** [max a b] is a local, float-specialized, version of the
   [Pervasives.max] function. *)
let max (a : float) (b : float) = 
  if a > b then a else b

(** [min a b] is a local, float-specialized, version of the
   [Pervasives.min] function. *)
let min (a : float) (b : float) = 
  if a < b then a else b

(** Bounds are stored (for efficiency) as an array of [\[|low0, high0,
   low1, high1, ... |\]].  Each bound is a half-open interval: [lowi]
   <= xi < [highi]*)
  type bounds = float array

(** A cell contains, in order, the total mass, center of mass, cell
  bounds, cell size squared, sub-cells. *)
  type tree = 
    | Empty
    | Body of B.body
    | Cell of float * float array * bounds * float * tree array

(** [m t] returns the mass of the tree [t] *)
  let m = function 
    | Empty -> 0.0
    | Body(b) -> B.m b
    | Cell(m, _, _, _, _) -> m

(** [q t] returns the center of mass of the tree [t] *)
   let q = function 
   | Empty -> Array.make 3 0.0
   | Body(b) -> B.q b
   | Cell(_, com, _, _, _) -> com

(** [low_bound bds i] returns the lower bound on dimension [i] given
   bounds [bds]. *)
let low_bound (bds : bounds) i = bds.(2*i)

(** [high_bound bds i] return the high bound on dimension [i] given
   bounds [bds]. *)
let high_bound (bds : bounds) i = bds.(2*i+1)

(** [in_bounds bs v] checks whether [v] is in the bounds given by
   [bs]. *)
  let in_bounds (bs : bounds) v = 
    let n = Array.length v in 
    let rec loop i = 
      if i >= n then 
 true
      else if v.(i) >= low_bound bs i && v.(i) < high_bound bs i then 
 loop (i + 1)
      else
 false in 
    loop 0

(** [even i] and [odd i] *)
  let even i = (i mod 2) = 0
  let odd i = not (even i)

(** [pow_int i n] computes [i]^[n]. *)
  let rec pow_int i n = 
    if n = 0 then 
      1
    else if n = 1 then 
      i
    else if even n then 
      let nn = n / 2 in 
      (pow_int i nn)*(pow_int i nn)
    else 
      let nn = n / 2 in 
      i*(pow_int i nn)*(pow_int i nn)

(** [bit_set i n] is [true] if bit [i] of the integer [n] (bit 0 is
   the least-significant bit of [n]) is 1 and [false] otherwise. *)
  let bit_set i n = 
    n land (1 lsl i) > 0 

(** [sub_bounds bs] returns an array of the 2^d (where d is the
   dimension of the space---[bs] is 2*d elements long) sub-bounds of
   the [bs] obtained by splitting the bounds in half on each
   dimension. *)
  let sub_bounds bs = 
    let nb = 2*B.dim and
 nsb = 1 lsl B.dim in
    Array.init nsb 
      (fun i -> 
 (* i is now encodes (bitwise) whether the bound in each
    dimension is high or low *)
 let sb = Array.make nb 0.0 in 
 for j = 0 to B.dim - 1 do 
   (* j labels the dimension *)
   let mid = ((low_bound bs j) +. (high_bound bs j))/.2.0 in 
   if bit_set j i then 
     begin
       sb.(2*j) <- mid;
       sb.(2*j+1) <- high_bound bs j
     end
   else
     begin
       sb.(2*j) <- low_bound bs j;
       sb.(2*j+1) <- mid
     end
 done;
 sb)

(** [make_null_bounds ()] returns a fresh set of bounds which enclose
   {b no} possible object. *)
let make_null_bounds () = 
  Array.init (2*B.dim)
    (fun i -> 
      if even i then 
 infinity
      else
 neg_infinity)

(** [expand q] creates a value which is a bit bigger than [q] so that
   [bounds_of_bodies bs] returns bounds which guarantee to enclose
   [bs]. *)
let expand =
  let factor = sqrt epsilon_float in 
  fun q -> 
    q +. (abs_float q)*.factor

(** [bounds_of_bodies bs] returns a bounds which completely enclose
   the given bodies [bs]. *)
  let bounds_of_bodies bs = 
    let bds = make_null_bounds () in 
    List.iter
      (fun b -> 
 let bq = B.q b in 
 Array.iteri 
   (fun i q -> 
     bds.(2*i) <- min bds.(2*i) q;
     bds.(2*i+1) <- max bds.(2*i+1) (expand q)) 
   bq)
      bs;
    bds

(** [bounds_size_squared bds] returns the size squared of the given
   bounds (i.e. the sum of squares of distances along each
   dimension).*)
let bounds_size_squared bds = 
  let size = ref 0.0 and 
      n = (Array.length bds)/2 in 
  for i = 0 to n - 1 do 
    size := !size +. (Vector.square (bds.(2*i) -. bds.(2*i+1)))
  done;
  !size

(** [mass_and_com sts] returns the mass and center-of-mass of the [sts] *)
  let mass_and_com sts = 
    let mass = 
      Array.fold_left
 (fun mass t -> 
   mass +. (m t))
 0.0 
 sts in 
    let com = 
      Array.fold_left 
 (fun com t -> 
   match t with 
   | Empty -> com
   | _ -> 
       let mt = m t and 
    qt = q t in
       for i = 0 to Array.length qt - 1 do 
  com.(i) <- com.(i) +. qt.(i)*.mt/.mass
       done;
       com)
 (Array.make B.dim 0.0)
 sts in 
    (mass, com)

(** [tree_of_body_list bs] constructs a tree which contains [bs]. *)
let rec tree_of_body_list = function
  | [] -> Empty
  | [b] -> Body(b)
  | bs -> 
      let bds = bounds_of_bodies bs in
      let s = bounds_size_squared bds in 
      let sub_bds = sub_bounds bds in 
      let sub_trees = Array.map (fun bd -> 
 let sub_bs = List.filter (fun b -> in_bounds bd (B.q b)) bs in 
 tree_of_body_list sub_bs)
   sub_bds in 
      let m, com = mass_and_com sub_trees in 
      Cell(m, com, bds, s, sub_trees)

(** [tree_of_bodies bs] returns the tree which contains [bs]. *)
   let tree_of_bodies bs = 
     let lbs = Array.to_list bs in 
     tree_of_body_list lbs

(** [fold fn start t] is the fundamental tree iterator.  Alas, there
       is no guarantee what the order of application of [fn] is. *)
  let rec fold fn start t = 
    match t with 
    | Empty -> start
    | Body(_) -> fn start t
    | Cell(_, _, _, _, sts) -> 
 Array.fold_left
   (fold fn) (fn start t) sts

(** [fold_w_abort fn start t] folds [fn] over [t] (with initial value
   [start]).  [fn] is applied to each tree-node before it is applied
   to the sub-nodes.  If [fn] returns [(false, value)] value is
   returned as the result of the fold for this entire branch---the
   recursion stops *)
  let rec fold_w_abort fn start t = 
    match t with 
    | Empty -> start
    | Body(_) -> snd (fn start t)
    | Cell(_, _, _, _, sts) -> 
 let (cont, new_start) = fn start t in 
 if cont then 
   Array.fold_left (fold_w_abort fn) new_start sts
 else
   new_start

(** [contains b t] returns [true] if [t] contains [b]. *)
  let rec contains b t = 
    let q = B.q b in 
    match t with 
    | Empty -> false
    | Body(b2) -> b == b2
    | Cell(_, _, bds, _, sts) -> 
 if in_bounds bds q then begin 
   let n = Array.length sts in 
   let rec loop i = 
     if i >= n then 
       false
     else if contains b sts.(i) then 
       true
     else
       loop (i + 1) in 
   loop 0 
 end else
   false
end

No comments: