001    package org.maltparser.ml.lib;
002    
003    import java.io.BufferedReader;
004    import java.io.EOFException;
005    import java.io.File;
006    import java.io.FileInputStream;
007    import java.io.IOException;
008    import java.io.InputStreamReader;
009    import java.io.ObjectInputStream;
010    import java.io.ObjectOutputStream;
011    import java.io.Reader;
012    import java.io.Serializable;
013    import java.nio.charset.Charset;
014    import java.util.Arrays;
015    import java.util.regex.Pattern;
016    
017    import org.maltparser.core.helper.Util;
018    
019    import de.bwaldvogel.liblinear.SolverType;
020    
021    /**
022     * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package.
023     * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
024     * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 
025     * structure directly on the model. </p> 
026     * 
027     * @author Johan Hall
028     *
029     */
030    public class MaltLiblinearModel implements Serializable, MaltLibModel {
031            private static final long serialVersionUID = 7526471155622776147L;
032            private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
033            private double bias;
034            /** label of each class */
035            private int[] labels;
036            private int nr_class;
037            private int nr_feature;
038            private SolverType solverType;
039            /** feature weight array */
040            private double[][] w;
041    
042        public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) {
043            this.labels = labels;
044            this.nr_class = nr_class;
045            this.nr_feature = nr_feature;
046            this.w = w;
047            this.solverType = solverType;   
048        }
049        
050        public MaltLiblinearModel(Reader inputReader) throws IOException {
051            loadModel(inputReader);
052        }
053        
054        public MaltLiblinearModel(File modelFile) throws IOException {
055            BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
056            loadModel(inputReader);
057        }
058        
059        /**
060        * @return number of classes
061        */
062        public int getNrClass() {
063            return nr_class;
064        }
065    
066        /**
067        * @return number of features
068        */
069        public int getNrFeature() {
070            return nr_feature;
071        }
072    
073        public int[] getLabels() {
074            return Util.copyOf(labels, nr_class);
075        }
076    
077        /**
078        * The nr_feature*nr_class array w gives feature weights. We use one
079        * against the rest for multi-class classification, so each feature
080        * index corresponds to nr_class weight values. Weights are
081        * organized in the following way
082        *
083        * <pre>
084        * +------------------+------------------+------------+
085        * | nr_class weights | nr_class weights | ...
086        * | for 1st feature | for 2nd feature |
087        * +------------------+------------------+------------+
088        * </pre>
089        *
090        * If bias &gt;= 0, x becomes [x; bias]. The number of features is
091        * increased by one, so w is a (nr_feature+1)*nr_class array. The
092        * value of bias is stored in the variable bias.
093        * @see #getBias()
094        * @return a <b>copy of</b> the feature weight array as described
095        */
096    //    public double[] getFeatureWeights() {
097    //        return Util.copyOf(w, w.length);
098    //    }
099    
100        /**
101        * @return true for logistic regression solvers
102        */
103        public boolean isProbabilityModel() {
104            return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
105        }
106        
107        public double getBias() {
108            return bias;
109        }
110            
111        public int[] predict(MaltFeatureNode[] x) { 
112                    final double[] dec_values = new double[nr_class];
113                    final int[] predictionList = Util.copyOf(labels, nr_class); 
114            final int n = (bias >= 0)?nr_feature + 1:nr_feature;
115    //        final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class;
116            final int xlen = x.length;
117    //        int i;
118    //        for (i = 0; i < nr_w; i++) {
119    //            dec_values[i] = 0;   
120    //        }
121            
122            for (int i=0; i < xlen; i++) {
123                if (x[i].index <= n) {
124                    final int t = (x[i].index - 1);
125                    if (w[t] != null) {
126                            for (int j = 0; j < w[t].length; j++) {
127                                dec_values[j] += w[t][j] * x[i].value;
128                            }
129                    }
130                }
131            }
132    
133                    
134                    double tmpDec;
135                    int tmpObj;
136                    int lagest;
137                    final int nc =  nr_class-1;
138                    for (int i=0; i < nc; i++) {
139                            lagest = i;
140                            for (int j=i; j < nr_class; j++) {
141                                    if (dec_values[j] > dec_values[lagest]) {
142                                            lagest = j;
143                                    }
144                            }
145                            tmpDec = dec_values[lagest];
146                            dec_values[lagest] = dec_values[i];
147                            dec_values[i] = tmpDec;
148                            tmpObj = predictionList[lagest];
149                            predictionList[lagest] = predictionList[i];
150                            predictionList[i] = tmpObj;
151                    }
152                    return predictionList;
153            }
154            
155            private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
156                    is.defaultReadObject();
157            }
158    
159            private void writeObject(ObjectOutputStream os) throws IOException {
160                    os.defaultWriteObject();
161            }
162            
163            private void loadModel(Reader inputReader) throws IOException {
164                    labels = null;
165                    Pattern whitespace = Pattern.compile("\\s+");
166            BufferedReader reader = null;
167            if (inputReader instanceof BufferedReader) {
168                reader = (BufferedReader)inputReader;
169            } else {
170                reader = new BufferedReader(inputReader);
171            }
172    
173            try {
174                String line = null;
175                while ((line = reader.readLine()) != null) {
176                    String[] split = whitespace.split(line);
177                    if (split[0].equals("solver_type")) {
178                        SolverType solver = SolverType.valueOf(split[1]);
179                        if (solver == null) {
180                            throw new RuntimeException("unknown solver type");
181                        }
182                        solverType = solver;
183                    } else if (split[0].equals("nr_class")) {
184                        nr_class = Util.atoi(split[1]);
185                        Integer.parseInt(split[1]);
186                    } else if (split[0].equals("nr_feature")) {
187                        nr_feature = Util.atoi(split[1]);
188                    } else if (split[0].equals("bias")) {
189                        bias = Util.atof(split[1]);
190                    } else if (split[0].equals("w")) {
191                        break;
192                    } else if (split[0].equals("label")) {
193                        labels = new int[nr_class];
194                        for (int i = 0; i < nr_class; i++) {
195                            labels[i] = Util.atoi(split[i + 1]);
196                        }
197                    } else {
198                        throw new RuntimeException("unknown text in model file: [" + line + "]");
199                    }
200                }
201    
202                int w_size = nr_feature;
203                if (bias >= 0) w_size++;
204    
205                int nr_w = nr_class;
206                if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1;
207                w = new double[w_size][nr_w];
208                int[] buffer = new int[128];
209    
210                for (int i = 0; i < w_size; i++) {
211                    for (int j = 0; j < nr_w; j++) {
212                        int b = 0;
213                        while (true) {
214                            int ch = reader.read();
215                            if (ch == -1) {
216                                throw new EOFException("unexpected EOF");
217                            }
218                            if (ch == ' ') {
219                                    w[i][j] = Util.atof(new String(buffer, 0, b));
220                                break;
221                            } else {
222                                buffer[b++] = ch;
223                            }
224                        }
225                    }
226                }
227            }
228            finally {
229                Util.closeQuietly(reader);
230            }
231            }
232    
233        public int hashCode() {
234            final int prime = 31;
235            long temp = Double.doubleToLongBits(bias);
236            int result = prime * 1 + (int)(temp ^ (temp >>> 32));
237            result = prime * result + Arrays.hashCode(labels);
238            result = prime * result + nr_class;
239            result = prime * result + nr_feature;
240            result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
241            for (int i = 0; i < w.length; i++) {
242                    result = prime * result + Arrays.hashCode(w[i]);
243            }
244            return result;
245        }
246    
247        public boolean equals(Object obj) {
248            if (this == obj) return true;
249            if (obj == null) return false;
250            if (getClass() != obj.getClass()) return false;
251            MaltLiblinearModel other = (MaltLiblinearModel)obj;
252            if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
253            if (!Arrays.equals(labels, other.labels)) return false;
254            if (nr_class != other.nr_class) return false;
255            if (nr_feature != other.nr_feature) return false;
256            if (solverType == null) {
257                if (other.solverType != null) return false;
258            } else if (!solverType.equals(other.solverType)) return false;
259            for (int i = 0; i < w.length; i++) {
260                    if (other.w.length <= i) return false;
261                    if (!Util.equals(w[i], other.w[i])) return false;
262            }    
263            return true;
264        }
265        
266        public String toString() {
267            final StringBuilder sb = new StringBuilder("Model");
268            sb.append(" bias=").append(bias);
269            sb.append(" nr_class=").append(nr_class);
270            sb.append(" nr_feature=").append(nr_feature);
271            sb.append(" solverType=").append(solverType);
272            return sb.toString();
273        }
274    }