common-ocaml.ml
type dual_number = Bundle of int*dual_number*dual_number | Base of float
let epsilon p =
match p
with (Bundle (e, x, x')) -> e
| (Base x) -> 0
let primal e p =
match p
with (Bundle (e1, x, x')) -> if e1<e then (Bundle (e1, x, x')) else x
| (Base x) -> Base x
let perturbation e p =
match p
with (Bundle (e1, x, x')) -> if e1<e then (Base 0.0) else x'
| (Base x) -> Base 0.0
let lift_real_to_real f dfdx ( *. ) =
let rec self p =
match p
with (Bundle (e, x, x')) ->
Bundle (e,
(self (primal e (Bundle (e, x, x')))),
((dfdx (primal e (Bundle (e, x, x'))))*.
(perturbation e (Bundle (e, x, x')))))
| (Base x) -> Base (f x)
in self
let lift_real_cross_real_to_real f dfdx1 dfdx2 ( +. ) ( *. ) =
let rec self p1 p2 =
match p1
with (Bundle (e1, x1, x1')) ->
(match p2
with (Bundle (e2, x2, x2')) ->
let e = if e1<e2 then e2 else e1
in Bundle (e,
(self (primal e p1) (primal e p2)),
((dfdx1 (primal e p1) (primal e p2))*.
(perturbation e p1)+.
(dfdx2 (primal e p1) (primal e p2))*.
(perturbation e p2)))
| (Base x2) ->
Bundle (e1,
(self (primal e1 p1) (primal e1 p2)),
((dfdx1 (primal e1 p1) (primal e1 p2))*.
(perturbation e1 p1)+.
(dfdx2 (primal e1 p1) (primal e1 p2))*.
(perturbation e1 p2))))
| (Base x1) ->
match p2
with (Bundle (e2, x2, x2')) ->
Bundle (e2,
(self (primal e2 p1) (primal e2 p2)),
((dfdx1 (primal e2 p1) (primal e2 p2))*.
(perturbation e2 p1)+.
(dfdx2 (primal e2 p1) (primal e2 p2))*.
(perturbation e2 p2)))
| (Base x2) -> Base (f x1 x2)
in self
let lift_real_cross_real_to_bool f =
let rec self p1 p2 =
match p1
with (Bundle (e1, x1, x1')) ->
(match p2
with (Bundle (e2, x2, x2')) -> self x1 x2
| (Base x2) -> self x1 p2)
| (Base x1) ->
match p2
with (Bundle (e2, x2, x2')) -> self p1 x2
| (Base x2) -> f x1 x2
in self
let e = ref 0
let derivative f x =
(e := !e+1;
let result = (perturbation (!e) (f (Bundle (!e, x, (Base 1.0)))))
in e := !e-1; result)
let rec write p =
match p
with (Bundle (e, x, x')) -> ((write x); p)
| (Base x) -> ((Printf.printf "%.18g\n" x); p)
let (( +. ), ( -. ), ( *. ), ( /. ), sqrt, ( < ), ( <= )) =
let (plus, minus, times, divide, original_sqrt, lt, ge) =
(( +. ), ( -. ), ( *. ), ( /. ), sqrt, ( < ), ( <= ))
in let rec ( +. ) x1 x2 = (lift_real_cross_real_to_real
plus
(fun x1 x2 -> Base 1.0)
(fun x1 x2 -> Base 1.0)
( +. )
( *. )
x1
x2)
and ( -. ) x1 x2 = (lift_real_cross_real_to_real
minus
(fun x1 x2 -> Base 1.0)
(fun x1 x2 -> Base (-1.0))
( +. )
( *. )
x1
x2)
and ( *. ) x1 x2 = (lift_real_cross_real_to_real
times
(fun x1 x2 -> x2)
(fun x1 x2 -> x1)
( +. )
( *. )
x1
x2)
and ( /. ) x1 x2 = (lift_real_cross_real_to_real
divide
(fun x1 x2 -> (Base 1.0)/.x2)
(fun x1 x2 -> (Base 0.0)-.x1/.(x2*.x2))
( +. )
( *. )
x1
x2)
and sqrt x = (lift_real_to_real
original_sqrt
(fun x -> (Base 1.0)/.((sqrt x)+.(sqrt x)))
( *. )
x)
and ( < ) x1 x2 = (lift_real_cross_real_to_bool lt x1 x2)
and ( <= ) x1 x2 = (lift_real_cross_real_to_bool ge x1 x2)
in (( +. ), ( -. ), ( *. ), ( /. ), sqrt, ( < ), ( <= ))
open List
let sqr x = x*.x
let map_n f n =
let rec loop i = if i=n then [] else (f i)::(loop (i+1)) in loop 0
let vplus u v = map2 ( +. ) u v
let vminus u v = map2 ( -. ) u v
let ktimesv k = map (fun x -> k*.x)
let magnitude_squared x = fold_left ( +. ) (Base 0.0) (map sqr x)
let magnitude x = sqrt (magnitude_squared x)
let distance_squared u v = magnitude_squared (vminus v u)
let distance u v = sqrt (distance_squared u v)
let rec replace_ith (xh::xt) i xi =
if i=0 then xi::xt else xh::(replace_ith xt (i-1) xi)
let gradient f x =
map_n
(fun i -> derivative (fun xi -> f (replace_ith x i xi)) (nth x i))
(length x)
let multivariate_argmin f x =
let g = gradient f
in let rec loop x fx gx eta i =
if (magnitude gx)<=(Base 1e-5)
then x
else if i=10
then loop x fx gx ((Base 2.0)*.eta) 0
else let x' = vminus x (ktimesv eta gx)
in if (distance x x')<=(Base 1e-5)
then x
else let fx' = (f x')
in if fx'<fx
then loop x' fx' (g x') eta (i+1)
else loop x fx gx (eta/.(Base 2.0)) 0
in loop x (f x) (g x) (Base 1e-5) 0
let multivariate_argmax f x =
multivariate_argmin (fun x -> (Base 0.0)-.(f x)) x
let multivariate_max f x = f (multivariate_argmax f x)
Generated by GNU enscript 1.6.4.