001 package org.maltparser.ml.lib;
002
003 import java.io.BufferedOutputStream;
004 import java.io.BufferedReader;
005 import java.io.FileOutputStream;
006 import java.io.IOException;
007 import java.io.InputStream;
008 import java.io.InputStreamReader;
009 import java.io.ObjectInputStream;
010 import java.io.ObjectOutputStream;
011 import java.io.PrintStream;
012
013 import java.util.LinkedHashMap;
014
015 import org.maltparser.core.exception.MaltChainedException;
016 import org.maltparser.core.feature.FeatureVector;
017 import org.maltparser.core.helper.NoPrintStream;
018 import org.maltparser.parser.guide.instance.InstanceModel;
019
020
021 import libsvm.svm;
022 import libsvm.svm_model;
023 import libsvm.svm_node;
024 import libsvm.svm_parameter;
025 import libsvm.svm_problem;
026
027 public class LibSvm extends Lib {
028
029 public LibSvm(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
030 super(owner, learnerMode, "libsvm");
031 if (learnerMode == CLASSIFY) {
032 try {
033 ObjectInputStream input = new ObjectInputStream(getInputStreamFromConfigFileEntry(".moo"));
034 try {
035 model = (MaltLibModel)input.readObject();
036 } finally {
037 input.close();
038 }
039 } catch (ClassNotFoundException e) {
040 throw new LibException("Couldn't load the liblinear model", e);
041 } catch (Exception e) {
042 throw new LibException("Couldn't load the liblinear model", e);
043 }
044 }
045 }
046
047 protected void trainInternal(FeatureVector featureVector) throws MaltChainedException {
048 try {
049 final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"));
050 final svm_parameter param = getLibSvmParameters();
051 if(svm.svm_check_parameter(prob, param) != null) {
052 throw new LibException(svm.svm_check_parameter(prob, param));
053 }
054 owner.getGuide().getConfiguration().getConfigLogger().info("Creating LIBSVM model "+getFile(".moo").getName()+"\n");
055 final PrintStream out = System.out;
056 final PrintStream err = System.err;
057 System.setOut(NoPrintStream.NO_PRINTSTREAM);
058 System.setErr(NoPrintStream.NO_PRINTSTREAM);
059 svm_model model = svm.svm_train(prob, param);
060 System.setOut(err);
061 System.setOut(out);
062 ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
063 try{
064 output.writeObject(new MaltLibsvmModel(model, prob));
065 } finally {
066 output.close();
067 }
068 if (!saveInstanceFiles) {
069 getFile(".ins").delete();
070 }
071 } catch (OutOfMemoryError e) {
072 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
073 } catch (IllegalArgumentException e) {
074 throw new LibException("The LIBSVM learner was not able to redirect Standard Error stream. ", e);
075 } catch (SecurityException e) {
076 throw new LibException("The LIBSVM learner cannot remove the instance file. ", e);
077 } catch (IOException e) {
078 throw new LibException("The LIBSVM learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
079 }
080 }
081
082 protected void trainExternal(FeatureVector featureVector) throws MaltChainedException {
083 try {
084 binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
085 owner.getGuide().getConfiguration().getConfigLogger().info("Creating learner model (external) "+getFile(".mod").getName());
086 final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"));
087 final String[] params = getLibParamStringArray();
088 String[] arrayCommands = new String[params.length+3];
089 int i = 0;
090 arrayCommands[i++] = pathExternalTrain;
091 for (; i <= params.length; i++) {
092 arrayCommands[i] = params[i-1];
093 }
094 arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
095 arrayCommands[i++] = getFile(".mod").getAbsolutePath();
096
097 if (verbosity == Verbostity.ALL) {
098 owner.getGuide().getConfiguration().getConfigLogger().info('\n');
099 }
100 final Process child = Runtime.getRuntime().exec(arrayCommands);
101 final InputStream in = child.getInputStream();
102 final InputStream err = child.getErrorStream();
103 int c;
104 while ((c = in.read()) != -1){
105 if (verbosity == Verbostity.ALL) {
106 owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
107 }
108 }
109 while ((c = err.read()) != -1){
110 if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
111 owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
112 }
113 }
114 if (child.waitFor() != 0) {
115 owner.getGuide().getConfiguration().getConfigLogger().info(" FAILED ("+child.exitValue()+")");
116 }
117 in.close();
118 err.close();
119 svm_model model = svm.svm_load_model(getFile(".mod").getAbsolutePath());
120 MaltLibsvmModel xmodel = new MaltLibsvmModel(model, prob);
121 ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
122 try {
123 output.writeObject(xmodel);
124 } finally {
125 output.close();
126 }
127 if (!saveInstanceFiles) {
128 getFile(".ins").delete();
129 getFile(".mod").delete();
130 getFile(".ins.tmp").delete();
131 }
132 owner.getGuide().getConfiguration().getConfigLogger().info('\n');
133 } catch (InterruptedException e) {
134 throw new LibException("Learner is interrupted. ", e);
135 } catch (IllegalArgumentException e) {
136 throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
137 } catch (SecurityException e) {
138 throw new LibException("The learner cannot remove the instance file. ", e);
139 } catch (IOException e) {
140 throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
141 } catch (OutOfMemoryError e) {
142 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
143 }
144 }
145
146 public void terminate() throws MaltChainedException {
147 super.terminate();
148 }
149
150 public void initLibOptions() {
151 libOptions = new LinkedHashMap<String, String>();
152 libOptions.put("s", Integer.toString(svm_parameter.C_SVC));
153 libOptions.put("t", Integer.toString(svm_parameter.POLY));
154 libOptions.put("d", Integer.toString(2));
155 libOptions.put("g", Double.toString(0.2));
156 libOptions.put("r", Double.toString(0));
157 libOptions.put("n", Double.toString(0.5));
158 libOptions.put("m", Integer.toString(100));
159 libOptions.put("c", Double.toString(1));
160 libOptions.put("e", Double.toString(1.0));
161 libOptions.put("p", Double.toString(0.1));
162 libOptions.put("h", Integer.toString(1));
163 libOptions.put("b", Integer.toString(0));
164 }
165
166 public void initAllowedLibOptionFlags() {
167 allowedLibOptionFlags = "stdgrnmcepb";
168 }
169
170 private svm_parameter getLibSvmParameters() throws MaltChainedException {
171 svm_parameter param = new svm_parameter();
172
173 param.svm_type = Integer.parseInt(libOptions.get("s"));
174 param.kernel_type = Integer.parseInt(libOptions.get("t"));
175 param.degree = Integer.parseInt(libOptions.get("d"));
176 param.gamma = Double.valueOf(libOptions.get("g")).doubleValue();
177 param.coef0 = Double.valueOf(libOptions.get("r")).doubleValue();
178 param.nu = Double.valueOf(libOptions.get("n")).doubleValue();
179 param.cache_size = Double.valueOf(libOptions.get("m")).doubleValue();
180 param.C = Double.valueOf(libOptions.get("c")).doubleValue();
181 param.eps = Double.valueOf(libOptions.get("e")).doubleValue();
182 param.p = Double.valueOf(libOptions.get("p")).doubleValue();
183 param.shrinking = Integer.parseInt(libOptions.get("h"));
184 param.probability = Integer.parseInt(libOptions.get("b"));
185 param.nr_weight = 0;
186 param.weight_label = new int[0];
187 param.weight = new double[0];
188 return param;
189 }
190
191 private svm_problem readProblem(InputStreamReader isr) throws MaltChainedException {
192 final svm_problem problem = new svm_problem();
193 final svm_parameter param = getLibSvmParameters();
194 final FeatureList featureList = new FeatureList();
195 try {
196 final BufferedReader fp = new BufferedReader(isr);
197
198 problem.l = getNumberOfInstances();
199 problem.x = new svm_node[problem.l][];
200 problem.y = new double[problem.l];
201 int i = 0;
202
203 while(true) {
204 String line = fp.readLine();
205 if(line == null) break;
206 int y = binariesInstance(line, featureList);
207 if (y == -1) {
208 continue;
209 }
210 try {
211 problem.y[i] = y;
212 problem.x[i] = new svm_node[featureList.size()];
213 int p = 0;
214 for (int k=0; k < featureList.size(); k++) {
215 MaltFeatureNode x = featureList.get(k);
216 problem.x[i][p] = new svm_node();
217 problem.x[i][p].value = x.getValue();
218 problem.x[i][p].index = x.getIndex();
219 p++;
220 }
221 i++;
222 } catch (ArrayIndexOutOfBoundsException e) {
223 throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
224 }
225 }
226 fp.close();
227 if (param.gamma == 0) {
228 param.gamma = 1.0/featureMap.getFeatureCounter();
229 }
230 } catch (IOException e) {
231 throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
232 }
233 return problem;
234 }
235 }