Calculating π in Clojure (Salamin-Brent)

Took a shot at implementing a PI digit generator in Clojure using a ‘fast’ algorithm.
It seemed like a decent enough excercise to try and understand something about performance in Clojure.

MacBook Pro – Intel Core 2 Duo 2.26 GHz – 4GB RAM
Java(TM) SE Runtime Environment (build 1.6.0_15-b03-219)
Java HotSpot(TM) 64-Bit Server VM (build 14.1-b02-90, mixed mode)
Clojure 1.1.0-alpha-SNAPSHOT (Aug 20 2009) git commit f1f5ad40984d46bdc314090552b76471ee2b8d01

Clojure matches the performance of Java in this example.

The Clojure code :

(import 'java.lang.Math)
(import 'java.math.MathContext)
(import 'java.math.BigDecimal)


(defn sb-pi [places]
  "Calculates PI digits using the Salamin-Brent algorithm
   and Java's BigDecimal class."

  (let [digits (.intValue (+ 10 places)) ;; add some guard digits
        round-mode BigDecimal/ROUND_DOWN]

    (letfn [(big-sqrt[#^BigDecimal num]
             "Calculates square root using Newton's method."
             (letfn [(big-sqrt-int 
                      [#^BigDecimal num #^BigDecimal x0 #^BigDecimal x1]
                      "aux function for calculating square root"
                      (let [#^BigDecimal x0new x1
                            #^BigDecimal x1new (-> num (.divide x0new digits round-mode))
                            #^BigDecimal xsum (+ x1new x0new)
                            #^BigDecimal x1tot (-> xsum (.divide 2M digits round-mode))]
                        (if (= x0 x1)
                          x1tot
                          (recur num x1 x1tot))))]
               (big-sqrt-int
                num 0M (BigDecimal/valueOf
                        (Math/sqrt (. num doubleValue))))))
            (sb-pi-int [#^BigDecimal a #^BigDecimal b 
                        #^BigDecimal t #^BigDecimal x #^BigDecimal y]
             "aux function for calculating PI"
             (let
                 [#^BigDecimal y1 a
                  #^BigDecimal absum (+ a b)
                  #^BigDecimal a1 (-> absum (.divide 2M digits round-mode))
                  #^BigDecimal b1 (big-sqrt (* b y1))
                  #^BigDecimal ydiff (- y1 a1)
                  #^BigDecimal t1 (- t (* x ydiff ydiff))
                  #^BigDecimal x1 (* x 2M)]               
               (if (== a b)
                 (let [#^BigDecimal absum1 (+ a1 b1)
                       #^BigDecimal absqrd (* absum1 absum1)
                       #^BigDecimal u (* t1 4M)]
                   (-> absqrd
                       (.divide u digits round-mode)
                       (.setScale places round-mode)))
                 (recur a1 b1 t1 x1 y1))))]
      
      (sb-pi-int 1M (-> 1M (.divide #^BigDecimal (big-sqrt 2M) digits round-mode))
                       (/ 1M 4M) 1M nil))))


(time (println (sb-pi (Integer/parseInt (second *command-line-args*)))))


$ time clj pi.clj 1               -->       3.403 msecs
$ time clj pi.clj 10              -->       3.956 msecs
$ time clj pi.clj 100             -->      10.630 msecs
$ time clj pi.clj 1000            -->     141.937 msecs
$ time clj pi.clj 10000           -->    3316.180 msecs

The same algorithm in Java (but using iteration instead of recursion) :


import java.math.BigDecimal;
import static java.math.BigDecimal.*;

class Pi {
  private static final BigDecimal TWO = new BigDecimal(2);
  private static final BigDecimal FOUR = new BigDecimal(4);
  private static int ROUND_MODE = ROUND_DOWN;

  public static void main(String[] args) {
    long start = System.nanoTime();
    System.out.println(pi(Integer.parseInt(args[0])));
    System.out.println("Elapsed time: " + 
                       ((System.nanoTime() - start) / 1E6) + " msecs");  
  }
    
  // Salamin-Brent Algorithm
  public static BigDecimal pi(final int digits) {
    final int SCALE = 10 + digits;
    BigDecimal a = ONE;
    BigDecimal b = ONE.divide(sqrt(TWO, SCALE), SCALE, ROUND_MODE);
    BigDecimal t = new BigDecimal(0.25);
    BigDecimal x = ONE;
    BigDecimal y;
        
    while (!a.equals(b)) {
      y = a;
      a = a.add(b).divide(TWO, SCALE, ROUND_MODE);
      b = sqrt(b.multiply(y), SCALE);
      t = t.subtract(x.multiply(y.subtract(a).multiply(y.subtract(a))));
      x = x.multiply(TWO);
    }
        
    return a.add(b)
      .multiply(a.add(b))
      .divide(t.multiply(FOUR), SCALE, ROUND_MODE)
      .setScale(digits, ROUND_MODE);
  }
    
  // square root method (Newton's)
  public static BigDecimal sqrt(BigDecimal A, final int SCALE) {
    BigDecimal x0 = new BigDecimal("0");
    BigDecimal x1 = new BigDecimal(Math.sqrt(A.doubleValue()));
        
    while (!x0.equals(x1)) {
      x0 = x1;
      x1 = A.divide(x0, SCALE, ROUND_MODE);
      x1 = x1.add(x0);
      x1 = x1.divide(TWO, SCALE, ROUND_MODE);
    }
        
    return x1;
  }
}


$ time java Pi 1         ---->         2.162 msecs
$ time java Pi 10        ---->         2.425 msecs
$ time java Pi 100       ---->         7.897 msecs
$ time java Pi 1000      ---->       150.610 msecs
$ time java Pi 10000     ---->      3009.705 msecs

Another slice of π ?

The previous Common Lisp and Haskell functions to generate the digits of PI where only accurate between 10000 and 20000 digits. The algorithm uses an approximation where we discard a certain number of ‘guard’ digits to get an accurate result. Some background regarding how the number of guard digits is calculated : There is ‘wobble’ in the number of contaminated digits depending on the number of input digits. When only the order of magnitude of rounding error is of interest, ulps(units in the last place) and ε (machine precision) may be used interchangeably, since they differ by at most a factor of β (the ‘wobble’ in the number of contaminated digits). For example, when a floating-point number is in error by n ulps, that means that the number of contaminated digits is about logβn. [1] [2]

Here are new versions of the functions using a guard digits calculation :

Common Lisp :

  "Accurately calculates PI digits using Machin's formula
   with fixed point arithmetic and guard digits."
  (labels
      ((arccot-recur (xsq n xpower op)
         (let ((term (floor xpower n))
               (opfun (if op #'+ #'-)))
           (if (= term 0)
               0
               (funcall opfun
                        (arccot-recur xsq (+ n 2) (floor xpower xsq) (not op))
                        term))))
       (arccot (x unity)
         (let ((xpower (floor (/ unity x))))
           (arccot-recur (* x x) 1 xpower t))))

    (if (> digits 0)
        (let* ((guard (+ 10 (ceiling (log digits 10))))
               (unity (expt 10 (+ digits guard)))
               (thispi (* 4 (- (* 4 (arccot 5 unity)) (arccot 239 unity)))))
          (floor thispi (expt 10 guard))))))

Haskell :

{-
Accurately calculates PI digits using Machin's formula
with fixed point arithmetic and variable guards digits.
-}

arccot :: Integer -> Integer -> Integer
arccot x unity = arccot' 0 (unity `div` x) 1 1
    where
      arccot' sum xpower n sign
          | xpower `div` n == 0 = sum
          | otherwise           =
              arccot' (sum + sign * (xpower `div` n))
                      (xpower `div` (x * x)) (n + 2) (- sign)

machin_pi :: Integer -> Integer
machin_pi digits =
    if digits < 1 then 0 else
        pi' `div` (10 ^ guard)
            where
              guard = 10 + (ceiling (logBase 10 (fromInteger digits)))
              unity = 10 ^ (digits + guard)
              pi' = 4 * (4 * arccot 5 unity - arccot 239 unity)

“What Every Computer Scientist Should Know About Floating-Point Arithmetic”, by David Goldberg, March 1991

1. Relative error and Ulps

2. Guard digits