common-ocaml.ml

type dual_number = Bundle of int*dual_number*dual_number | Base of float

let epsilon p =
  match p
  with (Bundle (e, x, x')) -> e
  |    (Base x) -> 0

let primal e p =
  match p
  with (Bundle (e1, x, x')) -> if e1<e then (Bundle (e1, x, x')) else x
  |    (Base x) -> Base x

let perturbation e p =
  match p
  with (Bundle (e1, x, x')) -> if e1<e then (Base 0.0) else x'
  |    (Base x) -> Base 0.0

let lift_real_to_real f dfdx ( *. ) =
  let rec self p =
    match p
    with (Bundle (e, x, x')) ->
      Bundle (e,
	      (self (primal e (Bundle (e, x, x')))),
	      ((dfdx (primal e (Bundle (e, x, x'))))*.
		 (perturbation e (Bundle (e, x, x')))))
    |    (Base x) -> Base (f x)
  in self

let lift_real_cross_real_to_real f dfdx1 dfdx2 ( +. ) ( *. ) =
  let rec self p1 p2 =
    match p1
    with (Bundle (e1, x1, x1')) ->
      (match p2
      with (Bundle (e2, x2, x2')) ->
	let e = if e1<e2 then e2 else e1
	in Bundle (e,
		   (self (primal e p1) (primal e p2)),
		   ((dfdx1 (primal e p1) (primal e p2))*.
		      (perturbation e p1)+.
		      (dfdx2 (primal e p1) (primal e p2))*.
		      (perturbation e p2)))
      |    (Base x2) ->
	  Bundle (e1,
		  (self (primal e1 p1) (primal e1 p2)),
		  ((dfdx1 (primal e1 p1) (primal e1 p2))*.
		     (perturbation e1 p1)+.
		     (dfdx2 (primal e1 p1) (primal e1 p2))*.
		     (perturbation e1 p2))))
    |    (Base x1) ->
	match p2
	with (Bundle (e2, x2, x2')) ->
	  Bundle (e2,
		  (self (primal e2 p1) (primal e2 p2)),
		  ((dfdx1 (primal e2 p1) (primal e2 p2))*.
		     (perturbation e2 p1)+.
		     (dfdx2 (primal e2 p1) (primal e2 p2))*.
		     (perturbation e2 p2)))
	|    (Base x2) -> Base (f x1 x2)
  in self

let lift_real_cross_real_to_bool f =
  let rec self p1 p2 =
    match p1
    with (Bundle (e1, x1, x1')) ->
      (match p2
      with (Bundle (e2, x2, x2')) -> self x1 x2
      |    (Base x2) -> self x1 p2)
    |    (Base x1) ->
	match p2
	with (Bundle (e2, x2, x2')) -> self p1 x2
	|    (Base x2) -> f x1 x2
  in self

let e = ref 0

let derivative f x =
  (e := !e+1;
   let result = (perturbation (!e) (f (Bundle (!e, x, (Base 1.0)))))
   in e := !e-1; result)

let rec write p =
  match p
  with (Bundle (e, x, x')) -> ((write x); p)
  |    (Base x) -> ((Printf.printf "%.18g\n" x); p)

let (( +. ), ( -. ), ( *. ), ( /. ), sqrt, ( < ), ( <= )) =
  let (plus, minus, times, divide, original_sqrt, lt, ge) =
    (( +. ), ( -. ), ( *. ), ( /. ), sqrt, ( < ), ( <= ))
  in let rec ( +. ) x1 x2 = (lift_real_cross_real_to_real
			       plus
			       (fun x1 x2 -> Base 1.0)
			       (fun x1 x2 -> Base 1.0)
			       ( +. )
			       ( *. )
			       x1
			       x2)
  and ( -. ) x1 x2 = (lift_real_cross_real_to_real
			minus
			(fun x1 x2 -> Base 1.0)
			(fun x1 x2 -> Base (-1.0))
			( +. )
			( *. )
			x1
			x2)
  and ( *. ) x1 x2 = (lift_real_cross_real_to_real
			times
			(fun x1 x2 -> x2)
			(fun x1 x2 -> x1)
			( +. )
			( *. )
			x1
			x2)
  and ( /. ) x1 x2 = (lift_real_cross_real_to_real
			divide
			(fun x1 x2 -> (Base 1.0)/.x2)
			(fun x1 x2 -> (Base 0.0)-.x1/.(x2*.x2))
			( +. )
			( *. )
			x1
			x2)
  and sqrt x = (lift_real_to_real
		  original_sqrt
		  (fun x -> (Base 1.0)/.((sqrt x)+.(sqrt x)))
		  ( *. )
		  x)
  and ( < ) x1 x2 = (lift_real_cross_real_to_bool lt x1 x2)
  and ( <= ) x1 x2 = (lift_real_cross_real_to_bool ge x1 x2)
  in (( +. ), ( -. ), ( *. ), ( /. ), sqrt, ( < ), ( <= ))

open List

let sqr x = x*.x

let map_n f n =
  let rec loop i = if i=n then [] else (f i)::(loop (i+1)) in loop 0

let vplus u v = map2 ( +. ) u v

let vminus u v = map2 ( -. ) u v

let ktimesv k = map (fun x -> k*.x)

let magnitude_squared x = fold_left ( +. ) (Base 0.0) (map sqr x)

let magnitude x = sqrt (magnitude_squared x)

let distance_squared u v = magnitude_squared (vminus v u)

let distance u v = sqrt (distance_squared u v)

let rec replace_ith (xh::xt) i xi =
  if i=0 then xi::xt else xh::(replace_ith xt (i-1) xi)

let gradient f x =
  map_n
    (fun i -> derivative (fun xi -> f (replace_ith x i xi)) (nth x i))
    (length x)

let multivariate_argmin f x =
  let g = gradient f
  in let rec loop x fx gx eta i =
    if (magnitude gx)<=(Base 1e-5)
    then x
    else if i=10
    then loop x fx gx ((Base 2.0)*.eta) 0
    else let x' = vminus x (ktimesv eta gx)
    in if (distance x x')<=(Base 1e-5)
    then x
    else let fx' = (f x')
    in if fx'<fx
    then loop x' fx' (g x') eta (i+1)
    else loop x fx gx (eta/.(Base 2.0)) 0
  in loop x (f x) (g x) (Base 1e-5) 0

let multivariate_argmax f x =
  multivariate_argmin (fun x -> (Base 0.0)-.(f x)) x

let multivariate_max f x = f (multivariate_argmax f x)

Generated by GNU enscript 1.6.4.