001 package org.maltparser.ml.lib; 002 003 import java.io.BufferedOutputStream; 004 import java.io.BufferedReader; 005 import java.io.FileOutputStream; 006 import java.io.IOException; 007 import java.io.InputStream; 008 import java.io.InputStreamReader; 009 import java.io.ObjectInputStream; 010 import java.io.ObjectOutputStream; 011 import java.io.PrintStream; 012 013 import java.util.LinkedHashMap; 014 015 import org.maltparser.core.exception.MaltChainedException; 016 import org.maltparser.core.feature.FeatureVector; 017 import org.maltparser.core.helper.NoPrintStream; 018 import org.maltparser.parser.guide.instance.InstanceModel; 019 020 021 import libsvm.svm; 022 import libsvm.svm_model; 023 import libsvm.svm_node; 024 import libsvm.svm_parameter; 025 import libsvm.svm_problem; 026 027 public class LibSvm extends Lib { 028 029 public LibSvm(InstanceModel owner, Integer learnerMode) throws MaltChainedException { 030 super(owner, learnerMode, "libsvm"); 031 if (learnerMode == CLASSIFY) { 032 try { 033 ObjectInputStream input = new ObjectInputStream(getInputStreamFromConfigFileEntry(".moo")); 034 try { 035 model = (MaltLibModel)input.readObject(); 036 } finally { 037 input.close(); 038 } 039 } catch (ClassNotFoundException e) { 040 throw new LibException("Couldn't load the liblinear model", e); 041 } catch (Exception e) { 042 throw new LibException("Couldn't load the liblinear model", e); 043 } 044 } 045 } 046 047 protected void trainInternal(FeatureVector featureVector) throws MaltChainedException { 048 try { 049 final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins")); 050 final svm_parameter param = getLibSvmParameters(); 051 if(svm.svm_check_parameter(prob, param) != null) { 052 throw new LibException(svm.svm_check_parameter(prob, param)); 053 } 054 owner.getGuide().getConfiguration().getConfigLogger().info("Creating LIBSVM model "+getFile(".moo").getName()+"\n"); 055 final PrintStream out = System.out; 056 final PrintStream err = System.err; 057 System.setOut(NoPrintStream.NO_PRINTSTREAM); 058 System.setErr(NoPrintStream.NO_PRINTSTREAM); 059 svm_model model = svm.svm_train(prob, param); 060 System.setOut(err); 061 System.setOut(out); 062 ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath()))); 063 try{ 064 output.writeObject(new MaltLibsvmModel(model, prob)); 065 } finally { 066 output.close(); 067 } 068 if (!saveInstanceFiles) { 069 getFile(".ins").delete(); 070 } 071 } catch (OutOfMemoryError e) { 072 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e); 073 } catch (IllegalArgumentException e) { 074 throw new LibException("The LIBSVM learner was not able to redirect Standard Error stream. ", e); 075 } catch (SecurityException e) { 076 throw new LibException("The LIBSVM learner cannot remove the instance file. ", e); 077 } catch (IOException e) { 078 throw new LibException("The LIBSVM learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e); 079 } 080 } 081 082 protected void trainExternal(FeatureVector featureVector) throws MaltChainedException { 083 try { 084 binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp")); 085 owner.getGuide().getConfiguration().getConfigLogger().info("Creating learner model (external) "+getFile(".mod").getName()); 086 final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins")); 087 final String[] params = getLibParamStringArray(); 088 String[] arrayCommands = new String[params.length+3]; 089 int i = 0; 090 arrayCommands[i++] = pathExternalTrain; 091 for (; i <= params.length; i++) { 092 arrayCommands[i] = params[i-1]; 093 } 094 arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath(); 095 arrayCommands[i++] = getFile(".mod").getAbsolutePath(); 096 097 if (verbosity == Verbostity.ALL) { 098 owner.getGuide().getConfiguration().getConfigLogger().info('\n'); 099 } 100 final Process child = Runtime.getRuntime().exec(arrayCommands); 101 final InputStream in = child.getInputStream(); 102 final InputStream err = child.getErrorStream(); 103 int c; 104 while ((c = in.read()) != -1){ 105 if (verbosity == Verbostity.ALL) { 106 owner.getGuide().getConfiguration().getConfigLogger().info((char)c); 107 } 108 } 109 while ((c = err.read()) != -1){ 110 if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) { 111 owner.getGuide().getConfiguration().getConfigLogger().info((char)c); 112 } 113 } 114 if (child.waitFor() != 0) { 115 owner.getGuide().getConfiguration().getConfigLogger().info(" FAILED ("+child.exitValue()+")"); 116 } 117 in.close(); 118 err.close(); 119 svm_model model = svm.svm_load_model(getFile(".mod").getAbsolutePath()); 120 MaltLibsvmModel xmodel = new MaltLibsvmModel(model, prob); 121 ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath()))); 122 try { 123 output.writeObject(xmodel); 124 } finally { 125 output.close(); 126 } 127 if (!saveInstanceFiles) { 128 getFile(".ins").delete(); 129 getFile(".mod").delete(); 130 getFile(".ins.tmp").delete(); 131 } 132 owner.getGuide().getConfiguration().getConfigLogger().info('\n'); 133 } catch (InterruptedException e) { 134 throw new LibException("Learner is interrupted. ", e); 135 } catch (IllegalArgumentException e) { 136 throw new LibException("The learner was not able to redirect Standard Error stream. ", e); 137 } catch (SecurityException e) { 138 throw new LibException("The learner cannot remove the instance file. ", e); 139 } catch (IOException e) { 140 throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e); 141 } catch (OutOfMemoryError e) { 142 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e); 143 } 144 } 145 146 public void terminate() throws MaltChainedException { 147 super.terminate(); 148 } 149 150 public void initLibOptions() { 151 libOptions = new LinkedHashMap<String, String>(); 152 libOptions.put("s", Integer.toString(svm_parameter.C_SVC)); 153 libOptions.put("t", Integer.toString(svm_parameter.POLY)); 154 libOptions.put("d", Integer.toString(2)); 155 libOptions.put("g", Double.toString(0.2)); 156 libOptions.put("r", Double.toString(0)); 157 libOptions.put("n", Double.toString(0.5)); 158 libOptions.put("m", Integer.toString(100)); 159 libOptions.put("c", Double.toString(1)); 160 libOptions.put("e", Double.toString(1.0)); 161 libOptions.put("p", Double.toString(0.1)); 162 libOptions.put("h", Integer.toString(1)); 163 libOptions.put("b", Integer.toString(0)); 164 } 165 166 public void initAllowedLibOptionFlags() { 167 allowedLibOptionFlags = "stdgrnmcepb"; 168 } 169 170 private svm_parameter getLibSvmParameters() throws MaltChainedException { 171 svm_parameter param = new svm_parameter(); 172 173 param.svm_type = Integer.parseInt(libOptions.get("s")); 174 param.kernel_type = Integer.parseInt(libOptions.get("t")); 175 param.degree = Integer.parseInt(libOptions.get("d")); 176 param.gamma = Double.valueOf(libOptions.get("g")).doubleValue(); 177 param.coef0 = Double.valueOf(libOptions.get("r")).doubleValue(); 178 param.nu = Double.valueOf(libOptions.get("n")).doubleValue(); 179 param.cache_size = Double.valueOf(libOptions.get("m")).doubleValue(); 180 param.C = Double.valueOf(libOptions.get("c")).doubleValue(); 181 param.eps = Double.valueOf(libOptions.get("e")).doubleValue(); 182 param.p = Double.valueOf(libOptions.get("p")).doubleValue(); 183 param.shrinking = Integer.parseInt(libOptions.get("h")); 184 param.probability = Integer.parseInt(libOptions.get("b")); 185 param.nr_weight = 0; 186 param.weight_label = new int[0]; 187 param.weight = new double[0]; 188 return param; 189 } 190 191 private svm_problem readProblem(InputStreamReader isr) throws MaltChainedException { 192 final svm_problem problem = new svm_problem(); 193 final svm_parameter param = getLibSvmParameters(); 194 final FeatureList featureList = new FeatureList(); 195 try { 196 final BufferedReader fp = new BufferedReader(isr); 197 198 problem.l = getNumberOfInstances(); 199 problem.x = new svm_node[problem.l][]; 200 problem.y = new double[problem.l]; 201 int i = 0; 202 203 while(true) { 204 String line = fp.readLine(); 205 if(line == null) break; 206 int y = binariesInstance(line, featureList); 207 if (y == -1) { 208 continue; 209 } 210 try { 211 problem.y[i] = y; 212 problem.x[i] = new svm_node[featureList.size()]; 213 int p = 0; 214 for (int k=0; k < featureList.size(); k++) { 215 MaltFeatureNode x = featureList.get(k); 216 problem.x[i][p] = new svm_node(); 217 problem.x[i][p].value = x.getValue(); 218 problem.x[i][p].index = x.getIndex(); 219 p++; 220 } 221 i++; 222 } catch (ArrayIndexOutOfBoundsException e) { 223 throw new LibException("Couldn't read libsvm problem from the instance file. ", e); 224 } 225 } 226 fp.close(); 227 if (param.gamma == 0) { 228 param.gamma = 1.0/featureMap.getFeatureCounter(); 229 } 230 } catch (IOException e) { 231 throw new LibException("Couldn't read libsvm problem from the instance file. ", e); 232 } 233 return problem; 234 } 235 }