001    package org.maltparser.ml.lib;
002    
003    import java.io.Serializable;
004    
005    import org.maltparser.core.helper.Util;
006    
007    import libsvm.svm_model;
008    import libsvm.svm_node;
009    import libsvm.svm_parameter;
010    import libsvm.svm_problem;
011    
012    
013    /**
014     * <p>This class borrows code from libsvm.svm.java of the Java implementation of the libsvm package.
015     * MaltLibsvmModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
016     * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 
017     * structure directly on the model. </p> 
018     * 
019     * @author Johan Hall
020     *
021     */
022    public class MaltLibsvmModel implements Serializable, MaltLibModel {
023            private static final long serialVersionUID = 7526471155622776147L;
024        public svm_parameter param;     // parameter                                                                                                                                                                                        
025        public int nr_class;            // number of classes, = 2 in regression/one class svm                                                                                                                                               
026            public int l;                   // total #SV                                                                                                                                                                                        
027        public svm_node[][] SV; // SVs (SV[l])                                                                                                                                                                                              
028        public double[][] sv_coef;      // coefficients for SVs in decision functions (sv_coef[k-1][l])                                                                                                                                     
029        public double[] rho;            // constants in decision functions (rho[k*(k-1)/2])                                                                                                                                                 
030    
031        // for classification only
032        public int[] label;             // label of each class (label[k])                                                                                                                                                                   
033        public int[] nSV;               // number of SVs for each class (nSV[k])                                                                                                                                                            
034                                    // nSV[0] + nSV[1] + ... + nSV[k-1] = l   
035        public int[] start;
036        
037        public MaltLibsvmModel(svm_model model, svm_problem problem) {
038            this.param = model.param;
039            this.nr_class = model.nr_class;
040            this.l = model.l;
041            this.SV = model.SV;
042            this.sv_coef = model.sv_coef;
043            this.rho = model.rho;
044            this.label = model.label;
045            this.nSV = model.nSV;
046                    start = new int[nr_class];
047                    start[0] = 0;
048                    for(int i=1;i<nr_class;i++) {
049                            start[i] = start[i-1]+nSV[i-1];
050                    }
051        }
052        
053        public int[] predict(MaltFeatureNode[] x) { 
054            final double[] dec_values = new double[nr_class*(nr_class-1)/2];
055                    final double[] kvalue = new double[l];
056                    final int[] vote = new int[nr_class];
057                    int i;
058                    for(i=0;i<l;i++) {
059                            kvalue[i] = MaltLibsvmModel.k_function(x,SV[i],param);
060                    }
061                    for(i=0;i<nr_class;i++) {
062                            vote[i] = 0;
063                    }
064                    
065                    int p=0;
066                    for(i=0;i<nr_class;i++) {
067                            for(int j=i+1;j<nr_class;j++) {
068                                    double sum = 0;
069                                    int si = start[i];
070                                    int sj = start[j];
071                                    int ci = nSV[i];
072                                    int cj = nSV[j];
073                            
074                                    int k;
075                                    double[] coef1 = sv_coef[j-1];
076                                    double[] coef2 = sv_coef[i];
077                                    for(k=0;k<ci;k++)
078                                            sum += coef1[si+k] * kvalue[si+k];
079                                    for(k=0;k<cj;k++)
080                                            sum += coef2[sj+k] * kvalue[sj+k];
081                                    sum -= rho[p];
082                                    dec_values[p] = sum;                                    
083    
084                                    if(dec_values[p] > 0)
085                                            ++vote[i];
086                                    else
087                                            ++vote[j];
088                                    p++;
089                            }
090                    }
091                    
092                    final int[] predictionList = Util.copyOf(label, nr_class); 
093                    int tmpObj;
094                    int lagest;
095                    int tmpint;
096                    final int nc =  nr_class-1;
097                    for (i=0; i < nc; i++) {
098                            lagest = i;
099                            for (int j=i; j < nr_class; j++) {
100                                    if (vote[j] > vote[lagest]) {
101                                            lagest = j;
102                                    }
103                            }
104                            tmpint = vote[lagest];
105                            vote[lagest] = vote[i];
106                            vote[i] = tmpint;
107                            tmpObj = predictionList[lagest];
108                            predictionList[lagest] = predictionList[i];
109                            predictionList[i] = tmpObj;
110                    }
111                    return predictionList;
112        }
113        
114            static double dot(MaltFeatureNode[] x, svm_node[] y) {
115                    double sum = 0;
116                    final int xlen = x.length;
117                    final int ylen = y.length;
118                    int i = 0;
119                    int j = 0;
120                    while(i < xlen && j < ylen)
121                    {
122                            if(x[i].index == y[j].index)
123                                    sum += x[i++].value * y[j++].value;
124                            else
125                            {
126                                    if(x[i].index > y[j].index)
127                                            ++j;
128                                    else
129                                            ++i;
130                            }
131                    }
132                    return sum;
133            }
134            
135            static double powi(double base, int times) {
136                    double tmp = base, ret = 1.0;
137    
138                    for(int t=times; t>0; t/=2)
139                    {
140                            if(t%2==1) ret*=tmp;
141                            tmp = tmp * tmp;
142                    }
143                    return ret;
144            }
145            
146            static double k_function(MaltFeatureNode[] x, svm_node[] y, svm_parameter param) {
147                    switch(param.kernel_type)
148                    {
149                            case svm_parameter.LINEAR:
150                                    return dot(x,y);
151                            case svm_parameter.POLY:
152                                    return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
153                            case svm_parameter.RBF:
154                            {
155                                    double sum = 0;
156                                    int xlen = x.length;
157                                    int ylen = y.length;
158                                    int i = 0;
159                                    int j = 0;
160                                    while(i < xlen && j < ylen)
161                                    {
162                                            if(x[i].index == y[j].index)
163                                            {
164                                                    double d = x[i++].value - y[j++].value;
165                                                    sum += d*d;
166                                            }
167                                            else if(x[i].index > y[j].index)
168                                            {
169                                                    sum += y[j].value * y[j].value;
170                                                    ++j;
171                                            }
172                                            else
173                                            {
174                                                    sum += x[i].value * x[i].value;
175                                                    ++i;
176                                            }
177                                    }
178                    
179                                    while(i < xlen)
180                                    {
181                                            sum += x[i].value * x[i].value;
182                                            ++i;
183                                    }
184                    
185                                    while(j < ylen)
186                                    {
187                                            sum += y[j].value * y[j].value;
188                                            ++j;
189                                    }
190                    
191                                    return Math.exp(-param.gamma*sum);
192                            }
193                            case svm_parameter.SIGMOID:
194                                    return Math.tanh(param.gamma*dot(x,y)+param.coef0);
195                            case svm_parameter.PRECOMPUTED:
196                                    return  x[(int)(y[0].value)].value;
197                            default:
198                                    return 0;       // java
199                    }
200            }
201            
202            public int[] getLabels() {
203                    if (label != null) {
204                            final int[] labels = new int[nr_class];
205                            for(int i=0;i<nr_class;i++) {
206                                    labels[i] = label[i];
207                            }
208                            return labels;
209                    }
210                    return null;
211            }
212    }