In the code below, we generate a function representing the second derivative of (x + 2.) * (x + 1.), and calculate the value when x = 3.

For some background, see Automatic Differentiation

This example will not work with release 2.7.0 or earlier. Build 2.7.0.r14548-b20080408080028 or later is required.

 
trait Numeric[t <: Numeric[t]] {
    def +(y: t): t
    
    def *(y: t): t
 
    // could be extended with support for -, /, exp, pow, log, trig functions
}
 
case class Real(a: Double) extends Numeric[Real] {
    def +(that: Real) = Real(this.a + that.a)
    
    def *(that: Real) = Real(this.a * that.a)
}
 
case class Dual[t <: Numeric[t]](a: t, b: t) extends Numeric[Dual[t]] {
    def +(that: Dual[t]) = Dual[t](this.a + that.a, this.b + that.b)
    
    def *(that: Dual[t]) = Dual[t](this.a * that.a, this.b * that.a + this.a * that.b)
}
 
object Numeric {
    implicit def double2Real(x: Double): Real = Real(x)
    
    implicit def double2Dual[t <: Numeric[t]](x: Double)(implicit v: Double => t): Dual[t] = Dual[t](x, 0.)
    
    def diff[t <: Numeric[t]](fn: Dual[t] => Dual[t])(x: t)(implicit v: Double => t): t =
    {
        fn(Dual(x, 1.)).b
    }
}
 
object Test extends Application {
    import Numeric._
    
    // an arbitrary function: (x + 2) (x + 1)
    def f[t <: Numeric[t]](x: t)(implicit v: Double => t): t =
    {
        (x + 2.) * (x + 1.)
    }
 
    // derivative of f, i.e. 2 x + 3
    def g[t <: Numeric[t]](x: t)(implicit v: Double => t): t =
    {
        diff[t](f)(x)
    }
    
    // second derivative of f, i.e. 2
    def h[t <: Numeric[t]](x: t)(implicit v: Double => t): t =
    {
        diff[t](g)(x)
    }
 
    def check(actual: Real, expected: Real) =
    {
        val tolerance = 0.0001
        assert(Math.abs(actual.a - expected.a) < tolerance)
    }
    
    check(diff[Real](f)(20.), 43.)
    
    check(h(3.), 2.)
}
 
code/automatic-differentiation.txt · Last modified: 2008/04/08 16:18 by ericwilligers
 
Recent changes RSS feed Valid XHTML 1.0 Driven by DokuWiki