common-smlnj.sml
datatype dual_number = Bundle of int*dual_number*dual_number | Base of real
exception error
fun replace_ith (xh::xt) i xi =
if i=0 then xi::xt else xh::(replace_ith xt (i-1) xi)
| replace_ith [] _ _ = raise error
fun epsilon (Bundle (e, x, x')) = e
| epsilon (Base x) = 0
fun primal e (Bundle (e1, x, x')) =
if e1<e then (Bundle (e1, x, x')) else x
| primal e (Base x) = Base x
fun perturbation e (Bundle (e1, x, x')) =
if e1<e then (Base 0.0) else x'
| perturbation e (Base x) = Base 0.0
fun lift_real_to_real f dfdx op * =
let fun self (Bundle (e, x, x')) =
Bundle (e,
(self (primal e (Bundle (e, x, x')))),
((dfdx (primal e (Bundle (e, x, x'))))*
(perturbation e (Bundle (e, x, x')))))
| self (Base x) = Base (f x)
in self end
fun lift_real_cross_real_to_real f dfdx1 dfdx2 op + op * =
let fun self ((Bundle (e1, x1, x1')), (Bundle (e2, x2, x2'))) =
let val e = Int.max (e1, e2)
in Bundle (e,
(self ((primal e (Bundle (e1, x1, x1'))),
(primal e (Bundle (e2, x2, x2'))))),
((dfdx1 (primal e (Bundle (e1, x1, x1')))
(primal e (Bundle (e2, x2, x2'))))*
(perturbation e (Bundle (e1, x1, x1')))+
(dfdx2 (primal e (Bundle (e1, x1, x1')))
(primal e (Bundle (e2, x2, x2'))))*
(perturbation e (Bundle (e2, x2, x2'))))) end
| self ((Bundle (e1, x1, x1')), (Base x2)) =
Bundle (e1,
(self ((primal e1 (Bundle (e1, x1, x1'))),
(primal e1 (Base x2)))),
((dfdx1 (primal e1 (Bundle (e1, x1, x1')))
(primal e1 (Base x2)))*
(perturbation e1 (Bundle (e1, x1, x1')))+
(dfdx2 (primal e1 (Bundle (e1, x1, x1')))
(primal e1 (Base x2)))*
(perturbation e1 (Base x2))))
| self ((Base x1), (Bundle (e2, x2, x2'))) =
Bundle (e2,
(self ((primal e2 (Base x1)),
(primal e2 (Bundle (e2, x2, x2'))))),
((dfdx1 (primal e2 (Base x1))
(primal e2 (Bundle (e2, x2, x2'))))*
(perturbation e2 (Base x1))+
(dfdx2 (primal e2 (Base x1))
(primal e2 (Bundle (e2, x2, x2'))))*
(perturbation e2 (Bundle (e2, x2, x2')))))
| self ((Base x1), (Base x2)) = Base (f (x1, x2))
in self end
fun lift_real_cross_real_to_bool f =
let fun self ((Bundle (e1, x1, x1')), (Bundle (e2, x2, x2'))) =
self (x1, x2)
| self ((Bundle (e1, x1, x1')), (Base x2)) = self (x1, (Base x2))
| self ((Base x1), (Bundle (e2, x2, x2'))) = self ((Base x1), x2)
| self ((Base x1), (Base x2)) = f (x1, x2)
in self end
val e = ref 0
fun derivative f x =
(e := !e+1;
let val result = (perturbation (!e) (f (Bundle (!e, x, (Base 1.0)))))
in e := !e-1; result end)
fun write (Bundle (e, x, x')) = ((write x); (Bundle (e, x, x')))
| write (Base x) = ((print (Real.toString x)); (print "\n"); (Base x))
open Real.Math
val (op +, op -, op *, op /, sqrt, op <, op <=) =
let val plus = op +
val minus = op -
val times = op *
val divide = op /
val original_sqrt = sqrt
val lt = op <
val ge = op <=
in let fun op + (x1, x2) =
lift_real_cross_real_to_real
plus
(fn x1 => fn x2 => Base 1.0)
(fn x1 => fn x2 => Base 1.0)
op +
op *
(x1, x2)
and op - (x1, x2) =
lift_real_cross_real_to_real
minus
(fn x1 => fn x2 => Base 1.0)
(fn x1 => fn x2 => Base ~1.0)
op +
op *
(x1, x2)
and op * (x1, x2) =
lift_real_cross_real_to_real
times
(fn x1 => fn x2 => x2)
(fn x1 => fn x2 => x1)
op +
op *
(x1, x2)
and op / (x1, x2) =
lift_real_cross_real_to_real
divide
(fn x1 => fn x2 => (Base 1.0)/x2)
(fn x1 => fn x2 => (Base 0.0)-x1/(x2*x2))
op +
op *
(x1, x2)
and sqrt x =
lift_real_to_real
original_sqrt
(fn x => (Base 1.0)/((sqrt x)+(sqrt x)))
op *
x
and op < (x1, x2) = lift_real_cross_real_to_bool lt (x1, x2)
and op <= (x1, x2) = lift_real_cross_real_to_bool ge (x1, x2)
in (op +, op -, op *, op /, sqrt, op <, op <=) end end
fun sqr x = x*x
fun vplus u v = ListPair.map op + (u, v)
fun vminus u v = ListPair.map op - (u, v)
fun ktimesv k = map (fn x => k*x)
fun magnitude_squared x = foldl op + (Base 0.0) (map sqr x)
fun magnitude x = sqrt (magnitude_squared x)
fun distance_squared u v = magnitude_squared (vminus v u)
fun distance u v = sqrt (distance_squared u v)
fun gradient f x =
List.tabulate
((length x),
(fn i =>
derivative (fn xi => f (replace_ith x i xi)) (List.nth (x, i))))
fun multivariate_argmin f x =
let val g = gradient f
in let fun loop x fx gx eta i =
if (magnitude gx)<=(Base 1e~5)
then x
else if i<=(Base 10.0) andalso (Base 10.0)<=i
then loop x fx gx ((Base 2.0)*eta) (Base 0.0)
else let val x' = vminus x (ktimesv eta gx)
in if (distance x x')<=(Base 1e~5)
then x
else let val fx' = (f x')
in if fx'<fx
then loop x' fx' (g x') eta (i+(Base 1.0))
else loop x fx gx (eta/(Base 2.0)) (Base 0.0)
end end
in loop x (f x) (g x) (Base 1e~5) (Base 0.0) end end
fun multivariate_argmax f x = multivariate_argmin (fn x => (Base 0.0)-(f x)) x
fun multivariate_max f x = f (multivariate_argmax f x)