// IRLS.java // // Requires old visualnumerics math library - if you do not have this // (it is no longer available), identify the matrix calls and substitute // another library, such as jama. // Also you will need to make these substitutions: // - change zliberror._assert to assert and delete the zlib import // (or just delete all the assert calls) // - change or delete the matrix.print calls. // This library is free software; you can redistribute it and/or // modify it under the terms of the GNU Library General Public // License as published by the Free Software Foundation; either // version 2 of the License, or (at your option) any later version. // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Library General Public License for more details. // // You should have received a copy of the GNU Library General Public // License along with this library; if not, write to the // Free Software Foundation, Inc., 59 Temple Place - Suite 330, // Boston, MA 02111-1307, USA. // // Primary author contact info: // j.p.lewis www.idiom.com/~zilla zilla@computer.org package ZS; import VisualNumerics.math.*; import zlib.*; public class IRLS { final static int _verbose = 1; /** * At = A transpose * p = desired power, e.g. slightly above 1 * we take At from the caller because the caller may have already * computed it * * The system should be overdetermined (A more rows than columns), * otherwise the solution will typically be zero * (and L1 of zero is no different than L2 zero). */ public static void solve(final double[][] A, final double[][] At, double[] x, final double[] b, double p, int niter) { zliberror._assert(A.length == b.length); zliberror._assert(A[0].length == x.length); double[] x1 = l2solve(A,At, b); refine(A,At,x1,b, p,niter); int nc = x.length; for( int ic = 0; ic < nc; ic++ ) x[ic] = x1[ic]; } //solve //---------------------------------------------------------------- public static void refine(final double[][] A, final double[][] At, double[] x, final double[] b, double p, int niter) { int nr = b.length; int nc = x.length; double[] r = new double[nr]; // residual double[] w = new double[nr]; // weights double pm2 = (p-2.) / 2.; double eps = 100. * Double.MIN_VALUE; System.out.println("eps = "+eps); for( int iter = 0; iter < niter; iter++ ) { if (_verbose > 0) printerr(A,x,b); double[] Ax = DoubleMatrix.multiply(A,x); // residual boolean nonzero = false; for( int ir = 0; ir < nr; ir++ ) { r[ir] = b[ir] - Ax[ir]; if (r[ir] != 0.) nonzero = true; double d = Math.pow( Math.abs(r[ir]) + eps, pm2 ); w[ir] = d; } if (_verbose > 0) matrix.print("residual", r); if (_verbose > 0) matrix.print("weights", w); if (!nonzero) { System.out.println("IRLS returning - solution is exact"); return; } double[][] AtW2 = matrix.diagMul2(At, w); double[][] AtW2A = DoubleMatrix.multiply(AtW2, A); double[] AtW2r = DoubleMatrix.multiply(AtW2, r); double[] dx = l2solve(AtW2A,AtW2A, AtW2r); zliberror._assert(dx.length == nc); for( int ic = 0; ic < nc; ic++ ) x[ic] += dx[ic]; if (_verbose > 0) matrix.print("new x = ", x); } if (_verbose > 0) printerr(A,x,b); } //refine //---------------------------------------------------------------- public static double[] l2solve(final double[][] A, final double[][] At, final double[] b) { try { if (_verbose > 0) System.out.println(" solving "+A.length+"x"+A[0].length); if (A.length == A[0].length) { return VisualNumerics.math.DoubleMatrix.solve(A, b); } else { double[][] AtA = DoubleMatrix.multiply(At, A); double[] Atb = DoubleMatrix.multiply(At, b); return VisualNumerics.math.DoubleMatrix.solve(AtA, Atb); } } catch(Exception ex) { zliberror.die(ex); //System.err.println(ex); //System.exit(1); } return null; } //l2solve //---------------------------------------------------------------- static void printerr(double[][] A, double[] x, double[] b) { double[] Ax = DoubleMatrix.multiply(A,x); double sum2 = 0.; double sum1 = 0.; for( int i = 0; i < b.length; i++ ) { double r = b[i] - Ax[i]; sum2 += (r*r); sum1 += Math.abs(r); } System.out.println(" irls L1="+sum1+" L2="+sum2); } //printerr //---------------------------------------------------------------- /** * test on a line fit */ public static void main(String[] cmdline) { int npts = 50; double a = 0.1; double b = 10.; double[][] A = new double[50][2]; double[] y = new double[50]; for( int i = 0; i < npts; i++ ) { double x = (100. * i) / npts; A[i][0] = x; A[i][1] = 1.; y[i] = a * x + b; y[i] += (10. * (2.0*Math.random()-1.0)); } //matrix.print("A=",A); matrix.print("y=",y); double[][] At = DoubleMatrix.transpose(A); double[] u = new double[2]; u = IRLS.l2solve(A,At, y); System.out.println("desired a,b="+a+","+b+ " L2 recovered a,b="+u[0]+","+u[1]); // on this problem most of the reduction in L1 error happens in // the first ~5 iterations IRLS.solve(A,At, u, y, 1.2, 10); System.out.println("desired a,b="+a+","+b+ " L1 recovered a,b="+u[0]+","+u[1]); System.exit(0); } //main } //IRLS