001 package org.maltparser.ml.lib; 002 003 import java.io.BufferedOutputStream; 004 import java.io.BufferedReader; 005 import java.io.BufferedWriter; 006 import java.io.File; 007 import java.io.FileNotFoundException; 008 import java.io.FileOutputStream; 009 import java.io.IOException; 010 import java.io.InputStream; 011 import java.io.InputStreamReader; 012 import java.io.ObjectInputStream; 013 import java.io.ObjectOutputStream; 014 import java.io.OutputStream; 015 016 import java.io.OutputStreamWriter; 017 import java.util.ArrayList; 018 019 import org.apache.log4j.Logger; 020 021 import java.util.Formatter; 022 import java.util.LinkedHashMap; 023 import java.util.Set; 024 import java.util.jar.JarEntry; 025 import java.util.regex.Pattern; 026 import java.util.regex.PatternSyntaxException; 027 028 029 import org.maltparser.core.exception.MaltChainedException; 030 import org.maltparser.core.feature.FeatureVector; 031 import org.maltparser.core.feature.function.FeatureFunction; 032 import org.maltparser.core.feature.value.FeatureValue; 033 import org.maltparser.core.feature.value.MultipleFeatureValue; 034 import org.maltparser.core.feature.value.SingleFeatureValue; 035 import org.maltparser.core.syntaxgraph.DependencyStructure; 036 import org.maltparser.ml.LearningMethod; 037 import org.maltparser.parser.DependencyParserConfig; 038 import org.maltparser.parser.guide.instance.InstanceModel; 039 import org.maltparser.parser.history.action.SingleDecision; 040 041 public abstract class Lib implements LearningMethod { 042 protected Verbostity verbosity; 043 public enum Verbostity { 044 SILENT, ERROR, ALL 045 } 046 protected InstanceModel owner; 047 protected int learnerMode; 048 protected String name; 049 protected int numberOfInstances; 050 protected boolean saveInstanceFiles; 051 protected boolean excludeNullValues; 052 protected BufferedWriter instanceOutput = null; 053 protected FeatureMap featureMap; 054 protected String paramString; 055 protected String pathExternalTrain; 056 protected LinkedHashMap<String, String> libOptions; 057 protected String allowedLibOptionFlags; 058 protected Logger configLogger; 059 final Pattern tabPattern = Pattern.compile("\t"); 060 final Pattern pipePattern = Pattern.compile("\\|"); 061 private StringBuilder sb = new StringBuilder(); 062 protected MaltLibModel model = null; 063 /** 064 * Constructs a Lib learner. 065 * 066 * @param owner the guide model owner 067 * @param learnerMode the mode of the learner BATCH or CLASSIFY 068 */ 069 public Lib(InstanceModel owner, Integer learnerMode, String learningMethodName) throws MaltChainedException { 070 setOwner(owner); 071 setLearnerMode(learnerMode.intValue()); 072 setNumberOfInstances(0); 073 setLearningMethodName(learningMethodName); 074 verbosity = Verbostity.SILENT; 075 configLogger = owner.getGuide().getConfiguration().getConfigLogger(); 076 initLibOptions(); 077 initAllowedLibOptionFlags(); 078 parseParameters(getConfiguration().getOptionValue("lib", "options").toString()); 079 initSpecialParameters(); 080 081 if (learnerMode == BATCH) { 082 featureMap = new FeatureMap(); 083 instanceOutput = new BufferedWriter(getInstanceOutputStreamWriter(".ins")); 084 } else if (learnerMode == CLASSIFY) { 085 featureMap = loadFeatureMap(getInputStreamFromConfigFileEntry(".map")); 086 } 087 } 088 089 090 public void addInstance(SingleDecision decision, FeatureVector featureVector) throws MaltChainedException { 091 if (featureVector == null) { 092 throw new LibException("The feature vector cannot be found"); 093 } else if (decision == null) { 094 throw new LibException("The decision cannot be found"); 095 } 096 097 try { 098 sb.append(decision.getDecisionCode()+"\t"); 099 final int n = featureVector.size(); 100 for (int i = 0; i < n; i++) { 101 FeatureValue featureValue = featureVector.getFeatureValue(i); 102 if (featureValue == null || (excludeNullValues == true && featureValue.isNullValue())) { 103 sb.append("-1"); 104 } else { 105 if (featureValue instanceof SingleFeatureValue) { 106 SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue; 107 if (singleFeatureValue.getValue() == 1) { 108 sb.append(singleFeatureValue.getIndexCode()); 109 } else if (singleFeatureValue.getValue() == 0) { 110 sb.append("-1"); 111 } else { 112 sb.append(singleFeatureValue.getIndexCode()); 113 sb.append(":"); 114 sb.append(singleFeatureValue.getValue()); 115 } 116 } else if (featureValue instanceof MultipleFeatureValue) { 117 Set<Integer> values = ((MultipleFeatureValue)featureValue).getCodes(); 118 int j=0; 119 for (Integer value : values) { 120 sb.append(value.toString()); 121 if (j != values.size()-1) { 122 sb.append("|"); 123 } 124 j++; 125 } 126 } else { 127 throw new LibException("Don't recognize the type of feature value: "+featureValue.getClass()); 128 } 129 } 130 sb.append('\t'); 131 } 132 sb.append('\n'); 133 instanceOutput.write(sb.toString()); 134 instanceOutput.flush(); 135 increaseNumberOfInstances(); 136 sb.setLength(0); 137 } catch (IOException e) { 138 throw new LibException("The learner cannot write to the instance file. ", e); 139 } 140 } 141 142 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { } 143 144 public void moveAllInstances(LearningMethod method, 145 FeatureFunction divideFeature, 146 ArrayList<Integer> divideFeatureIndexVector) 147 throws MaltChainedException { 148 if (method == null) { 149 throw new LibException("The learning method cannot be found. "); 150 } else if (divideFeature == null) { 151 throw new LibException("The divide feature cannot be found. "); 152 } 153 154 try { 155 final BufferedReader in = new BufferedReader(getInstanceInputStreamReader(".ins")); 156 final BufferedWriter out = method.getInstanceWriter(); 157 final StringBuilder sb = new StringBuilder(6); 158 int l = in.read(); 159 char c; 160 int j = 0; 161 162 while(true) { 163 if (l == -1) { 164 sb.setLength(0); 165 break; 166 } 167 c = (char)l; 168 l = in.read(); 169 if (c == '\t') { 170 if (divideFeatureIndexVector.contains(j-1)) { 171 out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode())); 172 out.write('\t'); 173 } 174 out.write(sb.toString()); 175 j++; 176 out.write('\t'); 177 sb.setLength(0); 178 } else if (c == '\n') { 179 out.write(sb.toString()); 180 if (divideFeatureIndexVector.contains(j-1)) { 181 out.write('\t'); 182 out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode())); 183 } 184 out.write('\n'); 185 sb.setLength(0); 186 method.increaseNumberOfInstances(); 187 this.decreaseNumberOfInstances(); 188 j = 0; 189 } else { 190 sb.append(c); 191 } 192 } 193 in.close(); 194 getFile(".ins").delete(); 195 out.flush(); 196 } catch (SecurityException e) { 197 throw new LibException("The learner cannot remove the instance file. ", e); 198 } catch (NullPointerException e) { 199 throw new LibException("The instance file cannot be found. ", e); 200 } catch (FileNotFoundException e) { 201 throw new LibException("The instance file cannot be found. ", e); 202 } catch (IOException e) { 203 throw new LibException("The learner read from the instance file. ", e); 204 } 205 } 206 207 public void noMoreInstances() throws MaltChainedException { 208 closeInstanceWriter(); 209 } 210 211 public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException { 212 if (featureVector == null) { 213 throw new LibException("The learner cannot predict the next class, because the feature vector cannot be found. "); 214 } 215 216 final FeatureList featureList = new FeatureList(); 217 final int size = featureVector.size(); 218 for (int i = 1; i <= size; i++) { 219 final FeatureValue featureValue = featureVector.getFeatureValue(i-1); 220 if (featureValue != null && !(excludeNullValues == true && featureValue.isNullValue())) { 221 if (featureValue instanceof SingleFeatureValue) { 222 SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue; 223 int index = featureMap.getIndex(i, singleFeatureValue.getIndexCode()); 224 if (index != -1 && singleFeatureValue.getValue() != 0) { 225 featureList.add(index,singleFeatureValue.getValue()); 226 } 227 } else if (featureValue instanceof MultipleFeatureValue) { 228 for (Integer value : ((MultipleFeatureValue)featureValue).getCodes()) { 229 int v = featureMap.getIndex(i, value); 230 if (v != -1) { 231 featureList.add(v,1); 232 } 233 } 234 } 235 } 236 } 237 try { 238 decision.getKBestList().addList(model.predict(featureList.toArray())); 239 // decision.getKBestList().addList(prediction(featureList)); 240 } catch (OutOfMemoryError e) { 241 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e); 242 } 243 return true; 244 } 245 246 // protected abstract int[] prediction(FeatureList featureList) throws MaltChainedException; 247 248 public void train(FeatureVector featureVector) throws MaltChainedException { 249 if (featureVector == null) { 250 throw new LibException("The feature vector cannot be found. "); 251 } else if (owner == null) { 252 throw new LibException("The parent guide model cannot be found. "); 253 } 254 long startTime = System.currentTimeMillis(); 255 256 // if (configLogger.isInfoEnabled()) { 257 // configLogger.info("\nStart training\n"); 258 // } 259 if (pathExternalTrain != null) { 260 trainExternal(featureVector); 261 } else { 262 trainInternal(featureVector); 263 } 264 // long elapsed = System.currentTimeMillis() - startTime; 265 // if (configLogger.isInfoEnabled()) { 266 // configLogger.info("Time 1: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n"); 267 // } 268 try { 269 // if (configLogger.isInfoEnabled()) { 270 // configLogger.info("\nSaving feature map "+getFile(".map").getName()+"\n"); 271 // } 272 saveFeatureMap(new BufferedOutputStream(new FileOutputStream(getFile(".map").getAbsolutePath())), featureMap); 273 } catch (FileNotFoundException e) { 274 throw new LibException("The learner cannot save the feature map file '"+getFile(".map").getAbsolutePath()+"'. ", e); 275 } 276 // elapsed = System.currentTimeMillis() - startTime; 277 // if (configLogger.isInfoEnabled()) { 278 // configLogger.info("Time 2: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n"); 279 // } 280 } 281 protected abstract void trainExternal(FeatureVector featureVector) throws MaltChainedException; 282 protected abstract void trainInternal(FeatureVector featureVector) throws MaltChainedException; 283 284 public void terminate() throws MaltChainedException { 285 closeInstanceWriter(); 286 owner = null; 287 model = null; 288 } 289 290 public BufferedWriter getInstanceWriter() { 291 return instanceOutput; 292 } 293 294 protected void closeInstanceWriter() throws MaltChainedException { 295 try { 296 if (instanceOutput != null) { 297 instanceOutput.flush(); 298 instanceOutput.close(); 299 instanceOutput = null; 300 } 301 } catch (IOException e) { 302 throw new LibException("The learner cannot close the instance file. ", e); 303 } 304 } 305 306 307 /** 308 * Returns the parameter string used for configure the learner 309 * 310 * @return the parameter string used for configure the learner 311 */ 312 public String getParamString() { 313 return paramString; 314 } 315 316 public InstanceModel getOwner() { 317 return owner; 318 } 319 320 protected void setOwner(InstanceModel owner) { 321 this.owner = owner; 322 } 323 324 public int getLearnerMode() { 325 return learnerMode; 326 } 327 328 public void setLearnerMode(int learnerMode) throws MaltChainedException { 329 this.learnerMode = learnerMode; 330 } 331 332 public String getLearningMethodName() { 333 return name; 334 } 335 336 /** 337 * Returns the current configuration 338 * 339 * @return the current configuration 340 * @throws MaltChainedException 341 */ 342 public DependencyParserConfig getConfiguration() throws MaltChainedException { 343 return owner.getGuide().getConfiguration(); 344 } 345 346 public int getNumberOfInstances() throws MaltChainedException { 347 if(numberOfInstances!=0) 348 return numberOfInstances; 349 else{ 350 BufferedReader reader = new BufferedReader( getInstanceInputStreamReader(".ins")); 351 try { 352 while(reader.readLine()!=null){ 353 numberOfInstances++; 354 owner.increaseFrequency(); 355 } 356 reader.close(); 357 } catch (IOException e) { 358 throw new MaltChainedException("No instances found in file",e); 359 } 360 return numberOfInstances; 361 } 362 } 363 364 public void increaseNumberOfInstances() { 365 numberOfInstances++; 366 owner.increaseFrequency(); 367 } 368 369 public void decreaseNumberOfInstances() { 370 numberOfInstances--; 371 owner.decreaseFrequency(); 372 } 373 374 protected void setNumberOfInstances(int numberOfInstances) { 375 this.numberOfInstances = 0; 376 } 377 378 protected void setLearningMethodName(String name) { 379 this.name = name; 380 } 381 382 public String getPathExternalTrain() { 383 return pathExternalTrain; 384 } 385 386 387 public void setPathExternalTrain(String pathExternalTrain) { 388 this.pathExternalTrain = pathExternalTrain; 389 } 390 391 protected OutputStreamWriter getInstanceOutputStreamWriter(String suffix) throws MaltChainedException { 392 return getConfiguration().getConfigurationDir().getAppendOutputStreamWriter(owner.getModelName()+getLearningMethodName()+suffix); 393 } 394 395 protected InputStreamReader getInstanceInputStreamReader(String suffix) throws MaltChainedException { 396 return getConfiguration().getConfigurationDir().getInputStreamReader(owner.getModelName()+getLearningMethodName()+suffix); 397 } 398 399 protected InputStreamReader getInstanceInputStreamReaderFromConfigFile(String suffix) throws MaltChainedException { 400 return getConfiguration().getConfigurationDir().getInputStreamReaderFromConfigFile(owner.getModelName()+getLearningMethodName()+suffix); 401 } 402 403 protected InputStream getInputStreamFromConfigFileEntry(String suffix) throws MaltChainedException { 404 return getConfiguration().getConfigurationDir().getInputStreamFromConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix); 405 } 406 407 408 protected File getFile(String suffix) throws MaltChainedException { 409 return getConfiguration().getConfigurationDir().getFile(owner.getModelName()+getLearningMethodName()+suffix); 410 } 411 412 protected JarEntry getConfigFileEntry(String suffix) throws MaltChainedException { 413 return getConfiguration().getConfigurationDir().getConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix); 414 } 415 416 protected void initSpecialParameters() throws MaltChainedException { 417 if (getConfiguration().getOptionValue("singlemalt", "null_value") != null && getConfiguration().getOptionValue("singlemalt", "null_value").toString().equalsIgnoreCase("none")) { 418 excludeNullValues = true; 419 } else { 420 excludeNullValues = false; 421 } 422 saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue(); 423 if (!getConfiguration().getOptionValue("lib", "external").toString().equals("")) { 424 String path = getConfiguration().getOptionValue("lib", "external").toString(); 425 try { 426 if (!new File(path).exists()) { 427 throw new LibException("The path to the external trainer 'svm-train' is wrong."); 428 } 429 if (new File(path).isDirectory()) { 430 throw new LibException("The option --lib-external points to a directory, the path should point at the 'train' file or the 'train.exe' file in the libsvm or the liblinear package"); 431 } 432 if (!(path.endsWith("train") ||path.endsWith("train.exe"))) { 433 throw new LibException("The option --lib-external does not specify the path to 'train' file or the 'train.exe' file in the libsvm or the liblinear package. "); 434 } 435 setPathExternalTrain(path); 436 } catch (SecurityException e) { 437 throw new LibException("Access denied to the file specified by the option --lib-external. ", e); 438 } 439 } 440 if (getConfiguration().getOptionValue("lib", "verbosity") != null) { 441 verbosity = Verbostity.valueOf(getConfiguration().getOptionValue("lib", "verbosity").toString().toUpperCase()); 442 } 443 } 444 445 public String getLibOptions() { 446 StringBuilder sb = new StringBuilder(); 447 for (String key : libOptions.keySet()) { 448 sb.append('-'); 449 sb.append(key); 450 sb.append(' '); 451 sb.append(libOptions.get(key)); 452 sb.append(' '); 453 } 454 return sb.toString(); 455 } 456 457 public String[] getLibParamStringArray() { 458 final ArrayList<String> params = new ArrayList<String>(); 459 460 for (String key : libOptions.keySet()) { 461 params.add("-"+key); params.add(libOptions.get(key)); 462 } 463 return params.toArray(new String[params.size()]); 464 } 465 466 public abstract void initLibOptions(); 467 public abstract void initAllowedLibOptionFlags(); 468 469 public void parseParameters(String paramstring) throws MaltChainedException { 470 if (paramstring == null) { 471 return; 472 } 473 final String[] argv; 474 try { 475 argv = paramstring.split("[_\\p{Blank}]"); 476 } catch (PatternSyntaxException e) { 477 throw new LibException("Could not split the parameter string '"+paramstring+"'. ", e); 478 } 479 for (int i=0; i < argv.length-1; i++) { 480 if(argv[i].charAt(0) != '-') { 481 throw new LibException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0)); 482 } 483 if(++i>=argv.length) { 484 throw new LibException("The last argument does not have any value. "); 485 } 486 try { 487 int index = allowedLibOptionFlags.indexOf(argv[i-1].charAt(1)); 488 if (index != -1) { 489 libOptions.put(Character.toString(argv[i-1].charAt(1)), argv[i]); 490 } else { 491 throw new LibException("Unknown learner parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. "); 492 } 493 } catch (ArrayIndexOutOfBoundsException e) { 494 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e); 495 } catch (NumberFormatException e) { 496 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e); 497 } catch (NullPointerException e) { 498 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e); 499 } 500 } 501 } 502 503 protected void finalize() throws Throwable { 504 try { 505 closeInstanceWriter(); 506 } finally { 507 super.finalize(); 508 } 509 } 510 511 public String toString() { 512 final StringBuffer sb = new StringBuffer(); 513 sb.append("\n"+getLearningMethodName()+" INTERFACE\n"); 514 sb.append(getLibOptions()); 515 return sb.toString(); 516 } 517 518 protected int binariesInstance(String line, FeatureList featureList) throws MaltChainedException { 519 int y = -1; 520 featureList.clear(); 521 try { 522 String[] columns = tabPattern.split(line); 523 524 if (columns.length == 0) { 525 return -1; 526 } 527 try { 528 y = Integer.parseInt(columns[0]); 529 } catch (NumberFormatException e) { 530 throw new LibException("The instance file contain a non-integer value '"+columns[0]+"'", e); 531 } 532 for(int j = 1; j < columns.length; j++) { 533 final String[] items = pipePattern.split(columns[j]); 534 for (int k = 0; k < items.length; k++) { 535 try { 536 int colon = items[k].indexOf(':'); 537 if (colon == -1) { 538 if (Integer.parseInt(items[k]) != -1) { 539 int v = featureMap.addIndex(j, Integer.parseInt(items[k])); 540 if (v != -1) { 541 featureList.add(v,1); 542 } 543 } 544 } else { 545 int index = featureMap.addIndex(j, Integer.parseInt(items[k].substring(0,colon))); 546 double value; 547 if (items[k].substring(colon+1).indexOf('.') != -1) { 548 value = Double.parseDouble(items[k].substring(colon+1)); 549 } else { 550 value = Integer.parseInt(items[k].substring(colon+1)); 551 } 552 featureList.add(index,value); 553 } 554 } catch (NumberFormatException e) { 555 throw new LibException("The instance file contain a non-numeric value '"+items[k]+"'", e); 556 } 557 } 558 } 559 } catch (ArrayIndexOutOfBoundsException e) { 560 throw new LibException("Couln't read from the instance file. ", e); 561 } 562 return y; 563 } 564 565 protected void binariesInstances2SVMFileFormat(InputStreamReader isr, OutputStreamWriter osw) throws MaltChainedException { 566 try { 567 final BufferedReader in = new BufferedReader(isr); 568 final BufferedWriter out = new BufferedWriter(osw); 569 final FeatureList featureSet = new FeatureList(); 570 while(true) { 571 String line = in.readLine(); 572 if(line == null) break; 573 int y = binariesInstance(line, featureSet); 574 if (y == -1) { 575 continue; 576 } 577 out.write(Integer.toString(y)); 578 579 for (int k=0; k < featureSet.size(); k++) { 580 MaltFeatureNode x = featureSet.get(k); 581 out.write(' '); 582 out.write(Integer.toString(x.getIndex())); 583 out.write(':'); 584 out.write(Double.toString(x.getValue())); 585 } 586 out.write('\n'); 587 } 588 in.close(); 589 out.close(); 590 } catch (NumberFormatException e) { 591 throw new LibException("The instance file contain a non-numeric value", e); 592 } catch (IOException e) { 593 throw new LibException("Couln't read from the instance file, when converting the Malt instances into LIBSV/LIBLINEAR format. ", e); 594 } 595 } 596 597 protected void saveFeatureMap(OutputStream os, FeatureMap map) throws MaltChainedException { 598 try { 599 ObjectOutputStream output = new ObjectOutputStream(os); 600 try{ 601 output.writeObject(map); 602 } 603 finally{ 604 output.close(); 605 } 606 } catch (IOException e) { 607 throw new LibException("Save feature map error", e); 608 } 609 } 610 611 protected FeatureMap loadFeatureMap(InputStream is) throws MaltChainedException { 612 FeatureMap map = new FeatureMap(); 613 try { 614 ObjectInputStream input = new ObjectInputStream(is); 615 try { 616 map = (FeatureMap)input.readObject(); 617 } finally { 618 input.close(); 619 } 620 } catch (ClassNotFoundException e) { 621 throw new LibException("Load feature map error", e); 622 } catch (IOException e) { 623 throw new LibException("Load feature map error", e); 624 } 625 return map; 626 } 627 }