001    package org.maltparser.ml.lib;
002    
003    import java.io.BufferedReader;
004    import java.io.EOFException;
005    import java.io.File;
006    import java.io.FileInputStream;
007    import java.io.IOException;
008    import java.io.InputStreamReader;
009    import java.io.ObjectInputStream;
010    import java.io.ObjectOutputStream;
011    import java.io.Reader;
012    import java.io.Serializable;
013    import java.nio.charset.Charset;
014    import java.util.Arrays;
015    import java.util.regex.Pattern;
016    
017    import org.maltparser.core.helper.Util;
018    
019    import liblinear.SolverType;
020    
021    /**
022     * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package.
023     * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
024     * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 
025     * structure directly on the model. </p> 
026     * 
027     * @author Johan Hall
028     *
029     */
030    public class MaltLiblinearModel implements Serializable, MaltLibModel {
031            private static final long serialVersionUID = 7526471155622776147L;
032            private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
033            private double bias;
034            /** label of each class */
035            private int[] labels;
036            private int nr_class;
037            private int nr_feature;
038            private SolverType solverType;
039            /** feature weight array */
040            private double[][] w;
041    
042        public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) {
043            this.labels = labels;
044            this.nr_class = nr_class;
045            this.nr_feature = nr_feature;
046            this.w = w;
047            this.solverType = solverType;   
048        }
049        
050        public MaltLiblinearModel(Reader inputReader) throws IOException {
051            loadModel(inputReader);
052        }
053        
054        public MaltLiblinearModel(File modelFile) throws IOException {
055            BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
056            loadModel(inputReader);
057        }
058        
059        /**
060        * @return number of classes
061        */
062        public int getNrClass() {
063            return nr_class;
064        }
065    
066        /**
067        * @return number of features
068        */
069        public int getNrFeature() {
070            return nr_feature;
071        }
072    
073        public int[] getLabels() {
074            return Util.copyOf(labels, nr_class);
075        }
076    
077        /**
078        * The nr_feature*nr_class array w gives feature weights. We use one
079        * against the rest for multi-class classification, so each feature
080        * index corresponds to nr_class weight values. Weights are
081        * organized in the following way
082        *
083        * <pre>
084        * +------------------+------------------+------------+
085        * | nr_class weights | nr_class weights | ...
086        * | for 1st feature | for 2nd feature |
087        * +------------------+------------------+------------+
088        * </pre>
089        *
090        * If bias &gt;= 0, x becomes [x; bias]. The number of features is
091        * increased by one, so w is a (nr_feature+1)*nr_class array. The
092        * value of bias is stored in the variable bias.
093        * @see #getBias()
094        * @return a <b>copy of</b> the feature weight array as described
095        */
096    //    public double[] getFeatureWeights() {
097    //        return Util.copyOf(w, w.length);
098    //    }
099    
100        /**
101        * @return true for logistic regression solvers
102        */
103        public boolean isProbabilityModel() {
104            return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
105        }
106        
107        public double getBias() {
108            return bias;
109        }
110            
111        public int[] predict(MaltFeatureNode[] x) { 
112                    final double[] dec_values = new double[nr_class];
113            final int n = (bias >= 0)?nr_feature + 1:nr_feature;
114            final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class;
115            final int xlen = x.length;
116            int i;
117            for (i = 0; i < nr_w; i++) {
118                dec_values[i] = 0;   
119            }
120            
121            for (i=0; i < xlen; i++) {
122                if (x[i].index <= n) {
123                    int t = (x[i].index - 1);
124                    if (w[t] != null) {
125                            for (int j = 0; j < w[t].length; j++) {
126                                dec_values[j] += w[t][j] * x[i].value;
127                            }
128                    }
129                }
130            }
131    
132                    final int[] predictionList = Util.copyOf(labels, nr_class); 
133                    double tmpDec;
134                    int tmpObj;
135                    int lagest;
136                    final int nc =  nr_class-1;
137                    for (i=0; i < nc; i++) {
138                            lagest = i;
139                            for (int j=i; j < nr_class; j++) {
140                                    if (dec_values[j] > dec_values[lagest]) {
141                                            lagest = j;
142                                    }
143                            }
144                            tmpDec = dec_values[lagest];
145                            dec_values[lagest] = dec_values[i];
146                            dec_values[i] = tmpDec;
147                            tmpObj = predictionList[lagest];
148                            predictionList[lagest] = predictionList[i];
149                            predictionList[i] = tmpObj;
150                    }
151                    return predictionList;
152            }
153            
154            private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
155                    is.defaultReadObject();
156            }
157    
158            private void writeObject(ObjectOutputStream os) throws IOException {
159                    os.defaultWriteObject();
160            }
161            
162            private void loadModel(Reader inputReader) throws IOException {
163                    labels = null;
164                    Pattern whitespace = Pattern.compile("\\s+");
165            BufferedReader reader = null;
166            if (inputReader instanceof BufferedReader) {
167                reader = (BufferedReader)inputReader;
168            } else {
169                reader = new BufferedReader(inputReader);
170            }
171    
172            try {
173                String line = null;
174                while ((line = reader.readLine()) != null) {
175                    String[] split = whitespace.split(line);
176                    if (split[0].equals("solver_type")) {
177                        SolverType solver = SolverType.valueOf(split[1]);
178                        if (solver == null) {
179                            throw new RuntimeException("unknown solver type");
180                        }
181                        solverType = solver;
182                    } else if (split[0].equals("nr_class")) {
183                        nr_class = Util.atoi(split[1]);
184                        Integer.parseInt(split[1]);
185                    } else if (split[0].equals("nr_feature")) {
186                        nr_feature = Util.atoi(split[1]);
187                    } else if (split[0].equals("bias")) {
188                        bias = Util.atof(split[1]);
189                    } else if (split[0].equals("w")) {
190                        break;
191                    } else if (split[0].equals("label")) {
192                        labels = new int[nr_class];
193                        for (int i = 0; i < nr_class; i++) {
194                            labels[i] = Util.atoi(split[i + 1]);
195                        }
196                    } else {
197                        throw new RuntimeException("unknown text in model file: [" + line + "]");
198                    }
199                }
200    
201                int w_size = nr_feature;
202                if (bias >= 0) w_size++;
203    
204                int nr_w = nr_class;
205                if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1;
206                w = new double[w_size][nr_w];
207                int[] buffer = new int[128];
208    
209                for (int i = 0; i < w_size; i++) {
210                    for (int j = 0; j < nr_w; j++) {
211                        int b = 0;
212                        while (true) {
213                            int ch = reader.read();
214                            if (ch == -1) {
215                                throw new EOFException("unexpected EOF");
216                            }
217                            if (ch == ' ') {
218                                    w[i][j] = Util.atof(new String(buffer, 0, b));
219                                break;
220                            } else {
221                                buffer[b++] = ch;
222                            }
223                        }
224                    }
225                }
226            }
227            finally {
228                Util.closeQuietly(reader);
229            }
230            }
231    
232        public int hashCode() {
233            final int prime = 31;
234            long temp = Double.doubleToLongBits(bias);
235            int result = prime * 1 + (int)(temp ^ (temp >>> 32));
236            result = prime * result + Arrays.hashCode(labels);
237            result = prime * result + nr_class;
238            result = prime * result + nr_feature;
239            result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
240            for (int i = 0; i < w.length; i++) {
241                    result = prime * result + Arrays.hashCode(w[i]);
242            }
243            return result;
244        }
245    
246        public boolean equals(Object obj) {
247            if (this == obj) return true;
248            if (obj == null) return false;
249            if (getClass() != obj.getClass()) return false;
250            MaltLiblinearModel other = (MaltLiblinearModel)obj;
251            if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
252            if (!Arrays.equals(labels, other.labels)) return false;
253            if (nr_class != other.nr_class) return false;
254            if (nr_feature != other.nr_feature) return false;
255            if (solverType == null) {
256                if (other.solverType != null) return false;
257            } else if (!solverType.equals(other.solverType)) return false;
258            for (int i = 0; i < w.length; i++) {
259                    if (other.w.length <= i) return false;
260                    if (!Util.equals(w[i], other.w[i])) return false;
261            }    
262            return true;
263        }
264        
265        public String toString() {
266            final StringBuilder sb = new StringBuilder("Model");
267            sb.append(" bias=").append(bias);
268            sb.append(" nr_class=").append(nr_class);
269            sb.append(" nr_feature=").append(nr_feature);
270            sb.append(" solverType=").append(solverType);
271            return sb.toString();
272        }
273    }