common-chicken.sc
(declare (block)
(standard-bindings)
(extended-bindings)
(not safe)
(not interrupts-enabled))
(define <_e <)
(define dual-number?
(let ((pair? pair?))
(lambda (p) (and (pair? p) (eq? (car p) 'dual-number)))))
(define (dual-number e x x-prime)
(if (dzero? x-prime) x (list 'dual-number e x x-prime)))
(define epsilon cadr)
(define (primal e p)
(if (or (not (dual-number? p)) (<_e (epsilon p) e)) p (caddr p)))
(define (perturbation e p)
(if (or (not (dual-number? p)) (<_e (epsilon p) e)) 0 (cadddr p)))
(define (lift-real->real f df/dx)
(letrec ((self (lambda (p)
(if (dual-number? p)
(let ((e (epsilon p)))
(dual-number
e
(self (primal e p))
(d* (df/dx (primal e p)) (perturbation e p))))
(f p)))))
self))
(define (lift-real*real->real f df/dx1 df/dx2)
(letrec ((self
(lambda (p1 p2)
(if (or (dual-number? p1)
(dual-number? p2))
(let ((e (if (or (not (dual-number? p1))
(and (dual-number? p2)
(<_e (epsilon p1) (epsilon p2))))
(epsilon p2)
(epsilon p1))))
(dual-number
e
(self (primal e p1) (primal e p2))
(d+ (d* (df/dx1 (primal e p1) (primal e p2))
(perturbation e p1))
(d* (df/dx2 (primal e p1) (primal e p2))
(perturbation e p2)))))
(f p1 p2)))))
self))
(define (primal* p)
(if (dual-number? p) (primal* (primal (epsilon p) p)) p))
(define (lift-real^n->boolean f) (lambda ps (apply f (map primal* ps))))
(define dpair?
(let ((pair? pair?))
(lambda (x) (and (pair? x) (not (dual-number? x))))))
(define d+ (lift-real*real->real + (lambda (x1 x2) 1) (lambda (x1 x2) 1)))
(define d- (lift-real*real->real - (lambda (x1 x2) 1) (lambda (x1 x2) -1)))
(define d*
(lift-real*real->real * (lambda (x1 x2) x2) (lambda (x1 x2) x1)))
(define d/
(lift-real*real->real
/ (lambda (x1 x2) (d/ 1 x2)) (lambda (x1 x2) (d- 0 (d/ x1 (d* x2 x2))))))
(define dsqrt (lift-real->real sqrt (lambda (x) (d/ 1 (d* 2 (dsqrt x))))))
(define dexp (lift-real->real exp (lambda (x) (dexp x))))
(define dlog (lift-real->real log (lambda (x) (d/ 1 x))))
(define dsin (lift-real->real sin (lambda (x) (dcos x))))
(define dcos (lift-real->real cos (lambda (x) (d- 0 (dsin x)))))
(define datan (lift-real*real->real
atan
(lambda (x1 x2) (d/ (d- 0 x2) (d+ (d* x1 x1) (d* x2 x2))))
(lambda (x1 x2) (d/ x1 (d+ (d* x1 x1) (d* x2 x2))))))
(define d= (lift-real^n->boolean =))
(define d< (lift-real^n->boolean <))
(define d> (lift-real^n->boolean >))
(define d<= (lift-real^n->boolean <=))
(define d>= (lift-real^n->boolean >=))
(define dzero? (lift-real^n->boolean zero?))
(define dpositive? (lift-real^n->boolean positive?))
(define dnegative? (lift-real^n->boolean negative?))
(define dreal? (lift-real^n->boolean real?))
(define derivative
(let ((e 0))
(lambda (f)
(lambda (x)
(set! e (d+ e 1))
(let ((result (perturbation e (f (dual-number e x 1)))))
(set! e (d- e 1))
result)))))
(define (my-write x) (write x) (newline) x)
(define (sqr x) (d* x x))
(define (map-n f)
(lambda (n)
(letrec ((loop (lambda (i) (if (d= i n) '() (cons (f i) (loop (d+ i 1)))))))
(loop 0))))
(define (reduce f i)
(lambda (l) (if (null? l) i (f (car l) ((reduce f i) (cdr l))))))
(define (v+ u v) (map d+ u v))
(define (v- u v) (map d- u v))
(define (k*v k v) (map (lambda (x) (d* k x)) v))
(define (magnitude-squared x) ((reduce d+ 0.0) (map sqr x)))
(define (magnitude x) (dsqrt (magnitude-squared x)))
(define (distance-squared u v) (magnitude-squared (v- v u)))
(define (distance u v) (dsqrt (distance-squared u v)))
(define (replace-ith x i xi)
(if (dzero? i)
(cons xi (cdr x))
(cons (car x) (replace-ith (cdr x) (d- i 1) xi))))
(define (gradient f)
(lambda (x)
((map-n
(lambda (i)
((derivative (lambda (xi) (f (replace-ith x i xi)))) (list-ref x i))))
(length x))))
(define (multivariate-argmin f x)
(let ((g (gradient f)))
(letrec ((loop
(lambda (x fx gx eta i)
(cond ((d<= (magnitude gx) 1e-5) x)
((d= i 10) (loop x fx gx (d* 2.0 eta) 0))
(else
(let ((x-prime (v- x (k*v eta gx))))
(if (d<= (distance x x-prime) 1e-5)
x
(let ((fx-prime (f x-prime)))
(if (d< fx-prime fx)
(loop x-prime fx-prime (g x-prime) eta (d+ i 1))
(loop x fx gx (d/ eta 2.0) 0))))))))))
(loop x (f x) (g x) 1e-5 0))))
(define (multivariate-argmax f x)
(multivariate-argmin (lambda (x) (d- 0.0 (f x))) x))
(define (multivariate-max f x) (f (multivariate-argmax f x)))
Generated by GNU enscript 1.6.4.