common-mlton.sml

datatype dual_number = Bundle of int*dual_number*dual_number | Base of real

exception error

fun replace_ith (xh::xt) i xi =
    if i=0 then xi::xt else xh::(replace_ith xt (i-1) xi)
  | replace_ith [] _ _ = raise error

fun epsilon (Bundle (e, x, x')) = e
  | epsilon (Base x) = 0

fun primal e (Bundle (e1, x, x')) =
    if e1<e then (Bundle (e1, x, x')) else x
  | primal e (Base x) = Base x

fun perturbation e (Bundle (e1, x, x')) =
    if e1<e then (Base 0.0) else x'
  | perturbation e (Base x) = Base 0.0

fun lift_real_to_real f dfdx op * =
    let fun self (Bundle (e, x, x')) =
	    Bundle (e,
		    (self (primal e (Bundle (e, x, x')))),
		    ((dfdx (primal e (Bundle (e, x, x'))))*
		     (perturbation e (Bundle (e, x, x')))))
	  | self (Base x) = Base (f x)
    in self end

fun lift_real_cross_real_to_real f dfdx1 dfdx2 op + op * =
    let fun self ((Bundle (e1, x1, x1')), (Bundle (e2, x2, x2'))) =
	    let val e = Int.max (e1, e2)
	    in Bundle (e,
		       (self ((primal e (Bundle (e1, x1, x1'))),
			      (primal e (Bundle (e2, x2, x2'))))),
		       ((dfdx1 (primal e (Bundle (e1, x1, x1')))
			       (primal e (Bundle (e2, x2, x2'))))*
			(perturbation e (Bundle (e1, x1, x1')))+
			(dfdx2 (primal e (Bundle (e1, x1, x1')))
			       (primal e (Bundle (e2, x2, x2'))))*
			(perturbation e (Bundle (e2, x2, x2'))))) end
	  | self ((Bundle (e1, x1, x1')), (Base x2)) =
	    Bundle (e1,
		    (self ((primal e1 (Bundle (e1, x1, x1'))),
			   (primal e1 (Base x2)))),
		    ((dfdx1 (primal e1 (Bundle (e1, x1, x1')))
			    (primal e1 (Base x2)))*
		     (perturbation e1 (Bundle (e1, x1, x1')))+
		     (dfdx2 (primal e1 (Bundle (e1, x1, x1')))
			    (primal e1 (Base x2)))*
		     (perturbation e1 (Base x2))))
	  | self ((Base x1), (Bundle (e2, x2, x2'))) =
	    Bundle (e2,
		    (self ((primal e2 (Base x1)),
			   (primal e2 (Bundle (e2, x2, x2'))))),
		    ((dfdx1 (primal e2 (Base x1))
			    (primal e2 (Bundle (e2, x2, x2'))))*
		     (perturbation e2 (Base x1))+
		     (dfdx2 (primal e2 (Base x1))
			    (primal e2 (Bundle (e2, x2, x2'))))*
		     (perturbation e2 (Bundle (e2, x2, x2')))))
	  | self ((Base x1), (Base x2)) = Base (f (x1, x2))
    in self end

fun lift_real_cross_real_to_bool f =
    let fun self ((Bundle (e1, x1, x1')), (Bundle (e2, x2, x2'))) =
	    self (x1, x2)
	  | self ((Bundle (e1, x1, x1')), (Base x2)) = self (x1, (Base x2))
	  | self ((Base x1), (Bundle (e2, x2, x2'))) = self ((Base x1), x2)
	  | self ((Base x1), (Base x2)) = f (x1, x2)
    in self end

val e = ref 0

fun derivative f x =
    (e := !e+1;
     let val result = (perturbation (!e) (f (Bundle (!e, x, (Base 1.0)))))
     in e := !e-1; result end)

fun write (Bundle (e, x, x')) = ((write x); (Bundle (e, x, x')))
  | write (Base x) = ((print (Real.toString x)); (print "\n"); (Base x))

open Real.Math

val (op +, op -, op *, op /, sqrt, op <, op <=) =
    let val plus = op +
	val minus = op -
	val times = op *
	val divide = op /
	val original_sqrt = sqrt
        val lt = op <
        val ge = op <=
    in let fun op + (x1, x2) =
	       lift_real_cross_real_to_real
		   plus
		   (fn x1 => fn x2 => Base 1.0)
		   (fn x1 => fn x2 => Base 1.0)
		   op +
		   op *
		   (x1, x2)
	   and op - (x1, x2) =
	       lift_real_cross_real_to_real
		   minus
		   (fn x1 => fn x2 => Base 1.0)
		   (fn x1 => fn x2 => Base ~1.0)
		   op +
		   op *
		   (x1, x2)
	   and op * (x1, x2) =
	       lift_real_cross_real_to_real
		   times
		   (fn x1 => fn x2 => x2)
		   (fn x1 => fn x2 => x1)
		   op +
		   op *
		   (x1, x2)
	   and op / (x1, x2) =
	       lift_real_cross_real_to_real
		   divide
		   (fn x1 => fn x2 => (Base 1.0)/x2)
		   (fn x1 => fn x2 => (Base 0.0)-x1/(x2*x2))
		   op +
		   op *
		   (x1, x2)
	   and sqrt x =
	       lift_real_to_real
		   original_sqrt
		   (fn x => (Base 1.0)/((sqrt x)+(sqrt x)))
		   op *
		   x
	   and op < (x1, x2) = lift_real_cross_real_to_bool lt (x1, x2)
	   and op <= (x1, x2) = lift_real_cross_real_to_bool ge (x1, x2)
       in (op +, op -, op *, op /, sqrt, op <, op <=) end end

fun sqr x = x*x

fun vplus u v = ListPair.map op + (u, v)

fun vminus u v = ListPair.map op - (u, v)

fun ktimesv k = map (fn x => k*x)

fun magnitude_squared x = foldl op + (Base 0.0) (map sqr x)

fun magnitude x = sqrt (magnitude_squared x)

fun distance_squared u v = magnitude_squared (vminus v u)

fun distance u v = sqrt (distance_squared u v)

fun gradient f x =
    List.tabulate
	((length x),
	 (fn i =>
	     derivative (fn xi => f (replace_ith x i xi)) (List.nth (x, i))))

fun multivariate_argmin f x =
    let val g = gradient f
    in let fun loop x fx gx eta i =
	       if (magnitude gx)<=(Base 1e~5)
	       then x
	       else if i<=(Base 10.0) andalso (Base 10.0)<=i
	       then loop x fx gx ((Base 2.0)*eta) (Base 0.0)
	       else let val x' = vminus x (ktimesv eta gx)
		    in if (distance x x')<=(Base 1e~5)
		       then x
		       else let val fx' = (f x')
			    in if fx'<fx
			       then loop x' fx' (g x') eta (i+(Base 1.0))
			       else loop x fx gx (eta/(Base 2.0)) (Base 0.0)
			    end end
       in loop x (f x) (g x) (Base 1e~5) (Base 0.0) end end

fun multivariate_argmax f x = multivariate_argmin (fn x => (Base 0.0)-(f x)) x

fun multivariate_max f x = f (multivariate_argmax f x)

Generated by GNU enscript 1.6.4.