001package org.maltparser.ml.lib; 002 003import java.io.Serializable; 004 005import libsvm.svm_model; 006import libsvm.svm_node; 007import libsvm.svm_parameter; 008import libsvm.svm_problem; 009 010 011/** 012 * <p>This class borrows code from libsvm.svm.java of the Java implementation of the libsvm package. 013 * MaltLibsvmModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to 014 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 015 * structure directly on the model. </p> 016 * 017 * @author Johan Hall 018 * 019 */ 020public class MaltLibsvmModel implements Serializable, MaltLibModel { 021 private static final long serialVersionUID = 7526471155622776147L; 022 public svm_parameter param; // parameter 023 public int nr_class; // number of classes, = 2 in regression/one class svm 024 public int l; // total #SV 025 public svm_node[][] SV; // SVs (SV[l]) 026 public double[][] sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l]) 027 public double[] rho; // constants in decision functions (rho[k*(k-1)/2]) 028 029 // for classification only 030 public int[] label; // label of each class (label[k]) 031 public int[] nSV; // number of SVs for each class (nSV[k]) 032 // nSV[0] + nSV[1] + ... + nSV[k-1] = l 033 public int[] start; 034 035 public MaltLibsvmModel(svm_model model, svm_problem problem) { 036 this.param = model.param; 037 this.nr_class = model.nr_class; 038 this.l = model.l; 039 this.SV = model.SV; 040 this.sv_coef = model.sv_coef; 041 this.rho = model.rho; 042 this.label = model.label; 043 this.nSV = model.nSV; 044 start = new int[nr_class]; 045 start[0] = 0; 046 for(int i=1;i<nr_class;i++) { 047 start[i] = start[i-1]+nSV[i-1]; 048 } 049 } 050 051 public int[] predict(MaltFeatureNode[] x) { 052 final double[] dec_values = new double[nr_class*(nr_class-1)/2]; 053 final double[] kvalue = new double[l]; 054 final int[] vote = new int[nr_class]; 055 int i; 056 for(i=0;i<l;i++) { 057 kvalue[i] = MaltLibsvmModel.k_function(x,SV[i],param); 058 } 059 for(i=0;i<nr_class;i++) { 060 vote[i] = 0; 061 } 062 063 int p=0; 064 for(i=0;i<nr_class;i++) { 065 for(int j=i+1;j<nr_class;j++) { 066 double sum = 0; 067 int si = start[i]; 068 int sj = start[j]; 069 int ci = nSV[i]; 070 int cj = nSV[j]; 071 072 int k; 073 double[] coef1 = sv_coef[j-1]; 074 double[] coef2 = sv_coef[i]; 075 for(k=0;k<ci;k++) 076 sum += coef1[si+k] * kvalue[si+k]; 077 for(k=0;k<cj;k++) 078 sum += coef2[sj+k] * kvalue[sj+k]; 079 sum -= rho[p]; 080 dec_values[p] = sum; 081 082 if(dec_values[p] > 0) 083 ++vote[i]; 084 else 085 ++vote[j]; 086 p++; 087 } 088 } 089 090 final int[] predictionList = new int[nr_class]; 091 System.arraycopy(label, 0, predictionList, 0, nr_class); 092 int tmp; 093 int iMax; 094 final int nc = nr_class-1; 095 for (i=0; i < nc; i++) { 096 iMax = i; 097 for (int j=i+1; j < nr_class; j++) { 098 if (vote[j] > vote[iMax]) { 099 iMax = j; 100 } 101 } 102 if (iMax != i) { 103 tmp = vote[iMax]; 104 vote[iMax] = vote[i]; 105 vote[i] = tmp; 106 tmp = predictionList[iMax]; 107 predictionList[iMax] = predictionList[i]; 108 predictionList[i] = tmp; 109 } 110 } 111 return predictionList; 112 } 113 114 115 public int predict_one(MaltFeatureNode[] x) { 116 final double[] dec_values = new double[nr_class*(nr_class-1)/2]; 117 final double[] kvalue = new double[l]; 118 final int[] vote = new int[nr_class]; 119 int i; 120 for(i=0;i<l;i++) { 121 kvalue[i] = MaltLibsvmModel.k_function(x,SV[i],param); 122 } 123 for(i=0;i<nr_class;i++) { 124 vote[i] = 0; 125 } 126 127 int p=0; 128 for(i=0;i<nr_class;i++) { 129 for(int j=i+1;j<nr_class;j++) { 130 double sum = 0; 131 int si = start[i]; 132 int sj = start[j]; 133 int ci = nSV[i]; 134 int cj = nSV[j]; 135 136 int k; 137 double[] coef1 = sv_coef[j-1]; 138 double[] coef2 = sv_coef[i]; 139 for(k=0;k<ci;k++) 140 sum += coef1[si+k] * kvalue[si+k]; 141 for(k=0;k<cj;k++) 142 sum += coef2[sj+k] * kvalue[sj+k]; 143 sum -= rho[p]; 144 dec_values[p] = sum; 145 146 if(dec_values[p] > 0) 147 ++vote[i]; 148 else 149 ++vote[j]; 150 p++; 151 } 152 } 153 154 155 int max = vote[0]; 156 int max_index = 0; 157 for (i = 1; i < vote.length; i++) { 158 if (vote[i] > max) { 159 max = vote[i]; 160 max_index = i; 161 } 162 } 163 164 return label[max_index]; 165 } 166 167 static double dot(MaltFeatureNode[] x, svm_node[] y) { 168 double sum = 0; 169 final int xlen = x.length; 170 final int ylen = y.length; 171 int i = 0; 172 int j = 0; 173 while(i < xlen && j < ylen) 174 { 175 if(x[i].index == y[j].index) 176 sum += x[i++].value * y[j++].value; 177 else 178 { 179 if(x[i].index > y[j].index) 180 ++j; 181 else 182 ++i; 183 } 184 } 185 return sum; 186 } 187 188 static double powi(double base, int times) { 189 double tmp = base, ret = 1.0; 190 191 for(int t=times; t>0; t/=2) 192 { 193 if(t%2==1) ret*=tmp; 194 tmp = tmp * tmp; 195 } 196 return ret; 197 } 198 199 static double k_function(MaltFeatureNode[] x, svm_node[] y, svm_parameter param) { 200 switch(param.kernel_type) 201 { 202 case svm_parameter.LINEAR: 203 return dot(x,y); 204 case svm_parameter.POLY: 205 return powi(param.gamma*dot(x,y)+param.coef0,param.degree); 206 case svm_parameter.RBF: 207 { 208 double sum = 0; 209 int xlen = x.length; 210 int ylen = y.length; 211 int i = 0; 212 int j = 0; 213 while(i < xlen && j < ylen) 214 { 215 if(x[i].index == y[j].index) 216 { 217 double d = x[i++].value - y[j++].value; 218 sum += d*d; 219 } 220 else if(x[i].index > y[j].index) 221 { 222 sum += y[j].value * y[j].value; 223 ++j; 224 } 225 else 226 { 227 sum += x[i].value * x[i].value; 228 ++i; 229 } 230 } 231 232 while(i < xlen) 233 { 234 sum += x[i].value * x[i].value; 235 ++i; 236 } 237 238 while(j < ylen) 239 { 240 sum += y[j].value * y[j].value; 241 ++j; 242 } 243 244 return Math.exp(-param.gamma*sum); 245 } 246 case svm_parameter.SIGMOID: 247 return Math.tanh(param.gamma*dot(x,y)+param.coef0); 248 case svm_parameter.PRECOMPUTED: 249 return x[(int)(y[0].value)].value; 250 default: 251 return 0; // java 252 } 253 } 254 255 public int[] getLabels() { 256 if (label != null) { 257 final int[] labels = new int[nr_class]; 258 for(int i=0;i<nr_class;i++) { 259 labels[i] = label[i]; 260 } 261 return labels; 262 } 263 return null; 264 } 265}