001    package org.maltparser.ml.lib;
002    
003    import java.io.BufferedOutputStream;
004    import java.io.BufferedReader;
005    import java.io.FileOutputStream;
006    import java.io.IOException;
007    import java.io.InputStream;
008    import java.io.InputStreamReader;
009    import java.io.ObjectInputStream;
010    import java.io.ObjectOutputStream;
011    import java.io.PrintStream;
012    
013    import java.util.LinkedHashMap;
014    
015    import org.maltparser.core.exception.MaltChainedException;
016    import org.maltparser.core.feature.FeatureVector;
017    import org.maltparser.core.helper.NoPrintStream;
018    import org.maltparser.parser.guide.instance.InstanceModel;
019    
020    
021    import libsvm.svm;
022    import libsvm.svm_model;
023    import libsvm.svm_node;
024    import libsvm.svm_parameter;
025    import libsvm.svm_problem;
026    
027    public class LibSvm extends Lib {
028    
029            public LibSvm(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
030                    super(owner, learnerMode, "libsvm");
031                    if (learnerMode == CLASSIFY) {
032                            try {
033                                ObjectInputStream input = new ObjectInputStream(getInputStreamFromConfigFileEntry(".moo"));
034                                try {
035                                    model = (MaltLibModel)input.readObject();
036                                } finally {
037                                    input.close();
038                                }
039                            } catch (ClassNotFoundException e) {
040                                    throw new LibException("Couldn't load the liblinear model", e);
041                            } catch (Exception e) {
042                                    throw new LibException("Couldn't load the liblinear model", e);
043                            }
044                    }
045            }
046            
047            protected void trainInternal(FeatureVector featureVector) throws MaltChainedException {
048                    try {
049                            final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"));
050                            final svm_parameter param = getLibSvmParameters();
051                            if(svm.svm_check_parameter(prob, param) != null) {
052                                    throw new LibException(svm.svm_check_parameter(prob, param));
053                            }
054                            owner.getGuide().getConfiguration().getConfigLogger().info("Creating LIBSVM model "+getFile(".moo").getName()+"\n");
055                            final PrintStream out = System.out;
056                            final PrintStream err = System.err;
057                            System.setOut(NoPrintStream.NO_PRINTSTREAM);
058                            System.setErr(NoPrintStream.NO_PRINTSTREAM);
059                            svm_model model = svm.svm_train(prob, param);
060                            System.setOut(err);
061                            System.setOut(out);
062                        ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
063                    try{
064                      output.writeObject(new MaltLibsvmModel(model, prob));
065                    } finally {
066                      output.close();
067                    }
068                            if (!saveInstanceFiles) {
069                                    getFile(".ins").delete();
070                            }
071                    } catch (OutOfMemoryError e) {
072                            throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
073                    } catch (IllegalArgumentException e) {
074                            throw new LibException("The LIBSVM learner was not able to redirect Standard Error stream. ", e);
075                    } catch (SecurityException e) {
076                            throw new LibException("The LIBSVM learner cannot remove the instance file. ", e);
077                    } catch (IOException e) {
078                            throw new LibException("The LIBSVM learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
079                    }
080            }
081            
082            protected void trainExternal(FeatureVector featureVector) throws MaltChainedException {
083                    try {           
084                            binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
085                            owner.getGuide().getConfiguration().getConfigLogger().info("Creating learner model (external) "+getFile(".mod").getName());
086                            final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"));
087                            final String[] params = getLibParamStringArray();
088                            String[] arrayCommands = new String[params.length+3];
089                            int i = 0;
090                            arrayCommands[i++] = pathExternalTrain;
091                            for (; i <= params.length; i++) {
092                                    arrayCommands[i] = params[i-1];
093                            }
094                            arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
095                            arrayCommands[i++] = getFile(".mod").getAbsolutePath();
096                            
097                    if (verbosity == Verbostity.ALL) {
098                            owner.getGuide().getConfiguration().getConfigLogger().info('\n');
099                    }
100                            final Process child = Runtime.getRuntime().exec(arrayCommands);
101                    final InputStream in = child.getInputStream();
102                    final InputStream err = child.getErrorStream();
103                    int c;
104                    while ((c = in.read()) != -1){
105                            if (verbosity == Verbostity.ALL) {
106                                    owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
107                            }
108                    }
109                    while ((c = err.read()) != -1){
110                            if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
111                                    owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
112                            }
113                    }
114                if (child.waitFor() != 0) {
115                    owner.getGuide().getConfiguration().getConfigLogger().info(" FAILED ("+child.exitValue()+")");
116                }
117                    in.close();
118                    err.close();
119                    svm_model model = svm.svm_load_model(getFile(".mod").getAbsolutePath());
120                    MaltLibsvmModel xmodel = new MaltLibsvmModel(model, prob);
121                    ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
122                    try {
123                            output.writeObject(xmodel);
124                        } finally {
125                            output.close();
126                        }
127                    if (!saveInstanceFiles) {
128                                    getFile(".ins").delete();
129                                    getFile(".mod").delete();
130                                    getFile(".ins.tmp").delete();
131                    }
132                    owner.getGuide().getConfiguration().getConfigLogger().info('\n');
133                    } catch (InterruptedException e) {
134                             throw new LibException("Learner is interrupted. ", e);
135                    } catch (IllegalArgumentException e) {
136                            throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
137                    } catch (SecurityException e) {
138                            throw new LibException("The learner cannot remove the instance file. ", e);
139                    } catch (IOException e) {
140                            throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
141                    } catch (OutOfMemoryError e) {
142                            throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
143                    }
144            }
145            
146            public void terminate() throws MaltChainedException { 
147                    super.terminate();
148            }
149            
150            public void initLibOptions() {
151                    libOptions = new LinkedHashMap<String, String>();
152                    libOptions.put("s", Integer.toString(svm_parameter.C_SVC));
153                    libOptions.put("t", Integer.toString(svm_parameter.POLY));
154                    libOptions.put("d", Integer.toString(2));
155                    libOptions.put("g", Double.toString(0.2));
156                    libOptions.put("r", Double.toString(0));
157                    libOptions.put("n", Double.toString(0.5));
158                    libOptions.put("m", Integer.toString(100));
159                    libOptions.put("c", Double.toString(1));
160                    libOptions.put("e", Double.toString(1.0));
161                    libOptions.put("p", Double.toString(0.1));
162                    libOptions.put("h", Integer.toString(1));
163                    libOptions.put("b", Integer.toString(0));
164            }
165            
166            public void initAllowedLibOptionFlags() {
167                    allowedLibOptionFlags = "stdgrnmcepb";
168            }
169            
170            private svm_parameter getLibSvmParameters() throws MaltChainedException {
171                    svm_parameter param = new svm_parameter();
172            
173                    param.svm_type = Integer.parseInt(libOptions.get("s"));
174                    param.kernel_type = Integer.parseInt(libOptions.get("t"));
175                    param.degree = Integer.parseInt(libOptions.get("d"));
176                    param.gamma = Double.valueOf(libOptions.get("g")).doubleValue();
177                    param.coef0 = Double.valueOf(libOptions.get("r")).doubleValue();
178                    param.nu = Double.valueOf(libOptions.get("n")).doubleValue();
179                    param.cache_size = Double.valueOf(libOptions.get("m")).doubleValue();
180                    param.C = Double.valueOf(libOptions.get("c")).doubleValue();
181                    param.eps = Double.valueOf(libOptions.get("e")).doubleValue();
182                    param.p = Double.valueOf(libOptions.get("p")).doubleValue();
183                    param.shrinking = Integer.parseInt(libOptions.get("h"));
184                    param.probability = Integer.parseInt(libOptions.get("b"));
185                    param.nr_weight = 0;
186                    param.weight_label = new int[0];
187                    param.weight = new double[0];
188                    return param;
189            }
190            
191            private svm_problem readProblem(InputStreamReader isr) throws MaltChainedException {
192                    final svm_problem problem = new svm_problem();
193                    final svm_parameter param = getLibSvmParameters();
194                    final FeatureList featureList = new FeatureList();
195                    try {
196                            final BufferedReader fp = new BufferedReader(isr);
197                            
198                            problem.l = getNumberOfInstances();
199                            problem.x = new svm_node[problem.l][];
200                            problem.y = new double[problem.l];
201                            int i = 0;
202                            
203                            while(true) {
204                                    String line = fp.readLine();
205                                    if(line == null) break;
206                                    int y = binariesInstance(line, featureList);
207                                    if (y == -1) {
208                                            continue;
209                                    }
210                                    try {
211                                            problem.y[i] = y;
212                                            problem.x[i] = new svm_node[featureList.size()];
213                                            int p = 0;
214                                    for (int k=0; k < featureList.size(); k++) {
215                                            MaltFeatureNode x = featureList.get(k);
216                                                    problem.x[i][p] = new svm_node();
217                                                    problem.x[i][p].value = x.getValue();
218                                                    problem.x[i][p].index = x.getIndex();          
219                                                    p++;
220                                            }
221                                            i++;
222                                    } catch (ArrayIndexOutOfBoundsException e) {
223                                            throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
224                                    }
225                            }
226                            fp.close();     
227                            if (param.gamma == 0) {
228                                    param.gamma = 1.0/featureMap.getFeatureCounter();
229                            }
230                    } catch (IOException e) {
231                            throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
232                    }
233                    return problem;
234            }
235    }