001package org.maltparser.ml.lib;
002
003import java.io.BufferedOutputStream;
004import java.io.BufferedReader;
005import java.io.FileOutputStream;
006import java.io.IOException;
007import java.io.InputStream;
008import java.io.InputStreamReader;
009import java.io.ObjectOutputStream;
010import java.io.PrintStream;
011
012import java.util.LinkedHashMap;
013
014import org.maltparser.core.config.Configuration;
015import org.maltparser.core.exception.MaltChainedException;
016import org.maltparser.core.helper.NoPrintStream;
017import org.maltparser.ml.lib.FeatureList;
018import org.maltparser.ml.lib.MaltLibModel;
019import org.maltparser.ml.lib.MaltLibsvmModel;
020import org.maltparser.ml.lib.MaltFeatureNode;
021import org.maltparser.ml.lib.LibException;
022import org.maltparser.parser.guide.instance.InstanceModel;
023
024
025import libsvm.svm;
026import libsvm.svm_model;
027import libsvm.svm_node;
028import libsvm.svm_parameter;
029import libsvm.svm_problem;
030
031public class LibSvm extends Lib {
032
033        public LibSvm(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
034                super(owner, learnerMode, "libsvm");
035                if (learnerMode == CLASSIFY) {
036                        model = (MaltLibModel)getConfigFileEntryObject(".moo");
037                }
038        }
039        
040        protected void trainInternal(LinkedHashMap<String, String> libOptions) throws MaltChainedException {
041                try {
042                        final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"), libOptions);
043                        final svm_parameter param = getLibSvmParameters(libOptions);
044                        if(svm.svm_check_parameter(prob, param) != null) {
045                                throw new LibException(svm.svm_check_parameter(prob, param));
046                        }
047                        Configuration config = getConfiguration();
048                        
049                        if (config.isLoggerInfoEnabled()) {
050                                config.logInfoMessage("Creating LIBSVM model "+getFile(".moo").getName()+"\n");
051                        }
052                        final PrintStream out = System.out;
053                        final PrintStream err = System.err;
054                        System.setOut(NoPrintStream.NO_PRINTSTREAM);
055                        System.setErr(NoPrintStream.NO_PRINTSTREAM);
056                        svm_model model = svm.svm_train(prob, param);
057                        System.setOut(err);
058                        System.setOut(out);
059                    ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
060                try{
061                  output.writeObject(new MaltLibsvmModel(model, prob));
062                } finally {
063                  output.close();
064                }
065                boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
066                        if (!saveInstanceFiles) {
067                                getFile(".ins").delete();
068                        }
069                } catch (OutOfMemoryError e) {
070                        throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
071                } catch (IllegalArgumentException e) {
072                        throw new LibException("The LIBSVM learner was not able to redirect Standard Error stream. ", e);
073                } catch (SecurityException e) {
074                        throw new LibException("The LIBSVM learner cannot remove the instance file. ", e);
075                } catch (IOException e) {
076                        throw new LibException("The LIBSVM learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
077                }
078        }
079        
080        protected void trainExternal(String pathExternalTrain, LinkedHashMap<String, String> libOptions) throws MaltChainedException {
081                try {           
082                        binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
083                        Configuration config = getConfiguration();
084                        
085                        if (config.isLoggerInfoEnabled()) {
086                                config.logInfoMessage("Creating learner model (external) "+getFile(".mod").getName());
087                        }
088                        final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"), libOptions);
089                        final String[] params = getLibParamStringArray(libOptions);
090                        String[] arrayCommands = new String[params.length+3];
091                        int i = 0;
092                        arrayCommands[i++] = pathExternalTrain;
093                        for (; i <= params.length; i++) {
094                                arrayCommands[i] = params[i-1];
095                        }
096                        arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
097                        arrayCommands[i++] = getFile(".mod").getAbsolutePath();
098                        
099                if (verbosity == Verbostity.ALL) {
100                        config.logInfoMessage('\n');
101                }
102                        final Process child = Runtime.getRuntime().exec(arrayCommands);
103                final InputStream in = child.getInputStream();
104                final InputStream err = child.getErrorStream();
105                int c;
106                while ((c = in.read()) != -1){
107                        if (verbosity == Verbostity.ALL) {
108                                config.logInfoMessage((char)c);
109                        }
110                }
111                while ((c = err.read()) != -1){
112                        if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
113                                config.logInfoMessage((char)c);
114                        }
115                }
116            if (child.waitFor() != 0) {
117                config.logErrorMessage(" FAILED ("+child.exitValue()+")");
118            }
119                in.close();
120                err.close();
121                svm_model model = svm.svm_load_model(getFile(".mod").getAbsolutePath());
122                MaltLibsvmModel xmodel = new MaltLibsvmModel(model, prob);
123                ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
124                try {
125                        output.writeObject(xmodel);
126                    } finally {
127                        output.close();
128                    }
129                boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
130                if (!saveInstanceFiles) {
131                                getFile(".ins").delete();
132                                getFile(".mod").delete();
133                                getFile(".ins.tmp").delete();
134                }
135                        if (config.isLoggerInfoEnabled()) {
136                                config.logInfoMessage('\n');
137                        }
138                } catch (InterruptedException e) {
139                         throw new LibException("Learner is interrupted. ", e);
140                } catch (IllegalArgumentException e) {
141                        throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
142                } catch (SecurityException e) {
143                        throw new LibException("The learner cannot remove the instance file. ", e);
144                } catch (IOException e) {
145                        throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
146                } catch (OutOfMemoryError e) {
147                        throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
148                }
149        }
150        
151        public void terminate() throws MaltChainedException { 
152                super.terminate();
153        }
154        
155        public  LinkedHashMap<String, String> getDefaultLibOptions() {
156                LinkedHashMap<String, String> libOptions = new LinkedHashMap<String, String>();
157                libOptions.put("s", Integer.toString(svm_parameter.C_SVC));
158                libOptions.put("t", Integer.toString(svm_parameter.POLY));
159                libOptions.put("d", Integer.toString(2));
160                libOptions.put("g", Double.toString(0.2));
161                libOptions.put("r", Double.toString(0));
162                libOptions.put("n", Double.toString(0.5));
163                libOptions.put("m", Integer.toString(100));
164                libOptions.put("c", Double.toString(1));
165                libOptions.put("e", Double.toString(1.0));
166                libOptions.put("p", Double.toString(0.1));
167                libOptions.put("h", Integer.toString(1));
168                libOptions.put("b", Integer.toString(0));
169                return libOptions;
170        }
171        
172        public String getAllowedLibOptionFlags() {
173                return "stdgrnmcepb";
174        }
175        
176        private svm_parameter getLibSvmParameters(LinkedHashMap<String, String> libOptions) throws MaltChainedException {
177                svm_parameter param = new svm_parameter();
178        
179                param.svm_type = Integer.parseInt(libOptions.get("s"));
180                param.kernel_type = Integer.parseInt(libOptions.get("t"));
181                param.degree = Integer.parseInt(libOptions.get("d"));
182                param.gamma = Double.valueOf(libOptions.get("g")).doubleValue();
183                param.coef0 = Double.valueOf(libOptions.get("r")).doubleValue();
184                param.nu = Double.valueOf(libOptions.get("n")).doubleValue();
185                param.cache_size = Double.valueOf(libOptions.get("m")).doubleValue();
186                param.C = Double.valueOf(libOptions.get("c")).doubleValue();
187                param.eps = Double.valueOf(libOptions.get("e")).doubleValue();
188                param.p = Double.valueOf(libOptions.get("p")).doubleValue();
189                param.shrinking = Integer.parseInt(libOptions.get("h"));
190                param.probability = Integer.parseInt(libOptions.get("b"));
191                param.nr_weight = 0;
192                param.weight_label = new int[0];
193                param.weight = new double[0];
194                return param;
195        }
196        
197        private svm_problem readProblem(InputStreamReader isr, LinkedHashMap<String, String> libOptions) throws MaltChainedException {
198                final svm_problem problem = new svm_problem();
199                final svm_parameter param = getLibSvmParameters(libOptions);
200                final FeatureList featureList = new FeatureList();
201                try {
202                        final BufferedReader fp = new BufferedReader(isr);
203                        
204                        problem.l = getNumberOfInstances();
205                        problem.x = new svm_node[problem.l][];
206                        problem.y = new double[problem.l];
207                        int i = 0;
208                        
209                        while(true) {
210                                String line = fp.readLine();
211                                if(line == null) break;
212                                int y = binariesInstance(line, featureList);
213                                if (y == -1) {
214                                        continue;
215                                }
216                                try {
217                                        problem.y[i] = y;
218                                        problem.x[i] = new svm_node[featureList.size()];
219                                        int p = 0;
220                                for (int k=0; k < featureList.size(); k++) {
221                                        MaltFeatureNode x = featureList.get(k);
222                                                problem.x[i][p] = new svm_node();
223                                                problem.x[i][p].value = x.getValue();
224                                                problem.x[i][p].index = x.getIndex();          
225                                                p++;
226                                        }
227                                        i++;
228                                } catch (ArrayIndexOutOfBoundsException e) {
229                                        throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
230                                }
231                        }
232                        fp.close();     
233                        if (param.gamma == 0) {
234                                param.gamma = 1.0/featureMap.getFeatureCounter();
235                        }
236                } catch (IOException e) {
237                        throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
238                }
239                return problem;
240        }
241}