001    package org.maltparser.ml.lib;
002    
003    import java.io.BufferedOutputStream;
004    import java.io.BufferedReader;
005    import java.io.FileOutputStream;
006    import java.io.IOException;
007    import java.io.InputStream;
008    import java.io.InputStreamReader;
009    import java.io.ObjectInputStream;
010    import java.io.ObjectOutputStream;
011    import java.io.PrintStream;
012    import java.util.LinkedHashMap;
013    
014    import liblinear.FeatureNode;
015    import liblinear.Linear;
016    import liblinear.Model;
017    import liblinear.Parameter;
018    import liblinear.Problem;
019    import liblinear.SolverType;
020    
021    import org.maltparser.core.exception.MaltChainedException;
022    import org.maltparser.core.feature.FeatureVector;
023    import org.maltparser.core.helper.NoPrintStream;
024    import org.maltparser.core.helper.Util;
025    import org.maltparser.parser.guide.instance.InstanceModel;
026    
027    public class LibLinear extends Lib {
028            
029            public LibLinear(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
030                    super(owner, learnerMode, "liblinear");
031                    if (learnerMode == CLASSIFY) {
032                            try {
033                                ObjectInputStream input = new ObjectInputStream(getInputStreamFromConfigFileEntry(".moo"));
034                                try {
035                                    model = (MaltLibModel)input.readObject();
036                                } finally {
037                                    input.close();
038                                }
039                            } catch (ClassNotFoundException e) {
040                                    throw new LibException("Couldn't load the liblinear model", e);
041                            } catch (Exception e) {
042                                    throw new LibException("Couldn't load the liblinear model", e);
043                            }
044                    }
045    
046            }
047            
048            protected void trainInternal(FeatureVector featureVector) throws MaltChainedException {
049                    try {
050                            if (configLogger.isInfoEnabled()) {
051                                    configLogger.info("Creating Liblinear model "+getFile(".moo").getName()+"\n");
052                            }
053                            Problem problem = readProblem(getInstanceInputStreamReader(".ins"));
054                            final PrintStream out = System.out;
055                            final PrintStream err = System.err;
056                            System.setOut(NoPrintStream.NO_PRINTSTREAM);
057                            System.setErr(NoPrintStream.NO_PRINTSTREAM);
058                            Parameter parameter = getLiblinearParameters();
059                            Model model = Linear.train(problem, parameter);
060                            System.setOut(err);
061                            System.setOut(out);
062    //                      System.out.println(" model.getNrFeature():" +  model.getNrFeature());
063    //                      System.out.println(" model.getFeatureWeights().length:" +  model.getFeatureWeights().length);
064                            if (configLogger.isInfoEnabled()) {
065                                    configLogger.info("Optimize memory usage for the Liblinear model "+getFile(".moo").getName()+"\n");
066                            }
067                            double[][] wmatrix = convert(model.getFeatureWeights(), model.getNrClass(), model.getNrFeature());
068                            MaltLiblinearModel xmodel = new MaltLiblinearModel(model.getLabels(), model.getNrClass(), wmatrix.length, wmatrix, parameter.getSolverType());
069                            if (configLogger.isInfoEnabled()) {
070                                    configLogger.info("Save the Liblinear model "+getFile(".moo").getName()+"\n");
071                            }
072                        ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
073                    try{
074                      output.writeObject(xmodel);
075                    } finally {
076                      output.close();
077                    }
078                            if (!saveInstanceFiles) {
079                                    getFile(".ins").delete();
080                            }
081                    } catch (OutOfMemoryError e) {
082                            throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
083                    } catch (IllegalArgumentException e) {
084                            throw new LibException("The Liblinear learner was not able to redirect Standard Error stream. ", e);
085                    } catch (SecurityException e) {
086                            throw new LibException("The Liblinear learner cannot remove the instance file. ", e);
087                    } catch (IOException e) {
088                            throw new LibException("The Liblinear learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
089                    }
090            }
091            
092        private double[][] convert(double[] w, int nr_class, int nr_feature) {
093            double[][] wmatrix = new double[nr_feature][];
094            double[] wsignature = new double[nr_feature];
095            boolean reuse = false;
096            int ne = 0;
097    //        int nr = 0;
098    //        int no = 0;
099    //        int n = 0;
100    
101            Long[] reverseMap = featureMap.reverseMap();
102            for (int i = 0; i < nr_feature; i++) {
103                    reuse = false;
104                    int k = nr_class;
105                    for (int t = i * nr_class; (t + (k - 1)) >= t; k--) {
106                            if (w[t + k - 1] != 0.0) {
107                                    break;
108                            }
109                    }
110                    double[] copy = new double[k];
111                System.arraycopy(w, i * nr_class, copy, 0,k);
112                if (eliminate(copy)) {
113                    ne++;
114                    featureMap.removeIndex(reverseMap[i + 1]);
115                    reverseMap[i + 1] = null;
116                    wmatrix[i] = null;
117                } else {
118                    featureMap.setIndex(reverseMap[i + 1], i + 1 - ne);
119                    for (int j=0; j<copy.length; j++) wsignature[i] += copy[j];
120                        for (int j = 0; j < i; j++) {
121                            if (wsignature[j] == wsignature[i]) {
122                                    if (Util.equals(copy, wmatrix[j])) {
123                                            wmatrix[i] = wmatrix[j];
124                                            reuse = true;
125    //                                      nr++;
126                                            break;
127                                    }
128                            }
129                        }
130                        if (reuse == false) {
131    //                      no++;
132                            wmatrix[i] = copy;
133                        }
134                }
135    //            n++;
136            }
137            featureMap.setFeatureCounter(featureMap.getFeatureCounter()- ne);
138            double[][] wmatrix_reduced = new double[nr_feature-ne][];
139            for (int i = 0, j = 0; i < wmatrix.length; i++) {
140                    if (wmatrix[i] != null) {
141                            wmatrix_reduced[j++] = wmatrix[i];
142                    }
143            }
144    //        System.out.println("NE:"+ne);
145    //        System.out.println("NR:"+nr);
146    //        System.out.println("NO:"+no);
147    //        System.out.println("N :"+n);
148    
149            return wmatrix_reduced;
150        }
151        
152        public static boolean eliminate(double[] a) {
153            if (a.length == 0) {
154                    return true;
155            }
156            for (int i = 1; i < a.length; i++) {
157                    if (a[i] != a[i-1]) {
158                            return false;
159                    }
160            }
161            return true;
162        }
163        
164            protected void trainExternal(FeatureVector featureVector) throws MaltChainedException {
165                    try {           
166                            
167                            if (configLogger.isInfoEnabled()) {
168                                    owner.getGuide().getConfiguration().getConfigLogger().info("Creating liblinear model (external) "+getFile(".mod").getName());
169                            }
170                            binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
171                            final String[] params = getLibParamStringArray();
172                            String[] arrayCommands = new String[params.length+3];
173                            int i = 0;
174                            arrayCommands[i++] = pathExternalTrain;
175                            for (; i <= params.length; i++) {
176                                    arrayCommands[i] = params[i-1];
177                            }
178                            arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
179                            arrayCommands[i++] = getFile(".mod").getAbsolutePath();
180                            
181                    if (verbosity == Verbostity.ALL) {
182                            owner.getGuide().getConfiguration().getConfigLogger().info('\n');
183                    }
184                            final Process child = Runtime.getRuntime().exec(arrayCommands);
185                    final InputStream in = child.getInputStream();
186                    final InputStream err = child.getErrorStream();
187                    int c;
188                    while ((c = in.read()) != -1){
189                            if (verbosity == Verbostity.ALL) {
190                                    owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
191                            }
192                    }
193                    while ((c = err.read()) != -1){
194                            if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
195                                    owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
196                            }
197                    }
198                if (child.waitFor() != 0) {
199                    owner.getGuide().getConfiguration().getConfigLogger().info(" FAILED ("+child.exitValue()+")");
200                }
201                    in.close();
202                    err.close();
203                            if (configLogger.isInfoEnabled()) {
204                                    configLogger.info("\nSaving Liblinear model "+getFile(".moo").getName()+"\n");
205                            }
206                            MaltLiblinearModel xmodel = new MaltLiblinearModel(getFile(".mod"));
207                        ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
208                    try{
209                      output.writeObject(xmodel);
210                    } finally {
211                      output.close();
212                    }
213                    if (!saveInstanceFiles) {
214                                    getFile(".ins").delete();
215                                    getFile(".mod").delete();
216                                    getFile(".ins.tmp").delete();
217                    }
218                    if (configLogger.isInfoEnabled()) {
219                            configLogger.info('\n');
220                    }
221                    } catch (InterruptedException e) {
222                             throw new LibException("Learner is interrupted. ", e);
223                    } catch (IllegalArgumentException e) {
224                            throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
225                    } catch (SecurityException e) {
226                            throw new LibException("The learner cannot remove the instance file. ", e);
227                    } catch (IOException e) {
228                            throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
229                    } catch (OutOfMemoryError e) {
230                            throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
231                    }
232            }
233            
234            public void terminate() throws MaltChainedException { 
235                    super.terminate();
236            }
237    
238            public void initLibOptions() {
239                    libOptions = new LinkedHashMap<String, String>();
240                    libOptions.put("s", "4"); // type = SolverType.L2LOSS_SVM_DUAL (default)
241                    libOptions.put("c", "0.1"); // cost = 1 (default)
242                    libOptions.put("e", "0.1"); // epsilon = 0.1 (default)
243                    libOptions.put("B", "-1"); // bias = -1 (default)
244            }
245            
246            public void initAllowedLibOptionFlags() {
247                    allowedLibOptionFlags = "sceB";
248            }
249            
250            private Problem readProblem(InputStreamReader isr) throws MaltChainedException {
251                    Problem problem = new Problem();
252                    final FeatureList featureList = new FeatureList();
253                    
254                    try {
255                            final BufferedReader fp = new BufferedReader(isr);
256                            
257                            problem.bias = -1;
258                            problem.l = getNumberOfInstances();
259                            problem.x = new FeatureNode[problem.l][];
260                            problem.y = new int[problem.l];
261                            int i = 0;
262    
263                            while(true) {
264                                    String line = fp.readLine();
265                                    if(line == null) break;
266                                    int y = binariesInstance(line, featureList);
267                                    if (y == -1) {
268                                            continue;
269                                    }
270                                    try {
271                                            problem.y[i] = y;
272                                            problem.x[i] = new FeatureNode[featureList.size()];
273                                            int p = 0;
274                                    for (int k=0; k < featureList.size(); k++) {
275                                            MaltFeatureNode x = featureList.get(k);
276                                                    problem.x[i][p++] = new FeatureNode(x.getIndex(), x.getValue());
277                                            }
278                                            i++;
279                                    } catch (ArrayIndexOutOfBoundsException e) {
280                                            throw new LibException("Couldn't read liblinear problem from the instance file. ", e);
281                                    }
282    
283                            }
284                            fp.close();
285                            problem.n = featureMap.size();
286                    } catch (IOException e) {
287                            throw new LibException("Cannot read from the instance file. ", e);
288                    }
289                    return problem;
290            }
291            
292            private Parameter getLiblinearParameters() throws MaltChainedException {
293                    Parameter param = new Parameter(SolverType.MCSVM_CS, 0.1, 0.1);
294                    String type = libOptions.get("s");
295                    
296                    if (type.equals("0")) {
297                            param.setSolverType(SolverType.L2R_LR);
298                    } else if (type.equals("1")) {
299                            param.setSolverType(SolverType.L2R_L2LOSS_SVC_DUAL);
300                    } else if (type.equals("2")) {
301                            param.setSolverType(SolverType.L2R_L2LOSS_SVC);
302                    } else if (type.equals("3")) {
303                            param.setSolverType(SolverType.L2R_L1LOSS_SVC_DUAL);
304                    } else if (type.equals("4")) {
305                            param.setSolverType(SolverType.MCSVM_CS);
306                    } else if (type.equals("5")) {
307                            param.setSolverType(SolverType.L1R_L2LOSS_SVC); 
308                    } else if (type.equals("6")) {
309                            param.setSolverType(SolverType.L1R_LR); 
310                    } else if (type.equals("7")) {
311                            param.setSolverType(SolverType.L2R_LR_DUAL);    
312                    } else {
313                            throw new LibException("The liblinear type (-s) is not an integer value between 0 and 4. ");
314                    }
315                    try {
316                            param.setC(Double.valueOf(libOptions.get("c")).doubleValue());
317                    } catch (NumberFormatException e) {
318                            throw new LibException("The liblinear cost (-c) value is not numerical value. ", e);
319                    }
320                    try {
321                            param.setEps(Double.valueOf(libOptions.get("e")).doubleValue());
322                    } catch (NumberFormatException e) {
323                            throw new LibException("The liblinear epsilon (-e) value is not numerical value. ", e);
324                    }
325                    return param;
326            }
327    }