001 package org.maltparser.ml.lib; 002 003 import java.io.BufferedOutputStream; 004 import java.io.BufferedReader; 005 import java.io.BufferedWriter; 006 import java.io.File; 007 import java.io.FileNotFoundException; 008 import java.io.FileOutputStream; 009 import java.io.IOException; 010 import java.io.InputStream; 011 import java.io.InputStreamReader; 012 import java.io.ObjectInputStream; 013 import java.io.ObjectOutputStream; 014 import java.io.OutputStream; 015 016 import java.io.OutputStreamWriter; 017 import java.util.ArrayList; 018 019 import org.apache.log4j.Logger; 020 021 import java.util.LinkedHashMap; 022 import java.util.Set; 023 import java.util.jar.JarEntry; 024 import java.util.regex.Pattern; 025 import java.util.regex.PatternSyntaxException; 026 027 028 import org.maltparser.core.exception.MaltChainedException; 029 import org.maltparser.core.feature.FeatureVector; 030 import org.maltparser.core.feature.function.FeatureFunction; 031 import org.maltparser.core.feature.value.FeatureValue; 032 import org.maltparser.core.feature.value.MultipleFeatureValue; 033 import org.maltparser.core.feature.value.SingleFeatureValue; 034 import org.maltparser.core.syntaxgraph.DependencyStructure; 035 import org.maltparser.ml.LearningMethod; 036 import org.maltparser.parser.DependencyParserConfig; 037 import org.maltparser.parser.guide.instance.InstanceModel; 038 import org.maltparser.parser.history.action.SingleDecision; 039 040 public abstract class Lib implements LearningMethod { 041 protected Verbostity verbosity; 042 public enum Verbostity { 043 SILENT, ERROR, ALL 044 } 045 protected InstanceModel owner; 046 protected int learnerMode; 047 protected String name; 048 protected int numberOfInstances; 049 protected boolean saveInstanceFiles; 050 protected boolean excludeNullValues; 051 protected BufferedWriter instanceOutput = null; 052 protected FeatureMap featureMap; 053 protected String paramString; 054 protected String pathExternalTrain; 055 protected LinkedHashMap<String, String> libOptions; 056 protected String allowedLibOptionFlags; 057 protected Logger configLogger; 058 protected final Pattern tabPattern = Pattern.compile("\t"); 059 protected final Pattern pipePattern = Pattern.compile("\\|"); 060 private final StringBuilder sb = new StringBuilder(); 061 protected MaltLibModel model = null; 062 /** 063 * Constructs a Lib learner. 064 * 065 * @param owner the guide model owner 066 * @param learnerMode the mode of the learner BATCH or CLASSIFY 067 */ 068 public Lib(InstanceModel owner, Integer learnerMode, String learningMethodName) throws MaltChainedException { 069 setOwner(owner); 070 setLearnerMode(learnerMode.intValue()); 071 setNumberOfInstances(0); 072 setLearningMethodName(learningMethodName); 073 verbosity = Verbostity.SILENT; 074 configLogger = owner.getGuide().getConfiguration().getConfigLogger(); 075 initLibOptions(); 076 initAllowedLibOptionFlags(); 077 parseParameters(getConfiguration().getOptionValue("lib", "options").toString()); 078 initSpecialParameters(); 079 080 if (learnerMode == BATCH) { 081 featureMap = new FeatureMap(); 082 instanceOutput = new BufferedWriter(getInstanceOutputStreamWriter(".ins")); 083 } else if (learnerMode == CLASSIFY) { 084 featureMap = loadFeatureMap(getInputStreamFromConfigFileEntry(".map")); 085 } 086 } 087 088 089 public void addInstance(SingleDecision decision, FeatureVector featureVector) throws MaltChainedException { 090 if (featureVector == null) { 091 throw new LibException("The feature vector cannot be found"); 092 } else if (decision == null) { 093 throw new LibException("The decision cannot be found"); 094 } 095 096 try { 097 sb.append(decision.getDecisionCode()+"\t"); 098 final int n = featureVector.size(); 099 for (int i = 0; i < n; i++) { 100 FeatureValue featureValue = featureVector.getFeatureValue(i); 101 if (featureValue == null || (excludeNullValues == true && featureValue.isNullValue())) { 102 sb.append("-1"); 103 } else { 104 if (!featureValue.isMultiple()) { 105 SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue; 106 if (singleFeatureValue.getValue() == 1) { 107 sb.append(singleFeatureValue.getIndexCode()); 108 } else if (singleFeatureValue.getValue() == 0) { 109 sb.append("-1"); 110 } else { 111 sb.append(singleFeatureValue.getIndexCode()); 112 sb.append(":"); 113 sb.append(singleFeatureValue.getValue()); 114 } 115 } else { //if (featureValue instanceof MultipleFeatureValue) { 116 Set<Integer> values = ((MultipleFeatureValue)featureValue).getCodes(); 117 int j=0; 118 for (Integer value : values) { 119 sb.append(value.toString()); 120 if (j != values.size()-1) { 121 sb.append("|"); 122 } 123 j++; 124 } 125 } 126 // else { 127 // throw new LibException("Don't recognize the type of feature value: "+featureValue.getClass()); 128 // } 129 } 130 sb.append('\t'); 131 } 132 sb.append('\n'); 133 instanceOutput.write(sb.toString()); 134 instanceOutput.flush(); 135 increaseNumberOfInstances(); 136 sb.setLength(0); 137 } catch (IOException e) { 138 throw new LibException("The learner cannot write to the instance file. ", e); 139 } 140 } 141 142 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { } 143 144 public void moveAllInstances(LearningMethod method, 145 FeatureFunction divideFeature, 146 ArrayList<Integer> divideFeatureIndexVector) 147 throws MaltChainedException { 148 if (method == null) { 149 throw new LibException("The learning method cannot be found. "); 150 } else if (divideFeature == null) { 151 throw new LibException("The divide feature cannot be found. "); 152 } 153 154 try { 155 final BufferedReader in = new BufferedReader(getInstanceInputStreamReader(".ins")); 156 final BufferedWriter out = method.getInstanceWriter(); 157 final StringBuilder sb = new StringBuilder(6); 158 int l = in.read(); 159 char c; 160 int j = 0; 161 162 while(true) { 163 if (l == -1) { 164 sb.setLength(0); 165 break; 166 } 167 c = (char)l; 168 l = in.read(); 169 if (c == '\t') { 170 if (divideFeatureIndexVector.contains(j-1)) { 171 out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode())); 172 out.write('\t'); 173 } 174 out.write(sb.toString()); 175 j++; 176 out.write('\t'); 177 sb.setLength(0); 178 } else if (c == '\n') { 179 out.write(sb.toString()); 180 if (divideFeatureIndexVector.contains(j-1)) { 181 out.write('\t'); 182 out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode())); 183 } 184 out.write('\n'); 185 sb.setLength(0); 186 method.increaseNumberOfInstances(); 187 this.decreaseNumberOfInstances(); 188 j = 0; 189 } else { 190 sb.append(c); 191 } 192 } 193 in.close(); 194 getFile(".ins").delete(); 195 out.flush(); 196 } catch (SecurityException e) { 197 throw new LibException("The learner cannot remove the instance file. ", e); 198 } catch (NullPointerException e) { 199 throw new LibException("The instance file cannot be found. ", e); 200 } catch (FileNotFoundException e) { 201 throw new LibException("The instance file cannot be found. ", e); 202 } catch (IOException e) { 203 throw new LibException("The learner read from the instance file. ", e); 204 } 205 } 206 207 public void noMoreInstances() throws MaltChainedException { 208 closeInstanceWriter(); 209 } 210 211 public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException { 212 // if (featureVector == null) { 213 // throw new LibException("The learner cannot predict the next class, because the feature vector cannot be found. "); 214 // } 215 final FeatureList featureList = new FeatureList(); 216 final int size = featureVector.size(); 217 for (int i = 1; i <= size; i++) { 218 final FeatureValue featureValue = featureVector.getFeatureValue(i-1); 219 if (featureValue != null && !(excludeNullValues == true && featureValue.isNullValue())) { 220 if (!featureValue.isMultiple()) { 221 SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue; 222 final int index = featureMap.getIndex(i, singleFeatureValue.getIndexCode()); 223 if (index != -1 && singleFeatureValue.getValue() != 0) { 224 featureList.add(index,singleFeatureValue.getValue()); 225 } 226 } 227 else { //if (featureValue instanceof MultipleFeatureValue) { 228 for (Integer value : ((MultipleFeatureValue)featureValue).getCodes()) { 229 final int v = featureMap.getIndex(i, value); 230 if (v != -1) { 231 featureList.add(v,1); 232 } 233 } 234 } 235 } 236 } 237 try { 238 decision.getKBestList().addList(model.predict(featureList.toArray())); 239 } catch (OutOfMemoryError e) { 240 throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e); 241 } 242 return true; 243 } 244 245 // protected abstract int[] prediction(FeatureList featureList) throws MaltChainedException; 246 247 public void train(FeatureVector featureVector) throws MaltChainedException { 248 if (featureVector == null) { 249 throw new LibException("The feature vector cannot be found. "); 250 } else if (owner == null) { 251 throw new LibException("The parent guide model cannot be found. "); 252 } 253 long startTime = System.currentTimeMillis(); 254 255 // if (configLogger.isInfoEnabled()) { 256 // configLogger.info("\nStart training\n"); 257 // } 258 if (pathExternalTrain != null) { 259 trainExternal(featureVector); 260 } else { 261 trainInternal(featureVector); 262 } 263 // long elapsed = System.currentTimeMillis() - startTime; 264 // if (configLogger.isInfoEnabled()) { 265 // configLogger.info("Time 1: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n"); 266 // } 267 try { 268 // if (configLogger.isInfoEnabled()) { 269 // configLogger.info("\nSaving feature map "+getFile(".map").getName()+"\n"); 270 // } 271 saveFeatureMap(new BufferedOutputStream(new FileOutputStream(getFile(".map").getAbsolutePath())), featureMap); 272 } catch (FileNotFoundException e) { 273 throw new LibException("The learner cannot save the feature map file '"+getFile(".map").getAbsolutePath()+"'. ", e); 274 } 275 // elapsed = System.currentTimeMillis() - startTime; 276 // if (configLogger.isInfoEnabled()) { 277 // configLogger.info("Time 2: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n"); 278 // } 279 } 280 protected abstract void trainExternal(FeatureVector featureVector) throws MaltChainedException; 281 protected abstract void trainInternal(FeatureVector featureVector) throws MaltChainedException; 282 283 public void terminate() throws MaltChainedException { 284 closeInstanceWriter(); 285 owner = null; 286 model = null; 287 } 288 289 public BufferedWriter getInstanceWriter() { 290 return instanceOutput; 291 } 292 293 protected void closeInstanceWriter() throws MaltChainedException { 294 try { 295 if (instanceOutput != null) { 296 instanceOutput.flush(); 297 instanceOutput.close(); 298 instanceOutput = null; 299 } 300 } catch (IOException e) { 301 throw new LibException("The learner cannot close the instance file. ", e); 302 } 303 } 304 305 306 /** 307 * Returns the parameter string used for configure the learner 308 * 309 * @return the parameter string used for configure the learner 310 */ 311 public String getParamString() { 312 return paramString; 313 } 314 315 public InstanceModel getOwner() { 316 return owner; 317 } 318 319 protected void setOwner(InstanceModel owner) { 320 this.owner = owner; 321 } 322 323 public int getLearnerMode() { 324 return learnerMode; 325 } 326 327 public void setLearnerMode(int learnerMode) throws MaltChainedException { 328 this.learnerMode = learnerMode; 329 } 330 331 public String getLearningMethodName() { 332 return name; 333 } 334 335 /** 336 * Returns the current configuration 337 * 338 * @return the current configuration 339 * @throws MaltChainedException 340 */ 341 public DependencyParserConfig getConfiguration() throws MaltChainedException { 342 return owner.getGuide().getConfiguration(); 343 } 344 345 public int getNumberOfInstances() throws MaltChainedException { 346 if(numberOfInstances!=0) 347 return numberOfInstances; 348 else{ 349 BufferedReader reader = new BufferedReader( getInstanceInputStreamReader(".ins")); 350 try { 351 while(reader.readLine()!=null){ 352 numberOfInstances++; 353 owner.increaseFrequency(); 354 } 355 reader.close(); 356 } catch (IOException e) { 357 throw new MaltChainedException("No instances found in file",e); 358 } 359 return numberOfInstances; 360 } 361 } 362 363 public void increaseNumberOfInstances() { 364 numberOfInstances++; 365 owner.increaseFrequency(); 366 } 367 368 public void decreaseNumberOfInstances() { 369 numberOfInstances--; 370 owner.decreaseFrequency(); 371 } 372 373 protected void setNumberOfInstances(int numberOfInstances) { 374 this.numberOfInstances = 0; 375 } 376 377 protected void setLearningMethodName(String name) { 378 this.name = name; 379 } 380 381 public String getPathExternalTrain() { 382 return pathExternalTrain; 383 } 384 385 386 public void setPathExternalTrain(String pathExternalTrain) { 387 this.pathExternalTrain = pathExternalTrain; 388 } 389 390 protected OutputStreamWriter getInstanceOutputStreamWriter(String suffix) throws MaltChainedException { 391 return getConfiguration().getConfigurationDir().getAppendOutputStreamWriter(owner.getModelName()+getLearningMethodName()+suffix); 392 } 393 394 protected InputStreamReader getInstanceInputStreamReader(String suffix) throws MaltChainedException { 395 return getConfiguration().getConfigurationDir().getInputStreamReader(owner.getModelName()+getLearningMethodName()+suffix); 396 } 397 398 protected InputStreamReader getInstanceInputStreamReaderFromConfigFile(String suffix) throws MaltChainedException { 399 return getConfiguration().getConfigurationDir().getInputStreamReaderFromConfigFile(owner.getModelName()+getLearningMethodName()+suffix); 400 } 401 402 protected InputStream getInputStreamFromConfigFileEntry(String suffix) throws MaltChainedException { 403 return getConfiguration().getConfigurationDir().getInputStreamFromConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix); 404 } 405 406 407 protected File getFile(String suffix) throws MaltChainedException { 408 return getConfiguration().getConfigurationDir().getFile(owner.getModelName()+getLearningMethodName()+suffix); 409 } 410 411 protected JarEntry getConfigFileEntry(String suffix) throws MaltChainedException { 412 return getConfiguration().getConfigurationDir().getConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix); 413 } 414 415 protected void initSpecialParameters() throws MaltChainedException { 416 if (getConfiguration().getOptionValue("singlemalt", "null_value") != null && getConfiguration().getOptionValue("singlemalt", "null_value").toString().equalsIgnoreCase("none")) { 417 excludeNullValues = true; 418 } else { 419 excludeNullValues = false; 420 } 421 saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue(); 422 if (!getConfiguration().getOptionValue("lib", "external").toString().equals("")) { 423 String path = getConfiguration().getOptionValue("lib", "external").toString(); 424 try { 425 if (!new File(path).exists()) { 426 throw new LibException("The path to the external trainer 'svm-train' is wrong."); 427 } 428 if (new File(path).isDirectory()) { 429 throw new LibException("The option --lib-external points to a directory, the path should point at the 'train' file or the 'train.exe' file in the libsvm or the liblinear package"); 430 } 431 if (!(path.endsWith("train") ||path.endsWith("train.exe"))) { 432 throw new LibException("The option --lib-external does not specify the path to 'train' file or the 'train.exe' file in the libsvm or the liblinear package. "); 433 } 434 setPathExternalTrain(path); 435 } catch (SecurityException e) { 436 throw new LibException("Access denied to the file specified by the option --lib-external. ", e); 437 } 438 } 439 if (getConfiguration().getOptionValue("lib", "verbosity") != null) { 440 verbosity = Verbostity.valueOf(getConfiguration().getOptionValue("lib", "verbosity").toString().toUpperCase()); 441 } 442 } 443 444 public String getLibOptions() { 445 final StringBuilder sb = new StringBuilder(); 446 for (String key : libOptions.keySet()) { 447 sb.append('-'); 448 sb.append(key); 449 sb.append(' '); 450 sb.append(libOptions.get(key)); 451 sb.append(' '); 452 } 453 return sb.toString(); 454 } 455 456 public String[] getLibParamStringArray() { 457 final ArrayList<String> params = new ArrayList<String>(); 458 459 for (String key : libOptions.keySet()) { 460 params.add("-"+key); params.add(libOptions.get(key)); 461 } 462 return params.toArray(new String[params.size()]); 463 } 464 465 public abstract void initLibOptions(); 466 public abstract void initAllowedLibOptionFlags(); 467 468 public void parseParameters(String paramstring) throws MaltChainedException { 469 if (paramstring == null) { 470 return; 471 } 472 final String[] argv; 473 try { 474 argv = paramstring.split("[_\\p{Blank}]"); 475 } catch (PatternSyntaxException e) { 476 throw new LibException("Could not split the parameter string '"+paramstring+"'. ", e); 477 } 478 for (int i=0; i < argv.length-1; i++) { 479 if(argv[i].charAt(0) != '-') { 480 throw new LibException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0)); 481 } 482 if(++i>=argv.length) { 483 throw new LibException("The last argument does not have any value. "); 484 } 485 try { 486 int index = allowedLibOptionFlags.indexOf(argv[i-1].charAt(1)); 487 if (index != -1) { 488 libOptions.put(Character.toString(argv[i-1].charAt(1)), argv[i]); 489 } else { 490 throw new LibException("Unknown learner parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. "); 491 } 492 } catch (ArrayIndexOutOfBoundsException e) { 493 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e); 494 } catch (NumberFormatException e) { 495 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e); 496 } catch (NullPointerException e) { 497 throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e); 498 } 499 } 500 } 501 502 protected void finalize() throws Throwable { 503 try { 504 closeInstanceWriter(); 505 } finally { 506 super.finalize(); 507 } 508 } 509 510 public String toString() { 511 final StringBuffer sb = new StringBuffer(); 512 sb.append("\n"+getLearningMethodName()+" INTERFACE\n"); 513 sb.append(getLibOptions()); 514 return sb.toString(); 515 } 516 517 protected int binariesInstance(String line, FeatureList featureList) throws MaltChainedException { 518 int y = -1; 519 featureList.clear(); 520 try { 521 String[] columns = tabPattern.split(line); 522 523 if (columns.length == 0) { 524 return -1; 525 } 526 try { 527 y = Integer.parseInt(columns[0]); 528 } catch (NumberFormatException e) { 529 throw new LibException("The instance file contain a non-integer value '"+columns[0]+"'", e); 530 } 531 for(int j = 1; j < columns.length; j++) { 532 final String[] items = pipePattern.split(columns[j]); 533 for (int k = 0; k < items.length; k++) { 534 try { 535 int colon = items[k].indexOf(':'); 536 if (colon == -1) { 537 if (Integer.parseInt(items[k]) != -1) { 538 int v = featureMap.addIndex(j, Integer.parseInt(items[k])); 539 if (v != -1) { 540 featureList.add(v,1); 541 } 542 } 543 } else { 544 int index = featureMap.addIndex(j, Integer.parseInt(items[k].substring(0,colon))); 545 double value; 546 if (items[k].substring(colon+1).indexOf('.') != -1) { 547 value = Double.parseDouble(items[k].substring(colon+1)); 548 } else { 549 value = Integer.parseInt(items[k].substring(colon+1)); 550 } 551 featureList.add(index,value); 552 } 553 } catch (NumberFormatException e) { 554 throw new LibException("The instance file contain a non-numeric value '"+items[k]+"'", e); 555 } 556 } 557 } 558 } catch (ArrayIndexOutOfBoundsException e) { 559 throw new LibException("Couln't read from the instance file. ", e); 560 } 561 return y; 562 } 563 564 protected void binariesInstances2SVMFileFormat(InputStreamReader isr, OutputStreamWriter osw) throws MaltChainedException { 565 try { 566 final BufferedReader in = new BufferedReader(isr); 567 final BufferedWriter out = new BufferedWriter(osw); 568 final FeatureList featureSet = new FeatureList(); 569 while(true) { 570 String line = in.readLine(); 571 if(line == null) break; 572 int y = binariesInstance(line, featureSet); 573 if (y == -1) { 574 continue; 575 } 576 out.write(Integer.toString(y)); 577 578 for (int k=0; k < featureSet.size(); k++) { 579 MaltFeatureNode x = featureSet.get(k); 580 out.write(' '); 581 out.write(Integer.toString(x.getIndex())); 582 out.write(':'); 583 out.write(Double.toString(x.getValue())); 584 } 585 out.write('\n'); 586 } 587 in.close(); 588 out.close(); 589 } catch (NumberFormatException e) { 590 throw new LibException("The instance file contain a non-numeric value", e); 591 } catch (IOException e) { 592 throw new LibException("Couln't read from the instance file, when converting the Malt instances into LIBSV/LIBLINEAR format. ", e); 593 } 594 } 595 596 protected void saveFeatureMap(OutputStream os, FeatureMap map) throws MaltChainedException { 597 try { 598 ObjectOutputStream output = new ObjectOutputStream(os); 599 try{ 600 output.writeObject(map); 601 } 602 finally{ 603 output.close(); 604 } 605 } catch (IOException e) { 606 throw new LibException("Save feature map error", e); 607 } 608 } 609 610 protected FeatureMap loadFeatureMap(InputStream is) throws MaltChainedException { 611 FeatureMap map = new FeatureMap(); 612 try { 613 ObjectInputStream input = new ObjectInputStream(is); 614 try { 615 map = (FeatureMap)input.readObject(); 616 } finally { 617 input.close(); 618 } 619 } catch (ClassNotFoundException e) { 620 throw new LibException("Load feature map error", e); 621 } catch (IOException e) { 622 throw new LibException("Load feature map error", e); 623 } 624 return map; 625 } 626 }