001 package org.maltparser.ml.lib; 002 003 import java.io.Serializable; 004 005 import org.maltparser.core.helper.Util; 006 007 import libsvm.svm_model; 008 import libsvm.svm_node; 009 import libsvm.svm_parameter; 010 import libsvm.svm_problem; 011 012 013 /** 014 * <p>This class borrows code from libsvm.svm.java of the Java implementation of the libsvm package. 015 * MaltLibsvmModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to 016 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 017 * structure directly on the model. </p> 018 * 019 * @author Johan Hall 020 * 021 */ 022 public class MaltLibsvmModel implements Serializable, MaltLibModel { 023 private static final long serialVersionUID = 7526471155622776147L; 024 public svm_parameter param; // parameter 025 public int nr_class; // number of classes, = 2 in regression/one class svm 026 public int l; // total #SV 027 public svm_node[][] SV; // SVs (SV[l]) 028 public double[][] sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l]) 029 public double[] rho; // constants in decision functions (rho[k*(k-1)/2]) 030 031 // for classification only 032 public int[] label; // label of each class (label[k]) 033 public int[] nSV; // number of SVs for each class (nSV[k]) 034 // nSV[0] + nSV[1] + ... + nSV[k-1] = l 035 public int[] start; 036 037 public MaltLibsvmModel(svm_model model, svm_problem problem) { 038 this.param = model.param; 039 this.nr_class = model.nr_class; 040 this.l = model.l; 041 this.SV = model.SV; 042 this.sv_coef = model.sv_coef; 043 this.rho = model.rho; 044 this.label = model.label; 045 this.nSV = model.nSV; 046 start = new int[nr_class]; 047 start[0] = 0; 048 for(int i=1;i<nr_class;i++) { 049 start[i] = start[i-1]+nSV[i-1]; 050 } 051 } 052 053 public int[] predict(MaltFeatureNode[] x) { 054 final double[] dec_values = new double[nr_class*(nr_class-1)/2]; 055 final double[] kvalue = new double[l]; 056 final int[] vote = new int[nr_class]; 057 int i; 058 for(i=0;i<l;i++) { 059 kvalue[i] = MaltLibsvmModel.k_function(x,SV[i],param); 060 } 061 for(i=0;i<nr_class;i++) { 062 vote[i] = 0; 063 } 064 065 int p=0; 066 for(i=0;i<nr_class;i++) { 067 for(int j=i+1;j<nr_class;j++) { 068 double sum = 0; 069 int si = start[i]; 070 int sj = start[j]; 071 int ci = nSV[i]; 072 int cj = nSV[j]; 073 074 int k; 075 double[] coef1 = sv_coef[j-1]; 076 double[] coef2 = sv_coef[i]; 077 for(k=0;k<ci;k++) 078 sum += coef1[si+k] * kvalue[si+k]; 079 for(k=0;k<cj;k++) 080 sum += coef2[sj+k] * kvalue[sj+k]; 081 sum -= rho[p]; 082 dec_values[p] = sum; 083 084 if(dec_values[p] > 0) 085 ++vote[i]; 086 else 087 ++vote[j]; 088 p++; 089 } 090 } 091 092 final int[] predictionList = Util.copyOf(label, nr_class); 093 int tmpObj; 094 int lagest; 095 int tmpint; 096 final int nc = nr_class-1; 097 for (i=0; i < nc; i++) { 098 lagest = i; 099 for (int j=i; j < nr_class; j++) { 100 if (vote[j] > vote[lagest]) { 101 lagest = j; 102 } 103 } 104 tmpint = vote[lagest]; 105 vote[lagest] = vote[i]; 106 vote[i] = tmpint; 107 tmpObj = predictionList[lagest]; 108 predictionList[lagest] = predictionList[i]; 109 predictionList[i] = tmpObj; 110 } 111 return predictionList; 112 } 113 114 static double dot(MaltFeatureNode[] x, svm_node[] y) { 115 double sum = 0; 116 final int xlen = x.length; 117 final int ylen = y.length; 118 int i = 0; 119 int j = 0; 120 while(i < xlen && j < ylen) 121 { 122 if(x[i].index == y[j].index) 123 sum += x[i++].value * y[j++].value; 124 else 125 { 126 if(x[i].index > y[j].index) 127 ++j; 128 else 129 ++i; 130 } 131 } 132 return sum; 133 } 134 135 static double powi(double base, int times) { 136 double tmp = base, ret = 1.0; 137 138 for(int t=times; t>0; t/=2) 139 { 140 if(t%2==1) ret*=tmp; 141 tmp = tmp * tmp; 142 } 143 return ret; 144 } 145 146 static double k_function(MaltFeatureNode[] x, svm_node[] y, svm_parameter param) { 147 switch(param.kernel_type) 148 { 149 case svm_parameter.LINEAR: 150 return dot(x,y); 151 case svm_parameter.POLY: 152 return powi(param.gamma*dot(x,y)+param.coef0,param.degree); 153 case svm_parameter.RBF: 154 { 155 double sum = 0; 156 int xlen = x.length; 157 int ylen = y.length; 158 int i = 0; 159 int j = 0; 160 while(i < xlen && j < ylen) 161 { 162 if(x[i].index == y[j].index) 163 { 164 double d = x[i++].value - y[j++].value; 165 sum += d*d; 166 } 167 else if(x[i].index > y[j].index) 168 { 169 sum += y[j].value * y[j].value; 170 ++j; 171 } 172 else 173 { 174 sum += x[i].value * x[i].value; 175 ++i; 176 } 177 } 178 179 while(i < xlen) 180 { 181 sum += x[i].value * x[i].value; 182 ++i; 183 } 184 185 while(j < ylen) 186 { 187 sum += y[j].value * y[j].value; 188 ++j; 189 } 190 191 return Math.exp(-param.gamma*sum); 192 } 193 case svm_parameter.SIGMOID: 194 return Math.tanh(param.gamma*dot(x,y)+param.coef0); 195 case svm_parameter.PRECOMPUTED: 196 return x[(int)(y[0].value)].value; 197 default: 198 return 0; // java 199 } 200 } 201 202 public int[] getLabels() { 203 if (label != null) { 204 final int[] labels = new int[nr_class]; 205 for(int i=0;i<nr_class;i++) { 206 labels[i] = label[i]; 207 } 208 return labels; 209 } 210 return null; 211 } 212 }