001 package org.maltparser.ml.lib;
002
003 import java.io.BufferedOutputStream;
004 import java.io.BufferedReader;
005 import java.io.BufferedWriter;
006 import java.io.File;
007 import java.io.FileNotFoundException;
008 import java.io.FileOutputStream;
009 import java.io.IOException;
010 import java.io.InputStream;
011 import java.io.InputStreamReader;
012 import java.io.ObjectInputStream;
013 import java.io.ObjectOutputStream;
014 import java.io.OutputStream;
015
016 import java.io.OutputStreamWriter;
017 import java.util.ArrayList;
018
019 import org.apache.log4j.Logger;
020
021 import java.util.Formatter;
022 import java.util.LinkedHashMap;
023 import java.util.Set;
024 import java.util.jar.JarEntry;
025 import java.util.regex.Pattern;
026 import java.util.regex.PatternSyntaxException;
027
028
029 import org.maltparser.core.exception.MaltChainedException;
030 import org.maltparser.core.feature.FeatureVector;
031 import org.maltparser.core.feature.function.FeatureFunction;
032 import org.maltparser.core.feature.value.FeatureValue;
033 import org.maltparser.core.feature.value.MultipleFeatureValue;
034 import org.maltparser.core.feature.value.SingleFeatureValue;
035 import org.maltparser.core.syntaxgraph.DependencyStructure;
036 import org.maltparser.ml.LearningMethod;
037 import org.maltparser.parser.DependencyParserConfig;
038 import org.maltparser.parser.guide.instance.InstanceModel;
039 import org.maltparser.parser.history.action.SingleDecision;
040
041 public abstract class Lib implements LearningMethod {
042 protected Verbostity verbosity;
043 public enum Verbostity {
044 SILENT, ERROR, ALL
045 }
046 protected InstanceModel owner;
047 protected int learnerMode;
048 protected String name;
049 protected int numberOfInstances;
050 protected boolean saveInstanceFiles;
051 protected boolean excludeNullValues;
052 protected BufferedWriter instanceOutput = null;
053 protected FeatureMap featureMap;
054 protected String paramString;
055 protected String pathExternalTrain;
056 protected LinkedHashMap<String, String> libOptions;
057 protected String allowedLibOptionFlags;
058 protected Logger configLogger;
059 final Pattern tabPattern = Pattern.compile("\t");
060 final Pattern pipePattern = Pattern.compile("\\|");
061 private StringBuilder sb = new StringBuilder();
062 protected MaltLibModel model = null;
063 /**
064 * Constructs a Lib learner.
065 *
066 * @param owner the guide model owner
067 * @param learnerMode the mode of the learner BATCH or CLASSIFY
068 */
069 public Lib(InstanceModel owner, Integer learnerMode, String learningMethodName) throws MaltChainedException {
070 setOwner(owner);
071 setLearnerMode(learnerMode.intValue());
072 setNumberOfInstances(0);
073 setLearningMethodName(learningMethodName);
074 verbosity = Verbostity.SILENT;
075 configLogger = owner.getGuide().getConfiguration().getConfigLogger();
076 initLibOptions();
077 initAllowedLibOptionFlags();
078 parseParameters(getConfiguration().getOptionValue("lib", "options").toString());
079 initSpecialParameters();
080
081 if (learnerMode == BATCH) {
082 featureMap = new FeatureMap();
083 instanceOutput = new BufferedWriter(getInstanceOutputStreamWriter(".ins"));
084 } else if (learnerMode == CLASSIFY) {
085 featureMap = loadFeatureMap(getInputStreamFromConfigFileEntry(".map"));
086 }
087 }
088
089
090 public void addInstance(SingleDecision decision, FeatureVector featureVector) throws MaltChainedException {
091 if (featureVector == null) {
092 throw new LibException("The feature vector cannot be found");
093 } else if (decision == null) {
094 throw new LibException("The decision cannot be found");
095 }
096
097 try {
098 sb.append(decision.getDecisionCode()+"\t");
099 final int n = featureVector.size();
100 for (int i = 0; i < n; i++) {
101 FeatureValue featureValue = featureVector.getFeatureValue(i);
102 if (featureValue == null || (excludeNullValues == true && featureValue.isNullValue())) {
103 sb.append("-1");
104 } else {
105 if (featureValue instanceof SingleFeatureValue) {
106 SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
107 if (singleFeatureValue.getValue() == 1) {
108 sb.append(singleFeatureValue.getIndexCode());
109 } else if (singleFeatureValue.getValue() == 0) {
110 sb.append("-1");
111 } else {
112 sb.append(singleFeatureValue.getIndexCode());
113 sb.append(":");
114 sb.append(singleFeatureValue.getValue());
115 }
116 } else if (featureValue instanceof MultipleFeatureValue) {
117 Set<Integer> values = ((MultipleFeatureValue)featureValue).getCodes();
118 int j=0;
119 for (Integer value : values) {
120 sb.append(value.toString());
121 if (j != values.size()-1) {
122 sb.append("|");
123 }
124 j++;
125 }
126 } else {
127 throw new LibException("Don't recognize the type of feature value: "+featureValue.getClass());
128 }
129 }
130 sb.append('\t');
131 }
132 sb.append('\n');
133 instanceOutput.write(sb.toString());
134 instanceOutput.flush();
135 increaseNumberOfInstances();
136 sb.setLength(0);
137 } catch (IOException e) {
138 throw new LibException("The learner cannot write to the instance file. ", e);
139 }
140 }
141
142 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { }
143
144 public void moveAllInstances(LearningMethod method,
145 FeatureFunction divideFeature,
146 ArrayList<Integer> divideFeatureIndexVector)
147 throws MaltChainedException {
148 if (method == null) {
149 throw new LibException("The learning method cannot be found. ");
150 } else if (divideFeature == null) {
151 throw new LibException("The divide feature cannot be found. ");
152 }
153
154 try {
155 final BufferedReader in = new BufferedReader(getInstanceInputStreamReader(".ins"));
156 final BufferedWriter out = method.getInstanceWriter();
157 final StringBuilder sb = new StringBuilder(6);
158 int l = in.read();
159 char c;
160 int j = 0;
161
162 while(true) {
163 if (l == -1) {
164 sb.setLength(0);
165 break;
166 }
167 c = (char)l;
168 l = in.read();
169 if (c == '\t') {
170 if (divideFeatureIndexVector.contains(j-1)) {
171 out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
172 out.write('\t');
173 }
174 out.write(sb.toString());
175 j++;
176 out.write('\t');
177 sb.setLength(0);
178 } else if (c == '\n') {
179 out.write(sb.toString());
180 if (divideFeatureIndexVector.contains(j-1)) {
181 out.write('\t');
182 out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
183 }
184 out.write('\n');
185 sb.setLength(0);
186 method.increaseNumberOfInstances();
187 this.decreaseNumberOfInstances();
188 j = 0;
189 } else {
190 sb.append(c);
191 }
192 }
193 in.close();
194 getFile(".ins").delete();
195 out.flush();
196 } catch (SecurityException e) {
197 throw new LibException("The learner cannot remove the instance file. ", e);
198 } catch (NullPointerException e) {
199 throw new LibException("The instance file cannot be found. ", e);
200 } catch (FileNotFoundException e) {
201 throw new LibException("The instance file cannot be found. ", e);
202 } catch (IOException e) {
203 throw new LibException("The learner read from the instance file. ", e);
204 }
205 }
206
207 public void noMoreInstances() throws MaltChainedException {
208 closeInstanceWriter();
209 }
210
211 public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
212 if (featureVector == null) {
213 throw new LibException("The learner cannot predict the next class, because the feature vector cannot be found. ");
214 }
215
216 final FeatureList featureList = new FeatureList();
217 final int size = featureVector.size();
218 for (int i = 1; i <= size; i++) {
219 final FeatureValue featureValue = featureVector.getFeatureValue(i-1);
220 if (featureValue != null && !(excludeNullValues == true && featureValue.isNullValue())) {
221 if (featureValue instanceof SingleFeatureValue) {
222 SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
223 int index = featureMap.getIndex(i, singleFeatureValue.getIndexCode());
224 if (index != -1 && singleFeatureValue.getValue() != 0) {
225 featureList.add(index,singleFeatureValue.getValue());
226 }
227 } else if (featureValue instanceof MultipleFeatureValue) {
228 for (Integer value : ((MultipleFeatureValue)featureValue).getCodes()) {
229 int v = featureMap.getIndex(i, value);
230 if (v != -1) {
231 featureList.add(v,1);
232 }
233 }
234 }
235 }
236 }
237 try {
238 decision.getKBestList().addList(model.predict(featureList.toArray()));
239 // decision.getKBestList().addList(prediction(featureList));
240 } catch (OutOfMemoryError e) {
241 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
242 }
243 return true;
244 }
245
246 // protected abstract int[] prediction(FeatureList featureList) throws MaltChainedException;
247
248 public void train(FeatureVector featureVector) throws MaltChainedException {
249 if (featureVector == null) {
250 throw new LibException("The feature vector cannot be found. ");
251 } else if (owner == null) {
252 throw new LibException("The parent guide model cannot be found. ");
253 }
254 long startTime = System.currentTimeMillis();
255
256 // if (configLogger.isInfoEnabled()) {
257 // configLogger.info("\nStart training\n");
258 // }
259 if (pathExternalTrain != null) {
260 trainExternal(featureVector);
261 } else {
262 trainInternal(featureVector);
263 }
264 // long elapsed = System.currentTimeMillis() - startTime;
265 // if (configLogger.isInfoEnabled()) {
266 // configLogger.info("Time 1: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
267 // }
268 try {
269 // if (configLogger.isInfoEnabled()) {
270 // configLogger.info("\nSaving feature map "+getFile(".map").getName()+"\n");
271 // }
272 saveFeatureMap(new BufferedOutputStream(new FileOutputStream(getFile(".map").getAbsolutePath())), featureMap);
273 } catch (FileNotFoundException e) {
274 throw new LibException("The learner cannot save the feature map file '"+getFile(".map").getAbsolutePath()+"'. ", e);
275 }
276 // elapsed = System.currentTimeMillis() - startTime;
277 // if (configLogger.isInfoEnabled()) {
278 // configLogger.info("Time 2: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
279 // }
280 }
281 protected abstract void trainExternal(FeatureVector featureVector) throws MaltChainedException;
282 protected abstract void trainInternal(FeatureVector featureVector) throws MaltChainedException;
283
284 public void terminate() throws MaltChainedException {
285 closeInstanceWriter();
286 owner = null;
287 model = null;
288 }
289
290 public BufferedWriter getInstanceWriter() {
291 return instanceOutput;
292 }
293
294 protected void closeInstanceWriter() throws MaltChainedException {
295 try {
296 if (instanceOutput != null) {
297 instanceOutput.flush();
298 instanceOutput.close();
299 instanceOutput = null;
300 }
301 } catch (IOException e) {
302 throw new LibException("The learner cannot close the instance file. ", e);
303 }
304 }
305
306
307 /**
308 * Returns the parameter string used for configure the learner
309 *
310 * @return the parameter string used for configure the learner
311 */
312 public String getParamString() {
313 return paramString;
314 }
315
316 public InstanceModel getOwner() {
317 return owner;
318 }
319
320 protected void setOwner(InstanceModel owner) {
321 this.owner = owner;
322 }
323
324 public int getLearnerMode() {
325 return learnerMode;
326 }
327
328 public void setLearnerMode(int learnerMode) throws MaltChainedException {
329 this.learnerMode = learnerMode;
330 }
331
332 public String getLearningMethodName() {
333 return name;
334 }
335
336 /**
337 * Returns the current configuration
338 *
339 * @return the current configuration
340 * @throws MaltChainedException
341 */
342 public DependencyParserConfig getConfiguration() throws MaltChainedException {
343 return owner.getGuide().getConfiguration();
344 }
345
346 public int getNumberOfInstances() throws MaltChainedException {
347 if(numberOfInstances!=0)
348 return numberOfInstances;
349 else{
350 BufferedReader reader = new BufferedReader( getInstanceInputStreamReader(".ins"));
351 try {
352 while(reader.readLine()!=null){
353 numberOfInstances++;
354 owner.increaseFrequency();
355 }
356 reader.close();
357 } catch (IOException e) {
358 throw new MaltChainedException("No instances found in file",e);
359 }
360 return numberOfInstances;
361 }
362 }
363
364 public void increaseNumberOfInstances() {
365 numberOfInstances++;
366 owner.increaseFrequency();
367 }
368
369 public void decreaseNumberOfInstances() {
370 numberOfInstances--;
371 owner.decreaseFrequency();
372 }
373
374 protected void setNumberOfInstances(int numberOfInstances) {
375 this.numberOfInstances = 0;
376 }
377
378 protected void setLearningMethodName(String name) {
379 this.name = name;
380 }
381
382 public String getPathExternalTrain() {
383 return pathExternalTrain;
384 }
385
386
387 public void setPathExternalTrain(String pathExternalTrain) {
388 this.pathExternalTrain = pathExternalTrain;
389 }
390
391 protected OutputStreamWriter getInstanceOutputStreamWriter(String suffix) throws MaltChainedException {
392 return getConfiguration().getConfigurationDir().getAppendOutputStreamWriter(owner.getModelName()+getLearningMethodName()+suffix);
393 }
394
395 protected InputStreamReader getInstanceInputStreamReader(String suffix) throws MaltChainedException {
396 return getConfiguration().getConfigurationDir().getInputStreamReader(owner.getModelName()+getLearningMethodName()+suffix);
397 }
398
399 protected InputStreamReader getInstanceInputStreamReaderFromConfigFile(String suffix) throws MaltChainedException {
400 return getConfiguration().getConfigurationDir().getInputStreamReaderFromConfigFile(owner.getModelName()+getLearningMethodName()+suffix);
401 }
402
403 protected InputStream getInputStreamFromConfigFileEntry(String suffix) throws MaltChainedException {
404 return getConfiguration().getConfigurationDir().getInputStreamFromConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix);
405 }
406
407
408 protected File getFile(String suffix) throws MaltChainedException {
409 return getConfiguration().getConfigurationDir().getFile(owner.getModelName()+getLearningMethodName()+suffix);
410 }
411
412 protected JarEntry getConfigFileEntry(String suffix) throws MaltChainedException {
413 return getConfiguration().getConfigurationDir().getConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix);
414 }
415
416 protected void initSpecialParameters() throws MaltChainedException {
417 if (getConfiguration().getOptionValue("singlemalt", "null_value") != null && getConfiguration().getOptionValue("singlemalt", "null_value").toString().equalsIgnoreCase("none")) {
418 excludeNullValues = true;
419 } else {
420 excludeNullValues = false;
421 }
422 saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
423 if (!getConfiguration().getOptionValue("lib", "external").toString().equals("")) {
424 String path = getConfiguration().getOptionValue("lib", "external").toString();
425 try {
426 if (!new File(path).exists()) {
427 throw new LibException("The path to the external trainer 'svm-train' is wrong.");
428 }
429 if (new File(path).isDirectory()) {
430 throw new LibException("The option --lib-external points to a directory, the path should point at the 'train' file or the 'train.exe' file in the libsvm or the liblinear package");
431 }
432 if (!(path.endsWith("train") ||path.endsWith("train.exe"))) {
433 throw new LibException("The option --lib-external does not specify the path to 'train' file or the 'train.exe' file in the libsvm or the liblinear package. ");
434 }
435 setPathExternalTrain(path);
436 } catch (SecurityException e) {
437 throw new LibException("Access denied to the file specified by the option --lib-external. ", e);
438 }
439 }
440 if (getConfiguration().getOptionValue("lib", "verbosity") != null) {
441 verbosity = Verbostity.valueOf(getConfiguration().getOptionValue("lib", "verbosity").toString().toUpperCase());
442 }
443 }
444
445 public String getLibOptions() {
446 StringBuilder sb = new StringBuilder();
447 for (String key : libOptions.keySet()) {
448 sb.append('-');
449 sb.append(key);
450 sb.append(' ');
451 sb.append(libOptions.get(key));
452 sb.append(' ');
453 }
454 return sb.toString();
455 }
456
457 public String[] getLibParamStringArray() {
458 final ArrayList<String> params = new ArrayList<String>();
459
460 for (String key : libOptions.keySet()) {
461 params.add("-"+key); params.add(libOptions.get(key));
462 }
463 return params.toArray(new String[params.size()]);
464 }
465
466 public abstract void initLibOptions();
467 public abstract void initAllowedLibOptionFlags();
468
469 public void parseParameters(String paramstring) throws MaltChainedException {
470 if (paramstring == null) {
471 return;
472 }
473 final String[] argv;
474 try {
475 argv = paramstring.split("[_\\p{Blank}]");
476 } catch (PatternSyntaxException e) {
477 throw new LibException("Could not split the parameter string '"+paramstring+"'. ", e);
478 }
479 for (int i=0; i < argv.length-1; i++) {
480 if(argv[i].charAt(0) != '-') {
481 throw new LibException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0));
482 }
483 if(++i>=argv.length) {
484 throw new LibException("The last argument does not have any value. ");
485 }
486 try {
487 int index = allowedLibOptionFlags.indexOf(argv[i-1].charAt(1));
488 if (index != -1) {
489 libOptions.put(Character.toString(argv[i-1].charAt(1)), argv[i]);
490 } else {
491 throw new LibException("Unknown learner parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. ");
492 }
493 } catch (ArrayIndexOutOfBoundsException e) {
494 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);
495 } catch (NumberFormatException e) {
496 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);
497 } catch (NullPointerException e) {
498 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);
499 }
500 }
501 }
502
503 protected void finalize() throws Throwable {
504 try {
505 closeInstanceWriter();
506 } finally {
507 super.finalize();
508 }
509 }
510
511 public String toString() {
512 final StringBuffer sb = new StringBuffer();
513 sb.append("\n"+getLearningMethodName()+" INTERFACE\n");
514 sb.append(getLibOptions());
515 return sb.toString();
516 }
517
518 protected int binariesInstance(String line, FeatureList featureList) throws MaltChainedException {
519 int y = -1;
520 featureList.clear();
521 try {
522 String[] columns = tabPattern.split(line);
523
524 if (columns.length == 0) {
525 return -1;
526 }
527 try {
528 y = Integer.parseInt(columns[0]);
529 } catch (NumberFormatException e) {
530 throw new LibException("The instance file contain a non-integer value '"+columns[0]+"'", e);
531 }
532 for(int j = 1; j < columns.length; j++) {
533 final String[] items = pipePattern.split(columns[j]);
534 for (int k = 0; k < items.length; k++) {
535 try {
536 int colon = items[k].indexOf(':');
537 if (colon == -1) {
538 if (Integer.parseInt(items[k]) != -1) {
539 int v = featureMap.addIndex(j, Integer.parseInt(items[k]));
540 if (v != -1) {
541 featureList.add(v,1);
542 }
543 }
544 } else {
545 int index = featureMap.addIndex(j, Integer.parseInt(items[k].substring(0,colon)));
546 double value;
547 if (items[k].substring(colon+1).indexOf('.') != -1) {
548 value = Double.parseDouble(items[k].substring(colon+1));
549 } else {
550 value = Integer.parseInt(items[k].substring(colon+1));
551 }
552 featureList.add(index,value);
553 }
554 } catch (NumberFormatException e) {
555 throw new LibException("The instance file contain a non-numeric value '"+items[k]+"'", e);
556 }
557 }
558 }
559 } catch (ArrayIndexOutOfBoundsException e) {
560 throw new LibException("Couln't read from the instance file. ", e);
561 }
562 return y;
563 }
564
565 protected void binariesInstances2SVMFileFormat(InputStreamReader isr, OutputStreamWriter osw) throws MaltChainedException {
566 try {
567 final BufferedReader in = new BufferedReader(isr);
568 final BufferedWriter out = new BufferedWriter(osw);
569 final FeatureList featureSet = new FeatureList();
570 while(true) {
571 String line = in.readLine();
572 if(line == null) break;
573 int y = binariesInstance(line, featureSet);
574 if (y == -1) {
575 continue;
576 }
577 out.write(Integer.toString(y));
578
579 for (int k=0; k < featureSet.size(); k++) {
580 MaltFeatureNode x = featureSet.get(k);
581 out.write(' ');
582 out.write(Integer.toString(x.getIndex()));
583 out.write(':');
584 out.write(Double.toString(x.getValue()));
585 }
586 out.write('\n');
587 }
588 in.close();
589 out.close();
590 } catch (NumberFormatException e) {
591 throw new LibException("The instance file contain a non-numeric value", e);
592 } catch (IOException e) {
593 throw new LibException("Couln't read from the instance file, when converting the Malt instances into LIBSV/LIBLINEAR format. ", e);
594 }
595 }
596
597 protected void saveFeatureMap(OutputStream os, FeatureMap map) throws MaltChainedException {
598 try {
599 ObjectOutputStream output = new ObjectOutputStream(os);
600 try{
601 output.writeObject(map);
602 }
603 finally{
604 output.close();
605 }
606 } catch (IOException e) {
607 throw new LibException("Save feature map error", e);
608 }
609 }
610
611 protected FeatureMap loadFeatureMap(InputStream is) throws MaltChainedException {
612 FeatureMap map = new FeatureMap();
613 try {
614 ObjectInputStream input = new ObjectInputStream(is);
615 try {
616 map = (FeatureMap)input.readObject();
617 } finally {
618 input.close();
619 }
620 } catch (ClassNotFoundException e) {
621 throw new LibException("Load feature map error", e);
622 } catch (IOException e) {
623 throw new LibException("Load feature map error", e);
624 }
625 return map;
626 }
627 }