001package org.maltparser.ml.lib;
002
003import java.io.BufferedReader;
004import java.io.EOFException;
005import java.io.File;
006import java.io.FileInputStream;
007import java.io.IOException;
008import java.io.InputStreamReader;
009import java.io.ObjectInputStream;
010import java.io.ObjectOutputStream;
011import java.io.Reader;
012import java.io.Serializable;
013import java.nio.charset.Charset;
014import java.util.Arrays;
015import java.util.regex.Pattern;
016
017import org.maltparser.core.helper.Util;
018
019import de.bwaldvogel.liblinear.SolverType;
020
021/**
022 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package.
023 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
024 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 
025 * structure directly on the model. </p> 
026 * 
027 * @author Johan Hall
028 *
029 */
030public class MaltLiblinearModel implements Serializable, MaltLibModel {
031        private static final long serialVersionUID = 7526471155622776147L;
032        private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
033        private double bias;
034        /** label of each class */
035        private int[] labels;
036        private int nr_class;
037        private int nr_feature;
038        private SolverType solverType;
039        /** feature weight array */
040        private double[][] w;
041
042    public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) {
043        this.labels = labels;
044        this.nr_class = nr_class;
045        this.nr_feature = nr_feature;
046        this.w = w;
047        this.solverType = solverType;   
048    }
049    
050    public MaltLiblinearModel(Reader inputReader) throws IOException {
051        loadModel(inputReader);
052    }
053    
054    public MaltLiblinearModel(File modelFile) throws IOException {
055        BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
056        loadModel(inputReader);
057    }
058    
059    /**
060    * @return number of classes
061    */
062    public int getNrClass() {
063        return nr_class;
064    }
065
066    /**
067    * @return number of features
068    */
069    public int getNrFeature() {
070        return nr_feature;
071    }
072
073    public int[] getLabels() {
074        return Util.copyOf(labels, nr_class);
075    }
076
077    /**
078    * The nr_feature*nr_class array w gives feature weights. We use one
079    * against the rest for multi-class classification, so each feature
080    * index corresponds to nr_class weight values. Weights are
081    * organized in the following way
082    *
083    * <pre>
084    * +------------------+------------------+------------+
085    * | nr_class weights | nr_class weights | ...
086    * | for 1st feature | for 2nd feature |
087    * +------------------+------------------+------------+
088    * </pre>
089    *
090    * If bias &gt;= 0, x becomes [x; bias]. The number of features is
091    * increased by one, so w is a (nr_feature+1)*nr_class array. The
092    * value of bias is stored in the variable bias.
093    * @see #getBias()
094    * @return a <b>copy of</b> the feature weight array as described
095    */
096//    public double[] getFeatureWeights() {
097//        return Util.copyOf(w, w.length);
098//    }
099
100    /**
101    * @return true for logistic regression solvers
102    */
103    public boolean isProbabilityModel() {
104        return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
105    }
106    
107    public double getBias() {
108        return bias;
109    }
110        
111    public int[] predict(MaltFeatureNode[] x) { 
112                final double[] dec_values = new double[nr_class];
113        final int n = (bias >= 0)?nr_feature + 1:nr_feature;
114        final int xlen = x.length;
115        
116        for (int i=0; i < xlen; i++) {
117            if (x[i].index <= n) {
118                final int t = (x[i].index - 1);
119                if (w[t] != null) {
120                        for (int j = 0; j < w[t].length; j++) {
121                            dec_values[j] += w[t][j] * x[i].value;
122                        }
123                }
124            }
125        }
126
127                
128                double tmpDec;
129                int tmpObj;
130                int iMax;
131        final int[] predictionList = new int[nr_class];
132        System.arraycopy(labels, 0, predictionList, 0, nr_class);
133                final int nc =  nr_class-1;
134                for (int i=0; i < nc; i++) {
135                        iMax = i;
136                        for (int j=i+1; j < nr_class; j++) {
137                                if (dec_values[j] > dec_values[iMax]) {
138                                        iMax = j;
139                                }
140                        }
141                        if (iMax != i) {
142                                tmpDec = dec_values[iMax];
143                                dec_values[iMax] = dec_values[i];
144                                dec_values[i] = tmpDec;
145                                tmpObj = predictionList[iMax];
146                                predictionList[iMax] = predictionList[i];
147                                predictionList[i] = tmpObj;
148                        }
149                }
150                return predictionList;
151        }
152        
153    public int predict_one(MaltFeatureNode[] x) { 
154                final double[] dec_values = new double[nr_class];
155        final int n = (bias >= 0)?nr_feature + 1:nr_feature;
156        final int xlen = x.length;
157        
158        for (int i=0; i < xlen; i++) {
159            if (x[i].index <= n) {
160                final int t = (x[i].index - 1);
161                if (w[t] != null) {
162                        for (int j = 0; j < w[t].length; j++) {
163                            dec_values[j] += w[t][j] * x[i].value;
164                        }
165                }
166            }
167        }
168        
169        double max = dec_values[0];
170        int max_index = 0;
171                for (int i = 1; i < dec_values.length; i++) {
172                        if (dec_values[i] > max) {
173                                max = dec_values[i];
174                                max_index = i;
175                        }
176                }
177
178                return labels[max_index];
179        }
180    
181        private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
182                is.defaultReadObject();
183        }
184
185        private void writeObject(ObjectOutputStream os) throws IOException {
186                os.defaultWriteObject();
187        }
188        
189        private void loadModel(Reader inputReader) throws IOException {
190                labels = null;
191                Pattern whitespace = Pattern.compile("\\s+");
192        BufferedReader reader = null;
193        if (inputReader instanceof BufferedReader) {
194            reader = (BufferedReader)inputReader;
195        } else {
196            reader = new BufferedReader(inputReader);
197        }
198
199        try {
200            String line = null;
201            while ((line = reader.readLine()) != null) {
202                String[] split = whitespace.split(line);
203                if (split[0].equals("solver_type")) {
204                    SolverType solver = SolverType.valueOf(split[1]);
205                    if (solver == null) {
206                        throw new RuntimeException("unknown solver type");
207                    }
208                    solverType = solver;
209                } else if (split[0].equals("nr_class")) {
210                    nr_class = Util.atoi(split[1]);
211                    Integer.parseInt(split[1]);
212                } else if (split[0].equals("nr_feature")) {
213                    nr_feature = Util.atoi(split[1]);
214                } else if (split[0].equals("bias")) {
215                    bias = Util.atof(split[1]);
216                } else if (split[0].equals("w")) {
217                    break;
218                } else if (split[0].equals("label")) {
219                    labels = new int[nr_class];
220                    for (int i = 0; i < nr_class; i++) {
221                        labels[i] = Util.atoi(split[i + 1]);
222                    }
223                } else {
224                    throw new RuntimeException("unknown text in model file: [" + line + "]");
225                }
226            }
227
228            int w_size = nr_feature;
229            if (bias >= 0) w_size++;
230
231            int nr_w = nr_class;
232            if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1;
233            w = new double[w_size][nr_w];
234            int[] buffer = new int[128];
235
236            for (int i = 0; i < w_size; i++) {
237                for (int j = 0; j < nr_w; j++) {
238                    int b = 0;
239                    while (true) {
240                        int ch = reader.read();
241                        if (ch == -1) {
242                            throw new EOFException("unexpected EOF");
243                        }
244                        if (ch == ' ') {
245                                w[i][j] = Util.atof(new String(buffer, 0, b));
246                            break;
247                        } else {
248                            buffer[b++] = ch;
249                        }
250                    }
251                }
252            }
253        }
254        finally {
255            Util.closeQuietly(reader);
256        }
257        }
258
259    public int hashCode() {
260        final int prime = 31;
261        long temp = Double.doubleToLongBits(bias);
262        int result = prime * 1 + (int)(temp ^ (temp >>> 32));
263        result = prime * result + Arrays.hashCode(labels);
264        result = prime * result + nr_class;
265        result = prime * result + nr_feature;
266        result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
267        for (int i = 0; i < w.length; i++) {
268                result = prime * result + Arrays.hashCode(w[i]);
269        }
270        return result;
271    }
272
273    public boolean equals(Object obj) {
274        if (this == obj) return true;
275        if (obj == null) return false;
276        if (getClass() != obj.getClass()) return false;
277        MaltLiblinearModel other = (MaltLiblinearModel)obj;
278        if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
279        if (!Arrays.equals(labels, other.labels)) return false;
280        if (nr_class != other.nr_class) return false;
281        if (nr_feature != other.nr_feature) return false;
282        if (solverType == null) {
283            if (other.solverType != null) return false;
284        } else if (!solverType.equals(other.solverType)) return false;
285        for (int i = 0; i < w.length; i++) {
286                if (other.w.length <= i) return false;
287                if (!Util.equals(w[i], other.w[i])) return false;
288        }    
289        return true;
290    }
291    
292    public String toString() {
293        final StringBuilder sb = new StringBuilder("Model");
294        sb.append(" bias=").append(bias);
295        sb.append(" nr_class=").append(nr_class);
296        sb.append(" nr_feature=").append(nr_feature);
297        sb.append(" solverType=").append(solverType);
298        return sb.toString();
299    }
300}