001    package org.maltparser.ml.lib;
002    
003    import java.io.BufferedOutputStream;
004    import java.io.BufferedReader;
005    import java.io.FileOutputStream;
006    import java.io.IOException;
007    import java.io.InputStream;
008    import java.io.InputStreamReader;
009    import java.io.ObjectInputStream;
010    import java.io.ObjectOutputStream;
011    import java.io.PrintStream;
012    import java.util.LinkedHashMap;
013    
014    import liblinear.FeatureNode;
015    import liblinear.Linear;
016    import liblinear.Model;
017    import liblinear.Parameter;
018    import liblinear.Problem;
019    import liblinear.SolverType;
020    
021    import org.maltparser.core.exception.MaltChainedException;
022    import org.maltparser.core.feature.FeatureVector;
023    import org.maltparser.core.helper.NoPrintStream;
024    import org.maltparser.core.helper.Util;
025    import org.maltparser.parser.guide.instance.InstanceModel;
026    
027    public class LibLinear extends Lib {
028            
029            public LibLinear(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
030                    super(owner, learnerMode, "liblinear");
031                    if (learnerMode == CLASSIFY) {
032                            try {
033                                ObjectInputStream input = new ObjectInputStream(getInputStreamFromConfigFileEntry(".moo"));
034                                try {
035                                    model = (MaltLibModel)input.readObject();
036                                } finally {
037                                    input.close();
038                                }
039                            } catch (ClassNotFoundException e) {
040                                    throw new LibException("Couldn't load the liblinear model", e);
041                            } catch (Exception e) {
042                                    throw new LibException("Couldn't load the liblinear model", e);
043                            }
044                    }
045    
046            }
047            
048            protected void trainInternal(FeatureVector featureVector) throws MaltChainedException {
049                    try {
050                            if (configLogger.isInfoEnabled()) {
051                                    configLogger.info("Creating Liblinear model "+getFile(".moo").getName()+"\n");
052                            }
053                            Problem problem = readProblem(getInstanceInputStreamReader(".ins"));
054                            final PrintStream out = System.out;
055                            final PrintStream err = System.err;
056                            System.setOut(NoPrintStream.NO_PRINTSTREAM);
057                            System.setErr(NoPrintStream.NO_PRINTSTREAM);
058                            Parameter parameter = getLiblinearParameters();
059                            Model model = Linear.train(problem, parameter);
060                            System.setOut(err);
061                            System.setOut(out);
062    //                      System.out.println(" model.getNrFeature():" +  model.getNrFeature());
063    //                      System.out.println(" model.getFeatureWeights().length:" +  model.getFeatureWeights().length);
064                            double[][] wmatrix = convert(model.getFeatureWeights(), model.getNrClass(), model.getNrFeature());
065                            MaltLiblinearModel xmodel = new MaltLiblinearModel(model.getLabels(), model.getNrClass(), wmatrix.length, wmatrix, parameter.getSolverType());
066                        ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
067                    try{
068                      output.writeObject(xmodel);
069                    } finally {
070                      output.close();
071                    }
072                            if (!saveInstanceFiles) {
073                                    getFile(".ins").delete();
074                            }
075                    } catch (OutOfMemoryError e) {
076                            throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
077                    } catch (IllegalArgumentException e) {
078                            throw new LibException("The Liblinear learner was not able to redirect Standard Error stream. ", e);
079                    } catch (SecurityException e) {
080                            throw new LibException("The Liblinear learner cannot remove the instance file. ", e);
081                    } catch (IOException e) {
082                            throw new LibException("The Liblinear learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
083                    }
084            }
085            
086        private double[][] convert(double[] w, int nr_class, int nr_feature) {
087            double[][] wmatrix = new double[nr_feature][];
088            boolean reuse = false;
089            int ne = 0;
090            int nr = 0;
091            int no = 0;
092            int n = 0;
093    
094            Long[] reverseMap = featureMap.reverseMap();
095            for (int i = 0; i < nr_feature; i++) {
096                    reuse = false;
097                    int k = nr_class;
098                    for (int t = i * nr_class; (t + (k - 1)) >= t; k--) {
099                            if (w[t + k - 1] != 0.0) {
100                                    break;
101                            }
102                    }
103                    double[] copy = new double[k];
104                System.arraycopy(w, i * nr_class, copy, 0,k);
105                if (eliminate(copy)) {
106                    ne++;
107                    featureMap.removeIndex(reverseMap[i + 1]);
108                    featureMap.decrementfeatureCounter();
109                    reverseMap[i + 1] = null;
110                    wmatrix[i] = null;
111                } else {
112                    featureMap.setIndex(reverseMap[i + 1], i + 1 - ne);
113                        for (int j = 0; j < i; j++) {
114                            if (Util.equals(copy, wmatrix[j])) {
115                                    wmatrix[i] = wmatrix[j];
116                                    reuse = true;
117                                    nr++;
118                                    break;
119                            }
120                        }
121                        if (reuse == false) {
122                            no++;
123                            wmatrix[i] = copy;
124                        }
125                }
126                n++;
127            }
128            double[][] wmatrix_reduced = new double[nr_feature-ne][];
129            for (int i = 0, j = 0; i < wmatrix.length; i++) {
130                    if (wmatrix[i] != null) {
131                            wmatrix_reduced[j++] = wmatrix[i];
132                    }
133            }
134    //        System.out.println("NE:"+ne);
135    //        System.out.println("NR:"+nr);
136    //        System.out.println("NO:"+no);
137    //        System.out.println("N :"+n);
138            return wmatrix_reduced;
139        }
140        
141        public static boolean eliminate(double[] a) {
142            if (a.length == 0) {
143                    return true;
144            }
145            for (int i = 1; i < a.length; i++) {
146                    if (a[i] != a[i-1]) {
147                            return false;
148                    }
149            }
150            return true;
151        }
152        
153            protected void trainExternal(FeatureVector featureVector) throws MaltChainedException {
154                    try {           
155                            
156                            if (configLogger.isInfoEnabled()) {
157                                    owner.getGuide().getConfiguration().getConfigLogger().info("Creating liblinear model (external) "+getFile(".mod").getName());
158                            }
159                            binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
160                            final String[] params = getLibParamStringArray();
161                            String[] arrayCommands = new String[params.length+3];
162                            int i = 0;
163                            arrayCommands[i++] = pathExternalTrain;
164                            for (; i <= params.length; i++) {
165                                    arrayCommands[i] = params[i-1];
166                            }
167                            arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
168                            arrayCommands[i++] = getFile(".mod").getAbsolutePath();
169                            
170                    if (verbosity == Verbostity.ALL) {
171                            owner.getGuide().getConfiguration().getConfigLogger().info('\n');
172                    }
173                            final Process child = Runtime.getRuntime().exec(arrayCommands);
174                    final InputStream in = child.getInputStream();
175                    final InputStream err = child.getErrorStream();
176                    int c;
177                    while ((c = in.read()) != -1){
178                            if (verbosity == Verbostity.ALL) {
179                                    owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
180                            }
181                    }
182                    while ((c = err.read()) != -1){
183                            if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
184                                    owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
185                            }
186                    }
187                if (child.waitFor() != 0) {
188                    owner.getGuide().getConfiguration().getConfigLogger().info(" FAILED ("+child.exitValue()+")");
189                }
190                    in.close();
191                    err.close();
192                            if (configLogger.isInfoEnabled()) {
193                                    configLogger.info("\nSaving Liblinear model "+getFile(".moo").getName()+"\n");
194                            }
195                            MaltLiblinearModel xmodel = new MaltLiblinearModel(getFile(".mod"));
196                        ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
197                    try{
198                      output.writeObject(xmodel);
199                    } finally {
200                      output.close();
201                    }
202                    if (!saveInstanceFiles) {
203                                    getFile(".ins").delete();
204                                    getFile(".mod").delete();
205                                    getFile(".ins.tmp").delete();
206                    }
207                    if (configLogger.isInfoEnabled()) {
208                            configLogger.info('\n');
209                    }
210                    } catch (InterruptedException e) {
211                             throw new LibException("Learner is interrupted. ", e);
212                    } catch (IllegalArgumentException e) {
213                            throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
214                    } catch (SecurityException e) {
215                            throw new LibException("The learner cannot remove the instance file. ", e);
216                    } catch (IOException e) {
217                            throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
218                    } catch (OutOfMemoryError e) {
219                            throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
220                    }
221            }
222            
223            public void terminate() throws MaltChainedException { 
224                    super.terminate();
225            }
226    
227            public void initLibOptions() {
228                    libOptions = new LinkedHashMap<String, String>();
229                    libOptions.put("s", "4"); // type = SolverType.L2LOSS_SVM_DUAL (default)
230                    libOptions.put("c", "0.1"); // cost = 1 (default)
231                    libOptions.put("e", "0.1"); // epsilon = 0.1 (default)
232                    libOptions.put("B", "-1"); // bias = -1 (default)
233            }
234            
235            public void initAllowedLibOptionFlags() {
236                    allowedLibOptionFlags = "sceB";
237            }
238            
239            private Problem readProblem(InputStreamReader isr) throws MaltChainedException {
240                    Problem problem = new Problem();
241                    final FeatureList featureList = new FeatureList();
242                    
243                    try {
244                            final BufferedReader fp = new BufferedReader(isr);
245                            
246                            problem.bias = -1;
247                            problem.l = getNumberOfInstances();
248                            problem.x = new FeatureNode[problem.l][];
249                            problem.y = new int[problem.l];
250                            int i = 0;
251    
252                            while(true) {
253                                    String line = fp.readLine();
254                                    if(line == null) break;
255                                    int y = binariesInstance(line, featureList);
256                                    if (y == -1) {
257                                            continue;
258                                    }
259                                    try {
260                                            problem.y[i] = y;
261                                            problem.x[i] = new FeatureNode[featureList.size()];
262                                            int p = 0;
263                                    for (int k=0; k < featureList.size(); k++) {
264                                            MaltFeatureNode x = featureList.get(k);
265                                                    problem.x[i][p++] = new FeatureNode(x.getIndex(), x.getValue());
266                                            }
267                                            i++;
268                                    } catch (ArrayIndexOutOfBoundsException e) {
269                                            throw new LibException("Couldn't read liblinear problem from the instance file. ", e);
270                                    }
271    
272                            }
273                            fp.close();
274                            problem.n = featureMap.size();
275                    } catch (IOException e) {
276                            throw new LibException("Cannot read from the instance file. ", e);
277                    }
278                    return problem;
279            }
280            
281            private Parameter getLiblinearParameters() throws MaltChainedException {
282                    Parameter param = new Parameter(SolverType.MCSVM_CS, 0.1, 0.1);
283                    String type = libOptions.get("s");
284                    
285                    if (type.equals("0")) {
286                            param.setSolverType(SolverType.L2R_LR);
287                    } else if (type.equals("1")) {
288                            param.setSolverType(SolverType.L2R_L2LOSS_SVC_DUAL);
289                    } else if (type.equals("2")) {
290                            param.setSolverType(SolverType.L2R_L2LOSS_SVC);
291                    } else if (type.equals("3")) {
292                            param.setSolverType(SolverType.L2R_L1LOSS_SVC_DUAL);
293                    } else if (type.equals("4")) {
294                            param.setSolverType(SolverType.MCSVM_CS);
295                    } else if (type.equals("5")) {
296                            param.setSolverType(SolverType.L1R_L2LOSS_SVC); 
297                    } else if (type.equals("6")) {
298                            param.setSolverType(SolverType.L1R_LR); 
299                    } else if (type.equals("7")) {
300                            param.setSolverType(SolverType.L2R_LR_DUAL);    
301                    } else {
302                            throw new LibException("The liblinear type (-s) is not an integer value between 0 and 4. ");
303                    }
304                    try {
305                            param.setC(Double.valueOf(libOptions.get("c")).doubleValue());
306                    } catch (NumberFormatException e) {
307                            throw new LibException("The liblinear cost (-c) value is not numerical value. ", e);
308                    }
309                    try {
310                            param.setEps(Double.valueOf(libOptions.get("e")).doubleValue());
311                    } catch (NumberFormatException e) {
312                            throw new LibException("The liblinear epsilon (-e) value is not numerical value. ", e);
313                    }
314                    return param;
315            }
316    }