001 package org.maltparser.parser; 002 003 import org.maltparser.core.exception.MaltChainedException; 004 import org.maltparser.core.syntaxgraph.DependencyStructure; 005 import org.maltparser.parser.guide.ClassifierGuide; 006 import org.maltparser.parser.guide.OracleGuide; 007 import org.maltparser.parser.guide.SingleGuide; 008 import org.maltparser.parser.history.GuideHistory; 009 import org.maltparser.parser.history.action.GuideDecision; 010 import org.maltparser.parser.history.action.GuideUserAction; 011 /** 012 * @author Johan Hall 013 * 014 */ 015 public class BatchTrainer extends Trainer { 016 private final OracleGuide oracleGuide; 017 private int parseCount; 018 019 public BatchTrainer(DependencyParserConfig manager) throws MaltChainedException { 020 super(manager); 021 ((SingleMalt)manager).addRegistry(org.maltparser.parser.Algorithm.class, this); 022 setManager(manager); 023 initParserState(1); 024 setGuide(new SingleGuide(manager, (GuideHistory)parserState.getHistory(), ClassifierGuide.GuideMode.BATCH)); 025 oracleGuide = parserState.getFactory().makeOracleGuide(parserState.getHistory()); 026 } 027 028 public DependencyStructure parse(DependencyStructure goldDependencyGraph, DependencyStructure parseDependencyGraph) throws MaltChainedException { 029 parserState.clear(); 030 parserState.initialize(parseDependencyGraph); 031 currentParserConfiguration = parserState.getConfiguration(); 032 parseCount++; 033 if (diagnostics == true) { 034 writeToDiaFile(parseCount + ""); 035 } 036 TransitionSystem transitionSystem = parserState.getTransitionSystem(); 037 while (!parserState.isTerminalState()) { 038 GuideUserAction action = transitionSystem.getDeterministicAction(parserState.getHistory(), currentParserConfiguration); 039 if (action == null) { 040 action = oracleGuide.predict(goldDependencyGraph, currentParserConfiguration); 041 try { 042 classifierGuide.addInstance((GuideDecision)action); 043 } catch (NullPointerException e) { 044 throw new MaltChainedException("The guide cannot be found. ", e); 045 } 046 } else if (diagnostics == true) { 047 writeToDiaFile(" *"); 048 } 049 if (diagnostics == true) { 050 writeToDiaFile(" " + transitionSystem.getActionString(action)); 051 } 052 parserState.apply(action); 053 } 054 copyEdges(currentParserConfiguration.getDependencyGraph(), parseDependencyGraph); 055 parseDependencyGraph.linkAllTreesToRoot(); 056 oracleGuide.finalizeSentence(parseDependencyGraph); 057 if (diagnostics == true) { 058 writeToDiaFile("\n"); 059 } 060 return parseDependencyGraph; 061 } 062 063 public OracleGuide getOracleGuide() { 064 return oracleGuide; 065 } 066 067 public void train() throws MaltChainedException { } 068 public void terminate() throws MaltChainedException { 069 if (diagnostics == true) { 070 closeDiaWriter(); 071 } 072 } 073 }