001package org.maltparser.ml.lib;
002
003import java.io.BufferedOutputStream;
004import java.io.BufferedReader;
005import java.io.FileOutputStream;
006import java.io.IOException;
007import java.io.InputStream;
008import java.io.InputStreamReader;
009import java.io.ObjectOutputStream;
010import java.io.PrintStream;
011import java.util.LinkedHashMap;
012
013import de.bwaldvogel.liblinear.FeatureNode;
014import de.bwaldvogel.liblinear.Linear;
015import de.bwaldvogel.liblinear.Model;
016import de.bwaldvogel.liblinear.Parameter;
017import de.bwaldvogel.liblinear.Problem;
018import de.bwaldvogel.liblinear.SolverType;
019
020import org.maltparser.core.config.Configuration;
021import org.maltparser.core.exception.MaltChainedException;
022import org.maltparser.core.helper.NoPrintStream;
023import org.maltparser.core.helper.Util;
024import org.maltparser.ml.lib.FeatureList;
025import org.maltparser.ml.lib.MaltLiblinearModel;
026import org.maltparser.ml.lib.MaltFeatureNode;
027import org.maltparser.ml.lib.LibException;
028import org.maltparser.parser.guide.instance.InstanceModel;
029
030public class LibLinear extends Lib {
031        
032        public LibLinear(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
033                super(owner, learnerMode, "liblinear");
034                if (learnerMode == CLASSIFY) {
035                        model = (MaltLibModel)getConfigFileEntryObject(".moo");
036                }
037        }
038        
039        protected void trainInternal( LinkedHashMap<String, String> libOptions) throws MaltChainedException {
040                Configuration config = getConfiguration();
041                
042                if (config.isLoggerInfoEnabled()) {
043                        config.logInfoMessage("Creating Liblinear model "+getFile(".moo").getName()+"\n");
044                }
045                double[] wmodel = null;
046                int[] labels = null;
047                int nr_class = 0;
048                int nr_feature = 0;
049                Parameter parameter = getLiblinearParameters(libOptions);
050                try {   
051                        Problem problem = readProblem(getInstanceInputStreamReader(".ins"));
052                        boolean res = checkProblem(problem);
053                        if (res == false) {
054                                throw new LibException("Abort (The number of training instances * the number of classes) > "+Integer.MAX_VALUE+" and this is not supported by LibLinear. ");
055                        }
056                        if (config.isLoggerInfoEnabled()) {
057                                config.logInfoMessage("- Train a parser model using LibLinear.\n");
058                        }
059                        final PrintStream out = System.out;
060                        final PrintStream err = System.err;
061                        System.setOut(NoPrintStream.NO_PRINTSTREAM);
062                        System.setErr(NoPrintStream.NO_PRINTSTREAM);
063                        Model model = Linear.train(problem, parameter);
064                        System.setOut(err);
065                        System.setOut(out);
066                        problem = null;
067                        wmodel = model.getFeatureWeights();
068                        labels = model.getLabels();
069                        nr_class = model.getNrClass();
070                        nr_feature = model.getNrFeature();
071                        boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
072                        if (!saveInstanceFiles) {
073                                getFile(".ins").delete();
074                        }
075                } catch (OutOfMemoryError e) {
076                        throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
077                } catch (IllegalArgumentException e) {
078                        throw new LibException("The Liblinear learner was not able to redirect Standard Error stream. ", e);
079                } catch (SecurityException e) {
080                        throw new LibException("The Liblinear learner cannot remove the instance file. ", e);
081                } catch (NegativeArraySizeException e) {
082                        throw new LibException("(The number of training instances * the number of classes) > "+Integer.MAX_VALUE+" and this is not supported by LibLinear.", e);
083                }
084                
085                if (config.isLoggerInfoEnabled()) {
086                        config.logInfoMessage("- Optimize the memory usage\n");
087                }
088                MaltLiblinearModel xmodel = null;
089                try {
090//                      System.out.println("Nr Features:" +  nr_feature);
091//                      System.out.println("nr_class:" + nr_class);
092//                      System.out.println("wmodel.length:" + wmodel.length);           
093                        double[][] wmatrix = convert2(wmodel, nr_class, nr_feature);
094                        xmodel = new MaltLiblinearModel(labels, nr_class, wmatrix.length, wmatrix, parameter.getSolverType());
095                        if (config.isLoggerInfoEnabled()) {
096                                config.logInfoMessage("- Save the Liblinear model "+getFile(".moo").getName()+"\n");
097                        }
098                } catch (OutOfMemoryError e) {
099                        throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
100                }                       
101                try {
102                        if (xmodel != null) {
103                            ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
104                        try{
105                          output.writeObject(xmodel);
106                        } finally {
107                          output.close();
108                        }
109                        }
110                } catch (OutOfMemoryError e) {
111                        throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
112                } catch (IllegalArgumentException e) {
113                        throw new LibException("The Liblinear learner was not able to redirect Standard Error stream. ", e);
114                } catch (SecurityException e) {
115                        throw new LibException("The Liblinear learner cannot remove the instance file. ", e);
116                } catch (IOException e) {
117                        throw new LibException("The Liblinear learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
118                }
119        }
120        
121    private double[][] convert2(double[] w, int nr_class, int nr_feature) {
122        int[] wlength = new int[nr_feature];
123        int nr_nfeature = 0;
124//        int ne = 0;
125//        int nr = 0;
126//        int no = 0;
127//        int n = 0;
128        
129        // Identify length of new weight array for each feature
130        for (int i = 0; i < nr_feature; i++) {
131                int k = nr_class;               
132                for (int t = i * nr_class; (t + (k - 1)) >= t; k--) {
133                        if (w[t + k - 1] != 0.0) {
134                                break;
135                        }
136                }
137                int b = k;
138                if (b != 0) {
139                        for (int t = i * nr_class; (t + (b - 1)) >= t; b--) {
140                                if (b != k) {
141                                        if (w[t + b - 1] != w[t + b]) {
142                                                break;
143                                        }
144                                }
145                        }
146                }
147                if (k == 0 || b == 0) {
148                        wlength[i] = 0;
149                } else {
150                        wlength[i] = k;
151                        nr_nfeature++;
152                }               
153        }
154        // Allocate the weight matrix with the new number of features and
155        // an array wsignature that efficient compare if weight vector can be reused by another feature. 
156        double[][] wmatrix = new double[nr_nfeature][];
157        double[] wsignature = new double[nr_nfeature];
158        Long[] reverseMap = featureMap.reverseMap();
159        int in = 0;
160        for (int i = 0; i < nr_feature; i++) {
161            if (wlength[i] == 0) {
162                // if the length of the weight vector is zero than eliminate the feature from the feature map.
163//              ne++;
164                featureMap.removeIndex(reverseMap[i + 1]);
165                reverseMap[i + 1] = null;
166            } else {            
167                boolean reuse = false;
168                double[] copy = new double[wlength[i]];
169                System.arraycopy(w, i * nr_class, copy, 0, wlength[i]);
170                featureMap.setIndex(reverseMap[i + 1], in + 1);
171                for (int j=0; j<copy.length; j++) wsignature[in] += copy[j];
172                    for (int j = 0; j < in; j++) {
173                        if (wsignature[j] == wsignature[in]) {
174                                // if the signatures is equal then do more narrow comparison  
175                                if (Util.equals(copy, wmatrix[j])) {
176                                        // if equal then reuse the weight vector
177                                        wmatrix[in] = wmatrix[j];
178                                        reuse = true;
179//                                      nr++;
180                                        break;
181                                }
182                        }
183                    }
184                    if (reuse == false) {
185                        // if no reuse has done use the new weight vector in the weight matrix 
186//                      no++;
187                        wmatrix[in] = copy;
188                    }
189                    in++;
190            }
191//            n++;
192        }
193        featureMap.setFeatureCounter(nr_nfeature);
194//        System.out.println("NE:"+ne);
195//        System.out.println("NR:"+nr);
196//        System.out.println("NO:"+no);
197//        System.out.println("N :"+n);
198        return wmatrix;
199    }
200    
201    public static boolean eliminate(double[] a) {
202        if (a.length == 0) {
203                return true;
204        }
205        for (int i = 1; i < a.length; i++) {
206                if (a[i] != a[i-1]) {
207                        return false;
208                }
209        }
210        return true;
211    }
212    
213        protected void trainExternal(String pathExternalTrain, LinkedHashMap<String, String> libOptions) throws MaltChainedException {
214                try {           
215                        Configuration config = getConfiguration();
216                        if (config.isLoggerInfoEnabled()) {
217                                config.logInfoMessage("Creating liblinear model (external) "+getFile(".mod").getName());
218                        }
219                        binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
220                        final String[] params = getLibParamStringArray(libOptions);
221                        String[] arrayCommands = new String[params.length+3];
222                        int i = 0;
223                        arrayCommands[i++] = pathExternalTrain;
224                        for (; i <= params.length; i++) {
225                                arrayCommands[i] = params[i-1];
226                        }
227                        arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
228                        arrayCommands[i++] = getFile(".mod").getAbsolutePath();
229                        
230                if (verbosity == Verbostity.ALL) {
231                        config.logInfoMessage('\n');
232                }
233                        final Process child = Runtime.getRuntime().exec(arrayCommands);
234                final InputStream in = child.getInputStream();
235                final InputStream err = child.getErrorStream();
236                int c;
237                while ((c = in.read()) != -1){
238                        if (verbosity == Verbostity.ALL) {
239                                config.logInfoMessage((char)c);
240                        }
241                }
242                while ((c = err.read()) != -1){
243                        if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
244                                config.logInfoMessage((char)c);
245                        }
246                }
247            if (child.waitFor() != 0) {
248                config.logErrorMessage(" FAILED ("+child.exitValue()+")");
249            }
250                in.close();
251                err.close();
252                        if (config.isLoggerInfoEnabled()) {
253                                config.logInfoMessage("\nSaving Liblinear model "+getFile(".moo").getName()+"\n");
254                        }
255                        MaltLiblinearModel xmodel = new MaltLiblinearModel(getFile(".mod"));
256                    ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
257                try{
258                  output.writeObject(xmodel);
259                } finally {
260                  output.close();
261                }
262                boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
263                if (!saveInstanceFiles) {
264                                getFile(".ins").delete();
265                                getFile(".mod").delete();
266                                getFile(".ins.tmp").delete();
267                }
268                if (config.isLoggerInfoEnabled()) {
269                        config.logInfoMessage('\n');
270                }
271                } catch (InterruptedException e) {
272                         throw new LibException("Learner is interrupted. ", e);
273                } catch (IllegalArgumentException e) {
274                        throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
275                } catch (SecurityException e) {
276                        throw new LibException("The learner cannot remove the instance file. ", e);
277                } catch (IOException e) {
278                        throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
279                } catch (OutOfMemoryError e) {
280                        throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
281                }
282        }
283        
284        public void terminate() throws MaltChainedException { 
285                super.terminate();
286        }
287
288        public LinkedHashMap<String, String> getDefaultLibOptions() {
289                LinkedHashMap<String, String> libOptions = new LinkedHashMap<String, String>();
290                libOptions.put("s", "4"); // type = SolverType.MCSVM_CS (default)
291                libOptions.put("c", "0.1"); // cost = 1 (default)
292                libOptions.put("e", "0.1"); // epsilon = 0.1 (default)
293                libOptions.put("B", "-1"); // bias = -1 (default)
294                return libOptions;
295        }
296        
297        public String getAllowedLibOptionFlags() {
298                return "sceB";
299        }
300        
301        private Problem readProblem(InputStreamReader isr) throws MaltChainedException {
302                Problem problem = new Problem();
303                final FeatureList featureList = new FeatureList();
304                if (getConfiguration().isLoggerInfoEnabled()) {
305                        getConfiguration().logInfoMessage("- Read all training instances.\n");
306                }
307                try {
308                        final BufferedReader fp = new BufferedReader(isr);
309                        
310                        problem.bias = -1;
311                        problem.l = getNumberOfInstances();
312                        problem.x = new FeatureNode[problem.l][];
313                        problem.y = new int[problem.l];
314                        int i = 0;
315                        
316                        while(true) {
317                                String line = fp.readLine();
318                                if(line == null) break;
319                                int y = binariesInstance(line, featureList);
320                                if (y == -1) {
321                                        continue;
322                                }
323                                try {
324                                        problem.y[i] = y;
325                                        problem.x[i] = new FeatureNode[featureList.size()];
326                                        int p = 0;
327                                for (int k=0; k < featureList.size(); k++) {
328                                        MaltFeatureNode x = featureList.get(k);
329                                                problem.x[i][p++] = new FeatureNode(x.getIndex(), x.getValue());
330                                        }
331                                        i++;
332                                } catch (ArrayIndexOutOfBoundsException e) {
333                                        throw new LibException("Couldn't read liblinear problem from the instance file. ", e);
334                                }
335
336                        }
337                        fp.close();
338                        problem.n = featureMap.size();
339                } catch (IOException e) {
340                        throw new LibException("Cannot read from the instance file. ", e);
341                }
342                
343                return problem;
344        }
345        
346        private boolean checkProblem(Problem problem) throws MaltChainedException {
347                int max_y = problem.y[0];
348                for (int i = 1; i < problem.y.length; i++) {
349                        if (problem.y[i] > max_y) {
350                                max_y = problem.y[i];
351                        }
352                }
353                if (max_y * problem.l < 0) { // max_y * problem.l > Integer.MAX_VALUE
354                        if (getConfiguration().isLoggerInfoEnabled()) {
355                                getConfiguration().logInfoMessage("*** Abort (The number of training instances * the number of classes) > Max array size: ("+problem.l+" * "+max_y+") > "+Integer.MAX_VALUE+" and this is not supported by LibLinear.\n");
356                        }
357                        return false;
358                }
359                return true;
360        }
361        
362        private Parameter getLiblinearParameters(LinkedHashMap<String, String> libOptions) throws MaltChainedException {
363                Parameter param = new Parameter(SolverType.MCSVM_CS, 0.1, 0.1);
364                String type = libOptions.get("s");
365                
366                if (type.equals("0")) {
367                        param.setSolverType(SolverType.L2R_LR);
368                } else if (type.equals("1")) {
369                        param.setSolverType(SolverType.L2R_L2LOSS_SVC_DUAL);
370                } else if (type.equals("2")) {
371                        param.setSolverType(SolverType.L2R_L2LOSS_SVC);
372                } else if (type.equals("3")) {
373                        param.setSolverType(SolverType.L2R_L1LOSS_SVC_DUAL);
374                } else if (type.equals("4")) {
375                        param.setSolverType(SolverType.MCSVM_CS);
376                } else if (type.equals("5")) {
377                        param.setSolverType(SolverType.L1R_L2LOSS_SVC); 
378                } else if (type.equals("6")) {
379                        param.setSolverType(SolverType.L1R_LR); 
380                } else if (type.equals("7")) {
381                        param.setSolverType(SolverType.L2R_LR_DUAL);    
382                } else {
383                        throw new LibException("The liblinear type (-s) is not an integer value between 0 and 4. ");
384                }
385                try {
386                        param.setC(Double.valueOf(libOptions.get("c")).doubleValue());
387                } catch (NumberFormatException e) {
388                        throw new LibException("The liblinear cost (-c) value is not numerical value. ", e);
389                }
390                try {
391                        param.setEps(Double.valueOf(libOptions.get("e")).doubleValue());
392                } catch (NumberFormatException e) {
393                        throw new LibException("The liblinear epsilon (-e) value is not numerical value. ", e);
394                }
395                return param;
396        }
397}