001package org.maltparser.ml.lib;
002
003import java.io.Serializable;
004
005import libsvm.svm_model;
006import libsvm.svm_node;
007import libsvm.svm_parameter;
008import libsvm.svm_problem;
009
010
011/**
012 * <p>This class borrows code from libsvm.svm.java of the Java implementation of the libsvm package.
013 * MaltLibsvmModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
014 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 
015 * structure directly on the model. </p> 
016 * 
017 * @author Johan Hall
018 *
019 */
020public class MaltLibsvmModel implements Serializable, MaltLibModel {
021        private static final long serialVersionUID = 7526471155622776147L;
022    public svm_parameter param;     // parameter                                                                                                                                                                                        
023    public int nr_class;            // number of classes, = 2 in regression/one class svm                                                                                                                                               
024        public int l;                   // total #SV                                                                                                                                                                                        
025    public svm_node[][] SV; // SVs (SV[l])                                                                                                                                                                                              
026    public double[][] sv_coef;      // coefficients for SVs in decision functions (sv_coef[k-1][l])                                                                                                                                     
027    public double[] rho;            // constants in decision functions (rho[k*(k-1)/2])                                                                                                                                                 
028
029    // for classification only
030    public int[] label;             // label of each class (label[k])                                                                                                                                                                   
031    public int[] nSV;               // number of SVs for each class (nSV[k])                                                                                                                                                            
032                                // nSV[0] + nSV[1] + ... + nSV[k-1] = l   
033    public int[] start;
034    
035    public MaltLibsvmModel(svm_model model, svm_problem problem) {
036        this.param = model.param;
037        this.nr_class = model.nr_class;
038        this.l = model.l;
039        this.SV = model.SV;
040        this.sv_coef = model.sv_coef;
041        this.rho = model.rho;
042        this.label = model.label;
043        this.nSV = model.nSV;
044                start = new int[nr_class];
045                start[0] = 0;
046                for(int i=1;i<nr_class;i++) {
047                        start[i] = start[i-1]+nSV[i-1];
048                }
049    }
050    
051    public int[] predict(MaltFeatureNode[] x) { 
052        final double[] dec_values = new double[nr_class*(nr_class-1)/2];
053                final double[] kvalue = new double[l];
054                final int[] vote = new int[nr_class];
055                int i;
056                for(i=0;i<l;i++) {
057                        kvalue[i] = MaltLibsvmModel.k_function(x,SV[i],param);
058                }
059                for(i=0;i<nr_class;i++) {
060                        vote[i] = 0;
061                }
062                
063                int p=0;
064                for(i=0;i<nr_class;i++) {
065                        for(int j=i+1;j<nr_class;j++) {
066                                double sum = 0;
067                                int si = start[i];
068                                int sj = start[j];
069                                int ci = nSV[i];
070                                int cj = nSV[j];
071                        
072                                int k;
073                                double[] coef1 = sv_coef[j-1];
074                                double[] coef2 = sv_coef[i];
075                                for(k=0;k<ci;k++)
076                                        sum += coef1[si+k] * kvalue[si+k];
077                                for(k=0;k<cj;k++)
078                                        sum += coef2[sj+k] * kvalue[sj+k];
079                                sum -= rho[p];
080                                dec_values[p] = sum;                                    
081
082                                if(dec_values[p] > 0)
083                                        ++vote[i];
084                                else
085                                        ++vote[j];
086                                p++;
087                        }
088                }
089                
090        final int[] predictionList = new int[nr_class];
091        System.arraycopy(label, 0, predictionList, 0, nr_class);
092                int tmp;
093                int iMax;
094                final int nc =  nr_class-1;
095                for (i=0; i < nc; i++) {
096                        iMax = i;
097                        for (int j=i+1; j < nr_class; j++) {
098                                if (vote[j] > vote[iMax]) {
099                                        iMax = j;
100                                }
101                        }
102                        if (iMax != i) {
103                                tmp = vote[iMax];
104                                vote[iMax] = vote[i];
105                                vote[i] = tmp;
106                                tmp = predictionList[iMax];
107                                predictionList[iMax] = predictionList[i];
108                                predictionList[i] = tmp;
109                        }
110                }
111                return predictionList;
112    }
113    
114    
115    public int predict_one(MaltFeatureNode[] x) { 
116        final double[] dec_values = new double[nr_class*(nr_class-1)/2];
117                final double[] kvalue = new double[l];
118                final int[] vote = new int[nr_class];
119                int i;
120                for(i=0;i<l;i++) {
121                        kvalue[i] = MaltLibsvmModel.k_function(x,SV[i],param);
122                }
123                for(i=0;i<nr_class;i++) {
124                        vote[i] = 0;
125                }
126                
127                int p=0;
128                for(i=0;i<nr_class;i++) {
129                        for(int j=i+1;j<nr_class;j++) {
130                                double sum = 0;
131                                int si = start[i];
132                                int sj = start[j];
133                                int ci = nSV[i];
134                                int cj = nSV[j];
135                        
136                                int k;
137                                double[] coef1 = sv_coef[j-1];
138                                double[] coef2 = sv_coef[i];
139                                for(k=0;k<ci;k++)
140                                        sum += coef1[si+k] * kvalue[si+k];
141                                for(k=0;k<cj;k++)
142                                        sum += coef2[sj+k] * kvalue[sj+k];
143                                sum -= rho[p];
144                                dec_values[p] = sum;                                    
145
146                                if(dec_values[p] > 0)
147                                        ++vote[i];
148                                else
149                                        ++vote[j];
150                                p++;
151                        }
152                }
153                
154                
155        int max = vote[0];
156        int max_index = 0;
157                for (i = 1; i < vote.length; i++) {
158                        if (vote[i] > max) {
159                                max = vote[i];
160                                max_index = i;
161                        }
162                }
163
164                return label[max_index];
165    }
166    
167        static double dot(MaltFeatureNode[] x, svm_node[] y) {
168                double sum = 0;
169                final int xlen = x.length;
170                final int ylen = y.length;
171                int i = 0;
172                int j = 0;
173                while(i < xlen && j < ylen)
174                {
175                        if(x[i].index == y[j].index)
176                                sum += x[i++].value * y[j++].value;
177                        else
178                        {
179                                if(x[i].index > y[j].index)
180                                        ++j;
181                                else
182                                        ++i;
183                        }
184                }
185                return sum;
186        }
187        
188        static double powi(double base, int times) {
189                double tmp = base, ret = 1.0;
190
191                for(int t=times; t>0; t/=2)
192                {
193                        if(t%2==1) ret*=tmp;
194                        tmp = tmp * tmp;
195                }
196                return ret;
197        }
198        
199        static double k_function(MaltFeatureNode[] x, svm_node[] y, svm_parameter param) {
200                switch(param.kernel_type)
201                {
202                        case svm_parameter.LINEAR:
203                                return dot(x,y);
204                        case svm_parameter.POLY:
205                                return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
206                        case svm_parameter.RBF:
207                        {
208                                double sum = 0;
209                                int xlen = x.length;
210                                int ylen = y.length;
211                                int i = 0;
212                                int j = 0;
213                                while(i < xlen && j < ylen)
214                                {
215                                        if(x[i].index == y[j].index)
216                                        {
217                                                double d = x[i++].value - y[j++].value;
218                                                sum += d*d;
219                                        }
220                                        else if(x[i].index > y[j].index)
221                                        {
222                                                sum += y[j].value * y[j].value;
223                                                ++j;
224                                        }
225                                        else
226                                        {
227                                                sum += x[i].value * x[i].value;
228                                                ++i;
229                                        }
230                                }
231                
232                                while(i < xlen)
233                                {
234                                        sum += x[i].value * x[i].value;
235                                        ++i;
236                                }
237                
238                                while(j < ylen)
239                                {
240                                        sum += y[j].value * y[j].value;
241                                        ++j;
242                                }
243                
244                                return Math.exp(-param.gamma*sum);
245                        }
246                        case svm_parameter.SIGMOID:
247                                return Math.tanh(param.gamma*dot(x,y)+param.coef0);
248                        case svm_parameter.PRECOMPUTED:
249                                return  x[(int)(y[0].value)].value;
250                        default:
251                                return 0;       // java
252                }
253        }
254        
255        public int[] getLabels() {
256                if (label != null) {
257                        final int[] labels = new int[nr_class];
258                        for(int i=0;i<nr_class;i++) {
259                                labels[i] = label[i];
260                        }
261                        return labels;
262                }
263                return null;
264        }
265}