001package org.maltparser.ml.lib; 002 003import java.io.BufferedOutputStream; 004import java.io.BufferedReader; 005import java.io.FileOutputStream; 006import java.io.IOException; 007import java.io.InputStream; 008import java.io.InputStreamReader; 009import java.io.ObjectOutputStream; 010import java.io.PrintStream; 011 012import java.util.LinkedHashMap; 013 014import org.maltparser.core.config.Configuration; 015import org.maltparser.core.exception.MaltChainedException; 016import org.maltparser.core.helper.NoPrintStream; 017import org.maltparser.ml.lib.FeatureList; 018import org.maltparser.ml.lib.MaltLibModel; 019import org.maltparser.ml.lib.MaltLibsvmModel; 020import org.maltparser.ml.lib.MaltFeatureNode; 021import org.maltparser.ml.lib.LibException; 022import org.maltparser.parser.guide.instance.InstanceModel; 023 024 025import libsvm.svm; 026import libsvm.svm_model; 027import libsvm.svm_node; 028import libsvm.svm_parameter; 029import libsvm.svm_problem; 030 031public class LibSvm extends Lib { 032 033 public LibSvm(InstanceModel owner, Integer learnerMode) throws MaltChainedException { 034 super(owner, learnerMode, "libsvm"); 035 if (learnerMode == CLASSIFY) { 036 model = (MaltLibModel)getConfigFileEntryObject(".moo"); 037 } 038 } 039 040 protected void trainInternal(LinkedHashMap<String, String> libOptions) throws MaltChainedException { 041 try { 042 final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"), libOptions); 043 final svm_parameter param = getLibSvmParameters(libOptions); 044 if(svm.svm_check_parameter(prob, param) != null) { 045 throw new LibException(svm.svm_check_parameter(prob, param)); 046 } 047 Configuration config = getConfiguration(); 048 049 if (config.isLoggerInfoEnabled()) { 050 config.logInfoMessage("Creating LIBSVM model "+getFile(".moo").getName()+"\n"); 051 } 052 final PrintStream out = System.out; 053 final PrintStream err = System.err; 054 System.setOut(NoPrintStream.NO_PRINTSTREAM); 055 System.setErr(NoPrintStream.NO_PRINTSTREAM); 056 svm_model model = svm.svm_train(prob, param); 057 System.setOut(err); 058 System.setOut(out); 059 ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath()))); 060 try{ 061 output.writeObject(new MaltLibsvmModel(model, prob)); 062 } finally { 063 output.close(); 064 } 065 boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue(); 066 if (!saveInstanceFiles) { 067 getFile(".ins").delete(); 068 } 069 } catch (OutOfMemoryError e) { 070 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e); 071 } catch (IllegalArgumentException e) { 072 throw new LibException("The LIBSVM learner was not able to redirect Standard Error stream. ", e); 073 } catch (SecurityException e) { 074 throw new LibException("The LIBSVM learner cannot remove the instance file. ", e); 075 } catch (IOException e) { 076 throw new LibException("The LIBSVM learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e); 077 } 078 } 079 080 protected void trainExternal(String pathExternalTrain, LinkedHashMap<String, String> libOptions) throws MaltChainedException { 081 try { 082 binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp")); 083 Configuration config = getConfiguration(); 084 085 if (config.isLoggerInfoEnabled()) { 086 config.logInfoMessage("Creating learner model (external) "+getFile(".mod").getName()); 087 } 088 final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"), libOptions); 089 final String[] params = getLibParamStringArray(libOptions); 090 String[] arrayCommands = new String[params.length+3]; 091 int i = 0; 092 arrayCommands[i++] = pathExternalTrain; 093 for (; i <= params.length; i++) { 094 arrayCommands[i] = params[i-1]; 095 } 096 arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath(); 097 arrayCommands[i++] = getFile(".mod").getAbsolutePath(); 098 099 if (verbosity == Verbostity.ALL) { 100 config.logInfoMessage('\n'); 101 } 102 final Process child = Runtime.getRuntime().exec(arrayCommands); 103 final InputStream in = child.getInputStream(); 104 final InputStream err = child.getErrorStream(); 105 int c; 106 while ((c = in.read()) != -1){ 107 if (verbosity == Verbostity.ALL) { 108 config.logInfoMessage((char)c); 109 } 110 } 111 while ((c = err.read()) != -1){ 112 if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) { 113 config.logInfoMessage((char)c); 114 } 115 } 116 if (child.waitFor() != 0) { 117 config.logErrorMessage(" FAILED ("+child.exitValue()+")"); 118 } 119 in.close(); 120 err.close(); 121 svm_model model = svm.svm_load_model(getFile(".mod").getAbsolutePath()); 122 MaltLibsvmModel xmodel = new MaltLibsvmModel(model, prob); 123 ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath()))); 124 try { 125 output.writeObject(xmodel); 126 } finally { 127 output.close(); 128 } 129 boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue(); 130 if (!saveInstanceFiles) { 131 getFile(".ins").delete(); 132 getFile(".mod").delete(); 133 getFile(".ins.tmp").delete(); 134 } 135 if (config.isLoggerInfoEnabled()) { 136 config.logInfoMessage('\n'); 137 } 138 } catch (InterruptedException e) { 139 throw new LibException("Learner is interrupted. ", e); 140 } catch (IllegalArgumentException e) { 141 throw new LibException("The learner was not able to redirect Standard Error stream. ", e); 142 } catch (SecurityException e) { 143 throw new LibException("The learner cannot remove the instance file. ", e); 144 } catch (IOException e) { 145 throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e); 146 } catch (OutOfMemoryError e) { 147 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e); 148 } 149 } 150 151 public void terminate() throws MaltChainedException { 152 super.terminate(); 153 } 154 155 public LinkedHashMap<String, String> getDefaultLibOptions() { 156 LinkedHashMap<String, String> libOptions = new LinkedHashMap<String, String>(); 157 libOptions.put("s", Integer.toString(svm_parameter.C_SVC)); 158 libOptions.put("t", Integer.toString(svm_parameter.POLY)); 159 libOptions.put("d", Integer.toString(2)); 160 libOptions.put("g", Double.toString(0.2)); 161 libOptions.put("r", Double.toString(0)); 162 libOptions.put("n", Double.toString(0.5)); 163 libOptions.put("m", Integer.toString(100)); 164 libOptions.put("c", Double.toString(1)); 165 libOptions.put("e", Double.toString(1.0)); 166 libOptions.put("p", Double.toString(0.1)); 167 libOptions.put("h", Integer.toString(1)); 168 libOptions.put("b", Integer.toString(0)); 169 return libOptions; 170 } 171 172 public String getAllowedLibOptionFlags() { 173 return "stdgrnmcepb"; 174 } 175 176 private svm_parameter getLibSvmParameters(LinkedHashMap<String, String> libOptions) throws MaltChainedException { 177 svm_parameter param = new svm_parameter(); 178 179 param.svm_type = Integer.parseInt(libOptions.get("s")); 180 param.kernel_type = Integer.parseInt(libOptions.get("t")); 181 param.degree = Integer.parseInt(libOptions.get("d")); 182 param.gamma = Double.valueOf(libOptions.get("g")).doubleValue(); 183 param.coef0 = Double.valueOf(libOptions.get("r")).doubleValue(); 184 param.nu = Double.valueOf(libOptions.get("n")).doubleValue(); 185 param.cache_size = Double.valueOf(libOptions.get("m")).doubleValue(); 186 param.C = Double.valueOf(libOptions.get("c")).doubleValue(); 187 param.eps = Double.valueOf(libOptions.get("e")).doubleValue(); 188 param.p = Double.valueOf(libOptions.get("p")).doubleValue(); 189 param.shrinking = Integer.parseInt(libOptions.get("h")); 190 param.probability = Integer.parseInt(libOptions.get("b")); 191 param.nr_weight = 0; 192 param.weight_label = new int[0]; 193 param.weight = new double[0]; 194 return param; 195 } 196 197 private svm_problem readProblem(InputStreamReader isr, LinkedHashMap<String, String> libOptions) throws MaltChainedException { 198 final svm_problem problem = new svm_problem(); 199 final svm_parameter param = getLibSvmParameters(libOptions); 200 final FeatureList featureList = new FeatureList(); 201 try { 202 final BufferedReader fp = new BufferedReader(isr); 203 204 problem.l = getNumberOfInstances(); 205 problem.x = new svm_node[problem.l][]; 206 problem.y = new double[problem.l]; 207 int i = 0; 208 209 while(true) { 210 String line = fp.readLine(); 211 if(line == null) break; 212 int y = binariesInstance(line, featureList); 213 if (y == -1) { 214 continue; 215 } 216 try { 217 problem.y[i] = y; 218 problem.x[i] = new svm_node[featureList.size()]; 219 int p = 0; 220 for (int k=0; k < featureList.size(); k++) { 221 MaltFeatureNode x = featureList.get(k); 222 problem.x[i][p] = new svm_node(); 223 problem.x[i][p].value = x.getValue(); 224 problem.x[i][p].index = x.getIndex(); 225 p++; 226 } 227 i++; 228 } catch (ArrayIndexOutOfBoundsException e) { 229 throw new LibException("Couldn't read libsvm problem from the instance file. ", e); 230 } 231 } 232 fp.close(); 233 if (param.gamma == 0) { 234 param.gamma = 1.0/featureMap.getFeatureCounter(); 235 } 236 } catch (IOException e) { 237 throw new LibException("Couldn't read libsvm problem from the instance file. ", e); 238 } 239 return problem; 240 } 241}