001    package org.maltparser.ml.lib;
002    
003    import java.io.BufferedReader;
004    import java.io.EOFException;
005    import java.io.File;
006    import java.io.FileInputStream;
007    import java.io.IOException;
008    import java.io.InputStreamReader;
009    import java.io.ObjectInputStream;
010    import java.io.ObjectOutputStream;
011    import java.io.Reader;
012    import java.io.Serializable;
013    import java.nio.charset.Charset;
014    import java.util.Arrays;
015    import java.util.regex.Pattern;
016    
017    import org.maltparser.core.helper.Util;
018    
019    import liblinear.Model;
020    import liblinear.SolverType;
021    
022    /**
023     * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package.
024     * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
025     * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 
026     * structure directly on the model. </p> 
027     * 
028     * @author Johan Hall
029     *
030     */
031    public class MaltLiblinearModel implements Serializable, MaltLibModel {
032            private static final long serialVersionUID = 7526471155622776147L;
033            private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
034            private double bias;
035            /** label of each class */
036            private int[] labels;
037            private int nr_class;
038            private int nr_feature;
039            private SolverType solverType;
040            /** feature weight array */
041            private double[] w;
042    
043        public MaltLiblinearModel(Model model, SolverType solverType) {
044            labels = model.getLabels();
045            nr_class = model.getNrClass();
046            nr_feature = model.getNrFeature();
047            this.solverType = solverType;
048            w = model.getFeatureWeights();
049        }
050        
051        public MaltLiblinearModel(Reader inputReader) throws IOException {
052            loadModel(inputReader);
053        }
054        
055        public MaltLiblinearModel(File modelFile) throws IOException {
056            BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
057            loadModel(inputReader);
058        }
059        
060        /**
061        * @return number of classes
062        */
063        public int getNrClass() {
064            return nr_class;
065        }
066    
067        /**
068        * @return number of features
069        */
070        public int getNrFeature() {
071            return nr_feature;
072        }
073    
074        public int[] getLabels() {
075            return Util.copyOf(labels, nr_class);
076        }
077    
078        /**
079        * The nr_feature*nr_class array w gives feature weights. We use one
080        * against the rest for multi-class classification, so each feature
081        * index corresponds to nr_class weight values. Weights are
082        * organized in the following way
083        *
084        * <pre>
085        * +------------------+------------------+------------+
086        * | nr_class weights | nr_class weights | ...
087        * | for 1st feature | for 2nd feature |
088        * +------------------+------------------+------------+
089        * </pre>
090        *
091        * If bias &gt;= 0, x becomes [x; bias]. The number of features is
092        * increased by one, so w is a (nr_feature+1)*nr_class array. The
093        * value of bias is stored in the variable bias.
094        * @see #getBias()
095        * @return a <b>copy of</b> the feature weight array as described
096        */
097        public double[] getFeatureWeights() {
098            return Util.copyOf(w, w.length);
099        }
100    
101        /**
102        * @return true for logistic regression solvers
103        */
104        public boolean isProbabilityModel() {
105            return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
106        }
107        
108        public double getBias() {
109            return bias;
110        }
111            
112        public int[] predict(MaltFeatureNode[] x) { 
113                    final double[] dec_values = new double[nr_class];
114            final int n = (bias >= 0)?nr_feature + 1:nr_feature;
115            final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class;
116            final int xlen = x.length;
117            int i;
118            for (i = 0; i < nr_w; i++) {
119                dec_values[i] = 0;   
120            }
121    
122            for (i=0; i < xlen; i++) {
123                if (x[i].index <= n) {
124                    int t = (x[i].index - 1) * nr_w;
125                    for (int j = 0; j < nr_w; j++) {
126                        dec_values[j] += w[t + j] * x[i].value;
127                    }
128                }
129            }
130    
131                    final int[] predictionList = Util.copyOf(labels, nr_class); 
132                    double tmpDec;
133                    int tmpObj;
134                    int lagest;
135                    final int nc =  nr_class-1;
136                    for (i=0; i < nc; i++) {
137                            lagest = i;
138                            for (int j=i; j < nr_class; j++) {
139                                    if (dec_values[j] > dec_values[lagest]) {
140                                            lagest = j;
141                                    }
142                            }
143                            tmpDec = dec_values[lagest];
144                            dec_values[lagest] = dec_values[i];
145                            dec_values[i] = tmpDec;
146                            tmpObj = predictionList[lagest];
147                            predictionList[lagest] = predictionList[i];
148                            predictionList[i] = tmpObj;
149                    }
150                    return predictionList;
151            }
152            
153            private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
154                    is.defaultReadObject();
155            }
156    
157            private void writeObject(ObjectOutputStream os) throws IOException {
158                    os.defaultWriteObject();
159            }
160            
161            private void loadModel(Reader inputReader) throws IOException {
162                    labels = null;
163                    Pattern whitespace = Pattern.compile("\\s+");
164            BufferedReader reader = null;
165            if (inputReader instanceof BufferedReader) {
166                reader = (BufferedReader)inputReader;
167            } else {
168                reader = new BufferedReader(inputReader);
169            }
170    
171            try {
172                String line = null;
173                while ((line = reader.readLine()) != null) {
174                    String[] split = whitespace.split(line);
175                    if (split[0].equals("solver_type")) {
176                        SolverType solver = SolverType.valueOf(split[1]);
177                        if (solver == null) {
178                            throw new RuntimeException("unknown solver type");
179                        }
180                        solverType = solver;
181                    } else if (split[0].equals("nr_class")) {
182                        nr_class = Util.atoi(split[1]);
183                        Integer.parseInt(split[1]);
184                    } else if (split[0].equals("nr_feature")) {
185                        nr_feature = Util.atoi(split[1]);
186                    } else if (split[0].equals("bias")) {
187                        bias = Util.atof(split[1]);
188                    } else if (split[0].equals("w")) {
189                        break;
190                    } else if (split[0].equals("label")) {
191                        labels = new int[nr_class];
192                        for (int i = 0; i < nr_class; i++) {
193                            labels[i] = Util.atoi(split[i + 1]);
194                        }
195                    } else {
196                        throw new RuntimeException("unknown text in model file: [" + line + "]");
197                    }
198                }
199    
200                int w_size = nr_feature;
201                if (bias >= 0) w_size++;
202    
203                int nr_w = nr_class;
204                if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1;
205    
206                w = new double[w_size * nr_w];
207                int[] buffer = new int[128];
208    
209                for (int i = 0; i < w_size; i++) {
210                    for (int j = 0; j < nr_w; j++) {
211                        int b = 0;
212                        while (true) {
213                            int ch = reader.read();
214                            if (ch == -1) {
215                                throw new EOFException("unexpected EOF");
216                            }
217                            if (ch == ' ') {
218                                w[i * nr_w + j] = Util.atof(new String(buffer, 0, b));
219                                break;
220                            } else {
221                                buffer[b++] = ch;
222                            }
223                        }
224                    }
225                }
226            }
227            finally {
228                Util.closeQuietly(reader);
229            }
230            }
231    
232        public int hashCode() {
233            final int prime = 31;
234            long temp = Double.doubleToLongBits(bias);
235            int result = prime * 1 + (int)(temp ^ (temp >>> 32));
236            result = prime * result + Arrays.hashCode(labels);
237            result = prime * result + nr_class;
238            result = prime * result + nr_feature;
239            result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
240            result = prime * result + Arrays.hashCode(w);
241            return result;
242        }
243    
244        public boolean equals(Object obj) {
245            if (this == obj) return true;
246            if (obj == null) return false;
247            if (getClass() != obj.getClass()) return false;
248            MaltLiblinearModel other = (MaltLiblinearModel)obj;
249            if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
250            if (!Arrays.equals(labels, other.labels)) return false;
251            if (nr_class != other.nr_class) return false;
252            if (nr_feature != other.nr_feature) return false;
253            if (solverType == null) {
254                if (other.solverType != null) return false;
255            } else if (!solverType.equals(other.solverType)) return false;
256            if (!Util.equals(w, other.w)) return false;
257            return true;
258        }
259        
260        public String toString() {
261            final StringBuilder sb = new StringBuilder("Model");
262            sb.append(" bias=").append(bias);
263            sb.append(" nr_class=").append(nr_class);
264            sb.append(" nr_feature=").append(nr_feature);
265            sb.append(" solverType=").append(solverType);
266            return sb.toString();
267        }
268    }