common-ghc.hs
module Common where
data Num a => Bundle a = B a a
instance (Num a, Show a) => Show (Bundle a) where
show (B x x') = "(B " ++ (show x) ++ " " ++ (show x') ++ ")"
lift x = B x 0
instance (Num a) => Num (Bundle a) where
fromInteger z = lift (fromInteger z)
(B x x') + (B y y') = B (x + y) (x' + y')
(B x x') - (B y y') = B (x - y) (x' - y')
(B x x') * (B y y') = B (x * y) (x * y' + x' * y)
negate (B x x') = B (- x) (- x')
abs (B x x') = let s = signum x in B (s * x) (s * x')
signum (B x _) = lift (signum x)
instance Fractional a => Fractional (Bundle a) where
recip (B x x') = let r = recip x in B r (- x' * r * r)
fromRational z = lift (fromRational z)
instance (Num a, Eq a) => Eq (Bundle a) where
(B x _) == (B y _) = (x == y)
instance (Num a, Ord a) => Ord (Bundle a) where
(B x _) `compare` (B y _) = x `compare` y
instance Floating a => Floating (Bundle a) where
pi = lift pi
exp (B x x') = B (exp x) (x' * exp x)
log (B x x') = B (log x) (x' / x)
sqrt (B x x') = let y = sqrt x in B y (x' / (2 * y))
sin (B x x') = B (sin x) (x' * (cos x))
cos (B x x') = B (cos x) (x' * (- sin x))
asin (B x x') = B (asin x) (x' * (error "unimplemented"))
atan (B x x') = B (atan x) (x' * (error "unimplemented"))
acos (B x x') = B (acos x) (x' * (error "unimplemented"))
sinh (B x x') = B (sinh x) (x' * (error "unimplemented"))
cosh (B x x') = B (cosh x) (x' * (error "unimplemented"))
asinh (B x x') = B (asinh x) (x' * (error "unimplemented"))
atanh (B x x') = B (atanh x) (x' * (error "unimplemented"))
acosh (B x x') = B (acosh x) (x' * (error "unimplemented"))
instance (Num a, Enum a) => Enum (Bundle a) where
toEnum i = lift (toEnum i)
fromEnum (B i _) = fromEnum i
succ = (+ 1)
pred = (subtract 1)
instance (Num a, Ord a, Real a) => Real (Bundle a) where
toRational (B x _) = toRational x
derivative :: Num a => (Bundle a -> Bundle a) -> a -> a
derivative f x = let (B _ y') = f (B x 1) in y'
sqr x = x * x
vplus :: Num a => [a] -> [a] -> [a]
vplus = zipWith (+)
vminus :: Num a => [a] -> [a] -> [a]
vminus = zipWith (-)
ktimesv k = map (k *)
magnitude_squared x = foldl (+) 0 (map sqr x)
magnitude :: Floating a => [a] -> a
magnitude = sqrt . magnitude_squared
distance_squared u v = magnitude_squared (vminus u v)
distance u v = sqrt (distance_squared u v)
replace_ith (x : xs) 0 xi = (xi : xs)
replace_ith (x : xs) (i + 1) xi = (x : (replace_ith xs i xi))
gradient f x =
map (\ i -> derivative
(\ xi -> f (replace_ith (map lift x) i xi)) (x !! i))
[0 .. (length x) - 1]
lower_fs :: Num a => ([Bundle a] -> Bundle a) -> [a] -> a
lower_fs f xs = let (B y _) = f (map lift xs) in y
multivariate_argmin f x =
let g = gradient f
ff = lower_fs f
loop x fx gx eta i =
if (magnitude gx) <= 1e-5
then x
else if i == 10
then loop x fx gx (2 * eta) 0
else let x_prime = vminus x (ktimesv eta gx)
in if (distance x x_prime) <= 1e-5
then x
else let fx_prime = ff x_prime
in if fx_prime < fx
then
loop
x_prime fx_prime (g x_prime) eta (i + 1)
else
loop
x fx gx (eta / 2) 0
in loop x (ff x) (g x) 1e-5 0
multivariate_argmax :: (Floating a, Ord a) =>
([Bundle a] -> Bundle a) -> [a] -> [a]
multivariate_argmax f x = multivariate_argmin (\ x -> - (f x)) x
multivariate_max f x = (lower_fs f) (multivariate_argmax f x)
Generated by GNU enscript 1.6.4.