001 package org.maltparser.ml.lib;
002
003 import java.io.Serializable;
004
005 import org.maltparser.core.helper.Util;
006
007 import libsvm.svm_model;
008 import libsvm.svm_node;
009 import libsvm.svm_parameter;
010 import libsvm.svm_problem;
011
012
013 /**
014 * <p>This class borrows code from libsvm.svm.java of the Java implementation of the libsvm package.
015 * MaltLibsvmModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
016 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data
017 * structure directly on the model. </p>
018 *
019 * @author Johan Hall
020 *
021 */
022 public class MaltLibsvmModel implements Serializable, MaltLibModel {
023 private static final long serialVersionUID = 7526471155622776147L;
024 public svm_parameter param; // parameter
025 public int nr_class; // number of classes, = 2 in regression/one class svm
026 public int l; // total #SV
027 public svm_node[][] SV; // SVs (SV[l])
028 public double[][] sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l])
029 public double[] rho; // constants in decision functions (rho[k*(k-1)/2])
030
031 // for classification only
032 public int[] label; // label of each class (label[k])
033 public int[] nSV; // number of SVs for each class (nSV[k])
034 // nSV[0] + nSV[1] + ... + nSV[k-1] = l
035 public int[] start;
036
037 public MaltLibsvmModel(svm_model model, svm_problem problem) {
038 this.param = model.param;
039 this.nr_class = model.nr_class;
040 this.l = model.l;
041 this.SV = model.SV;
042 this.sv_coef = model.sv_coef;
043 this.rho = model.rho;
044 this.label = model.label;
045 this.nSV = model.nSV;
046 start = new int[nr_class];
047 start[0] = 0;
048 for(int i=1;i<nr_class;i++) {
049 start[i] = start[i-1]+nSV[i-1];
050 }
051 }
052
053 public int[] predict(MaltFeatureNode[] x) {
054 final double[] dec_values = new double[nr_class*(nr_class-1)/2];
055 final double[] kvalue = new double[l];
056 final int[] vote = new int[nr_class];
057 int i;
058 for(i=0;i<l;i++) {
059 kvalue[i] = MaltLibsvmModel.k_function(x,SV[i],param);
060 }
061 for(i=0;i<nr_class;i++) {
062 vote[i] = 0;
063 }
064
065 int p=0;
066 for(i=0;i<nr_class;i++) {
067 for(int j=i+1;j<nr_class;j++) {
068 double sum = 0;
069 int si = start[i];
070 int sj = start[j];
071 int ci = nSV[i];
072 int cj = nSV[j];
073
074 int k;
075 double[] coef1 = sv_coef[j-1];
076 double[] coef2 = sv_coef[i];
077 for(k=0;k<ci;k++)
078 sum += coef1[si+k] * kvalue[si+k];
079 for(k=0;k<cj;k++)
080 sum += coef2[sj+k] * kvalue[sj+k];
081 sum -= rho[p];
082 dec_values[p] = sum;
083
084 if(dec_values[p] > 0)
085 ++vote[i];
086 else
087 ++vote[j];
088 p++;
089 }
090 }
091
092 final int[] predictionList = Util.copyOf(label, nr_class);
093 int tmpObj;
094 int lagest;
095 int tmpint;
096 final int nc = nr_class-1;
097 for (i=0; i < nc; i++) {
098 lagest = i;
099 for (int j=i; j < nr_class; j++) {
100 if (vote[j] > vote[lagest]) {
101 lagest = j;
102 }
103 }
104 tmpint = vote[lagest];
105 vote[lagest] = vote[i];
106 vote[i] = tmpint;
107 tmpObj = predictionList[lagest];
108 predictionList[lagest] = predictionList[i];
109 predictionList[i] = tmpObj;
110 }
111 return predictionList;
112 }
113
114 static double dot(MaltFeatureNode[] x, svm_node[] y) {
115 double sum = 0;
116 final int xlen = x.length;
117 final int ylen = y.length;
118 int i = 0;
119 int j = 0;
120 while(i < xlen && j < ylen)
121 {
122 if(x[i].index == y[j].index)
123 sum += x[i++].value * y[j++].value;
124 else
125 {
126 if(x[i].index > y[j].index)
127 ++j;
128 else
129 ++i;
130 }
131 }
132 return sum;
133 }
134
135 static double powi(double base, int times) {
136 double tmp = base, ret = 1.0;
137
138 for(int t=times; t>0; t/=2)
139 {
140 if(t%2==1) ret*=tmp;
141 tmp = tmp * tmp;
142 }
143 return ret;
144 }
145
146 static double k_function(MaltFeatureNode[] x, svm_node[] y, svm_parameter param) {
147 switch(param.kernel_type)
148 {
149 case svm_parameter.LINEAR:
150 return dot(x,y);
151 case svm_parameter.POLY:
152 return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
153 case svm_parameter.RBF:
154 {
155 double sum = 0;
156 int xlen = x.length;
157 int ylen = y.length;
158 int i = 0;
159 int j = 0;
160 while(i < xlen && j < ylen)
161 {
162 if(x[i].index == y[j].index)
163 {
164 double d = x[i++].value - y[j++].value;
165 sum += d*d;
166 }
167 else if(x[i].index > y[j].index)
168 {
169 sum += y[j].value * y[j].value;
170 ++j;
171 }
172 else
173 {
174 sum += x[i].value * x[i].value;
175 ++i;
176 }
177 }
178
179 while(i < xlen)
180 {
181 sum += x[i].value * x[i].value;
182 ++i;
183 }
184
185 while(j < ylen)
186 {
187 sum += y[j].value * y[j].value;
188 ++j;
189 }
190
191 return Math.exp(-param.gamma*sum);
192 }
193 case svm_parameter.SIGMOID:
194 return Math.tanh(param.gamma*dot(x,y)+param.coef0);
195 case svm_parameter.PRECOMPUTED:
196 return x[(int)(y[0].value)].value;
197 default:
198 return 0; // java
199 }
200 }
201
202 public int[] getLabels() {
203 if (label != null) {
204 final int[] labels = new int[nr_class];
205 for(int i=0;i<nr_class;i++) {
206 labels[i] = label[i];
207 }
208 return labels;
209 }
210 return null;
211 }
212 }