// levenberg-marquardt in java // // To use this, implement the functions in the LMfunc interface. // // This library uses simple matrix routines from the JAMA java matrix package, // which is in the public domain. Reference: // http://math.nist.gov/javanumerics/jama/ // (JAMA has a matrix object class. An earlier library JNL, which is no longer // available, represented matrices as low-level arrays. Several years // ago the performance of JNL matrix code was better than that of JAMA, // though improvements in java compilers may have fixed this by now.) // // One further recommendation would be to use an inverse based // on Choleski decomposition, which is easy to implement and // suitable for the symmetric inverse required here. There is a choleski // routine at idiom.com/~zilla. // // If you make an improved version, please consider adding your // name to it ("modified by ...") and send it back to me // (and put it on the web). // // ---------------------------------------------------------------- // // This library is free software; you can redistribute it and/or // modify it under the terms of the GNU Library General Public // License as published by the Free Software Foundation; either // version 2 of the License, or (at your option) any later version. // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Library General Public License for more details. // // You should have received a copy of the GNU Library General Public // License along with this library; if not, write to the // Free Software Foundation, Inc., 59 Temple Place - Suite 330, // Boston, MA 02111-1307, USA. // // initial author contact info: // jplewis www.idiom.com/~zilla zilla # computer.org, #=at // // Improvements by: // dscherba www.ncsa.uiuc.edu/~dscherba // Jonathan Jackson j.jackson # ucl.ac.uk package ZS.Solve; // see comment above import Jama.*; /** * Levenberg-Marquardt, implemented from the general description * in Numerical Recipes (NR), then tweaked slightly to mostly * match the results of their code. * Use for nonlinear least squares assuming Gaussian errors. * * TODO this holds some parameters fixed by simply not updating them. * this may be ok if the number if fixed parameters is small, * but if the number of varying parameters is larger it would * be more efficient to make a smaller hessian involving only * the variables. * * The NR code assumes a statistical context, e.g. returns * covariance of parameter errors; we do not do this. */ public final class LM { /** * calculate the current sum-squared-error * (Chi-squared is the distribution of squared Gaussian errors, * thus the name) */ static double chiSquared(double[][] x, double[] a, double[] y, double[] s, LMfunc f) { int npts = y.length; double sum = 0.; for( int i = 0; i < npts; i++ ) { double d = y[i] - f.val(x[i], a); d = d / s[i]; sum = sum + (d*d); } return sum; } //chiSquared /** * Minimize E = sum {(y[k] - f(x[k],a)) / s[k]}^2 * The individual errors are optionally scaled by s[k]. * Note that LMfunc implements the value and gradient of f(x,a), * NOT the value and gradient of E with respect to a! * * @param x array of domain points, each may be multidimensional * @param y corresponding array of values * @param a the parameters/state of the model * @param vary false to indicate the corresponding a[k] is to be held fixed * @param s2 sigma^2 for point i * @param lambda blend between steepest descent (lambda high) and * jump to bottom of quadratic (lambda zero). * Start with 0.001. * @param termepsilon termination accuracy (0.01) * @param maxiter stop and return after this many iterations if not done * @param verbose set to zero (no prints), 1, 2 * * @return the new lambda for future iterations. * Can use this and maxiter to interleave the LM descent with some other * task, setting maxiter to something small. */ public static double solve(double[][] x, double[] a, double[] y, double[] s, boolean[] vary, LMfunc f, double lambda, double termepsilon, int maxiter, int verbose) throws Exception { int npts = y.length; int nparm = a.length; assert s.length == npts; assert x.length == npts; if (verbose > 0) { System.out.print("solve x["+x.length+"]["+x[0].length+"]" ); System.out.print(" a["+a.length+"]"); System.out.println(" y["+y.length+"]"); } double e0 = chiSquared(x, a, y, s, f); //double lambda = 0.001; boolean done = false; // g = gradient, H = hessian, d = step to minimum // H d = -g, solve for d double[][] H = new double[nparm][nparm]; double[] g = new double[nparm]; //double[] d = new double[nparm]; double[] oos2 = new double[s.length]; for( int i = 0; i < npts; i++ ) oos2[i] = 1./(s[i]*s[i]); int iter = 0; int term = 0; // termination count test do { ++iter; // hessian approximation for( int r = 0; r < nparm; r++ ) { for( int c = 0; c < nparm; c++ ) { for( int i = 0; i < npts; i++ ) { if (i == 0) H[r][c] = 0.; double[] xi = x[i]; H[r][c] += (oos2[i] * f.grad(xi, a, r) * f.grad(xi, a, c)); } //npts } //c } //r // boost diagonal towards gradient descent for( int r = 0; r < nparm; r++ ) H[r][r] *= (1. + lambda); // gradient for( int r = 0; r < nparm; r++ ) { for( int i = 0; i < npts; i++ ) { if (i == 0) g[r] = 0.; double[] xi = x[i]; g[r] += (oos2[i] * (y[i]-f.val(xi,a)) * f.grad(xi, a, r)); } } //npts // scale (for consistency with NR, not necessary) if (false) { for( int r = 0; r < nparm; r++ ) { g[r] = -0.5 * g[r]; for( int c = 0; c < nparm; c++ ) { H[r][c] *= 0.5; } } } // solve H d = -g, evaluate error at new location //double[] d = DoubleMatrix.solve(H, g); double[] d = (new Matrix(H)).lu().solve(new Matrix(g, nparm)).getRowPackedCopy(); //double[] na = DoubleVector.add(a, d); double[] na = (new Matrix(a, nparm)).plus(new Matrix(d, nparm)).getRowPackedCopy(); double e1 = chiSquared(x, na, y, s, f); if (verbose > 0) { System.out.println("\n\niteration "+iter+" lambda = "+lambda); System.out.print("a = "); (new Matrix(a, nparm)).print(10, 2); if (verbose > 1) { System.out.print("H = "); (new Matrix(H)).print(10, 2); System.out.print("g = "); (new Matrix(g, nparm)).print(10, 2); System.out.print("d = "); (new Matrix(d, nparm)).print(10, 2); } System.out.print("e0 = " + e0 + ": "); System.out.print("moved from "); (new Matrix(a, nparm)).print(10, 2); System.out.print("e1 = " + e1 + ": "); if (e1 < e0) { System.out.print("to "); (new Matrix(na, nparm)).print(10, 2); } else { System.out.println("move rejected"); } } // termination test (slightly different than NR) if (Math.abs(e1-e0) > termepsilon) { term = 0; } else { term++; if (term == 4) { System.out.println("terminating after " + iter + " iterations"); done = true; } } if (iter >= maxiter) done = true; // in the C++ version, found that changing this to e1 >= e0 // was not a good idea. See comment there. // if (e1 > e0 || Double.isNaN(e1)) { // new location worse than before lambda *= 10.; } else { // new location better, accept new parameters lambda *= 0.1; e0 = e1; // simply assigning a = na will not get results copied back to caller for( int i = 0; i < nparm; i++ ) { if (vary[i]) a[i] = na[i]; } } } while(!done); return lambda; } //solve //---------------------------------------------------------------- /** * solve for phase, amplitude and frequency of a sinusoid */ static class LMSineTest implements LMfunc { static final int PHASE = 0; static final int AMP = 1; static final int FREQ = 2; public double[] initial() { double[] a = new double[3]; a[PHASE] = 0.; a[AMP] = 1.; a[FREQ] = 1.; return a; } //initial public double val(double[] x, double[] a) { return a[AMP] * Math.sin(a[FREQ]*x[0] + a[PHASE]); } //val public double grad(double[] x, double[] a, int a_k) { if (a_k == AMP) return Math.sin(a[FREQ]*x[0] + a[PHASE]); else if (a_k == FREQ) return a[AMP] * Math.cos(a[FREQ]*x[0] + a[PHASE]) * x[0]; else if (a_k == PHASE) return a[AMP] * Math.cos(a[FREQ]*x[0] + a[PHASE]); else { assert false; return 0.; } } //grad public Object[] testdata() { double[] a = new double[3]; a[PHASE] = 0.111; a[AMP] = 1.222; a[FREQ] = 1.333; int npts = 10; double[][] x = new double[npts][1]; double[] y = new double[npts]; double[] s = new double[npts]; for( int i = 0; i < npts; i++ ) { x[i][0] = (double)i / npts; y[i] = val(x[i], a); s[i] = 1.; } Object[] o = new Object[4]; o[0] = x; o[1] = a; o[2] = y; o[3] = s; return o; } //test } //SineTest //---------------------------------------------------------------- /** * quadratic (p-o)'S'S(p-o) * solve for o, S * S is a single scale factor */ static class LMQuadTest implements LMfunc { public double val(double[] x, double[] a) { assert a.length == 3; assert x.length == 2; double ox = a[0]; double oy = a[1]; double s = a[2]; double sdx = s*(x[0] - ox); double sdy = s*(x[1] - oy); return sdx*sdx + sdy*sdy; } //val /** * z = (p-o)'S'S(p-o) * dz/dp = 2S'S(p-o) * * z = (s*(px-ox))^2 + (s*(py-oy))^2 * dz/dox = -2(s*(px-ox))*s * dz/ds = 2*s*[(px-ox)^2 + (py-oy)^2] * z = (s*dx)^2 + (s*dy)^2 * dz/ds = 2(s*dx)*dx + 2(s*dy)*dy */ public double grad(double[] x, double[] a, int a_k) { assert a.length == 3; assert x.length == 2; assert a_k < 3: "a_k="+a_k; double ox = a[0]; double oy = a[1]; double s = a[2]; double dx = (x[0] - ox); double dy = (x[1] - oy); if (a_k == 0) return -2.*s*s*dx; else if (a_k == 1) return -2.*s*s*dy; else return 2.*s*(dx*dx + dy*dy); } //grad public double[] initial() { double[] a = new double[3]; a[0] = 0.05; a[1] = 0.1; a[2] = 1.0; return a; } //initial public Object[] testdata() { Object[] o = new Object[4]; int npts = 25; double[][] x = new double[npts][2]; double[] y = new double[npts]; double[] s = new double[npts]; double[] a = new double[3]; a[0] = 0.; a[1] = 0.; a[2] = 0.9; int i = 0; for( int r = -2; r <= 2; r++ ) { for( int c = -2; c <= 2; c++ ) { x[i][0] = c; x[i][1] = r; y[i] = val(x[i], a); System.out.println("Quad "+c+","+r+" -> "+y[i]); s[i] = 1.; i++; } } System.out.print("quad x= "); (new Matrix(x)).print(10, 2); System.out.print("quad y= "); (new Matrix(y,npts)).print(10, 2); o[0] = x; o[1] = a; o[2] = y; o[3] = s; return o; } //testdata } //LMQuadTest //---------------------------------------------------------------- /** * Replicate the example in NR, fit a sum of Gaussians to data. * y(x) = \sum B_k exp(-((x - E_k) / G_k)^2) * minimize chisq = \sum { y[j] - \sum B_k exp(-((x_j - E_k) / G_k)^2) }^2 * * B_k, E_k, G_k are stored in that order * * Works, results are close to those from the NR example code. */ static class LMGaussTest implements LMfunc { static double SPREAD = 0.001; // noise variance public double val(double[] x, double[] a) { assert x.length == 1; assert (a.length%3) == 0; int K = a.length / 3; int i = 0; double y = 0.; for( int j = 0; j < K; j++ ) { double arg = (x[0] - a[i+1]) / a[i+2]; double ex = Math.exp(- arg*arg); y += (a[i] * ex); i += 3; } return y; } //val /** *
     * y(x) = \sum B_k exp(-((x - E_k) / G_k)^2)
     * arg  =  (x-E_k)/G_k
     * ex   =  exp(-arg*arg)
     * fac =   B_k * ex * 2 * arg
     * 
     * d/dB_k = exp(-((x - E_k) / G_k)^2)
     *
     * d/dE_k = B_k exp(-((x - E_k) / G_k)^2) . -2((x - E_k) / G_k) . -1/G_k
     *        = 2 * B_k * ex * arg / G_k
     *   d/E_k[-((x - E_k) / G_k)^2] = -2((x - E_k) / G_k) d/dE_k[(x-E_k)/G_k]
     *   d/dE_k[(x-E_k)/G_k] = -1/G_k
     *
     * d/G_k = B_k exp(-((x - E_k) / G_k)^2) . -2((x - E_k) / G_k) . -(x-E_k)/G_k^2
     *       = B_k ex -2 arg -arg / G_k
     *       = fac arg / G_k
     *   d/dx[1/x] = d/dx[x^-1] = -x[x^-2]
     */
    public double grad(double[] x, double[] a, int a_k)
    {
      assert x.length == 1;

      // i - index one of the K Gaussians
      int i = 3*(a_k / 3);

      double arg = (x[0] - a[i+1]) / a[i+2];
      double ex = Math.exp(- arg*arg);
      double fac = a[i] * ex * 2. * arg;

      if (a_k == i)
	return ex;

      else if (a_k == (i+1)) {
	return fac / a[i+2];
      }

      else if (a_k == (i+2)) {
	return fac * arg / a[i+2];
      }

      else {
	System.err.println("bad a_k");
	return 1.;
      }

    } //grad


    public double[] initial()
    {
      double[] a = new double[6];
      a[0] = 4.5;
      a[1] = 2.2;
      a[2] = 2.8;

      a[3] = 2.5;
      a[4] = 4.9;
      a[5] = 2.8;
      return a;
    } //initial


    public Object[] testdata()
    {
      Object[] o = new Object[4];
      int npts = 100;
      double[][] x = new double[npts][1];
      double[] y = new double[npts];
      double[] s = new double[npts];
      double[] a = new double[6];

      a[0] = 5.0;	// values returned by initial
      a[1] = 2.0;	// should be fairly close to these
      a[2] = 3.0;
      a[3] = 2.0;
      a[4] = 5.0;
      a[5] = 3.0;

      for( int i = 0; i < npts; i++ ) {
	x[i][0] = 0.1*(i+1);	// NR always counts from 1
	y[i] = val(x[i], a);
	s[i] = SPREAD * y[i];
	System.out.println(i+": x,y= "+x[i][0]+", "+y[i]);
      }

      o[0] = x;
      o[1] = a;
      o[2] = y;
      o[3] = s;

      return o;
    } //testdata

  } //LMGaussTest

  //----------------------------------------------------------------

  // test program
  public static void main(String[] cmdline)
  {

    LMfunc f = new LMQuadTest();
    //LMfunc f = new LMSineTest();	// works
    //LMfunc f = new LMGaussTest();	// works

    double[] aguess = f.initial();
    Object[] test = f.testdata();
    double[][] x = (double[][])test[0];
    double[] areal = (double[])test[1];
    double[] y = (double[])test[2];
    double[] s = (double[])test[3];
    boolean[] vary = new boolean[aguess.length];
    for( int i = 0; i < aguess.length; i++ ) vary[i] = true;
    assert aguess.length == areal.length;

    try {
      solve( x, aguess, y, s, vary, f, 0.001, 0.01, 100, 2);
    }
    catch(Exception ex) {
      System.err.println("Exception caught: " + ex.getMessage());
      System.exit(1); 
    }

    System.out.print("desired solution "); 
    (new Matrix(areal, areal.length)).print(10, 2);

    System.exit(0);
  } //main

} //LM