001package org.maltparser.parser;
002
003import org.maltparser.core.exception.MaltChainedException;
004import org.maltparser.core.feature.FeatureModel;
005import org.maltparser.core.symbol.SymbolTableHandler;
006import org.maltparser.core.syntaxgraph.DependencyStructure;
007import org.maltparser.parser.guide.ClassifierGuide;
008import org.maltparser.parser.guide.OracleGuide;
009import org.maltparser.parser.guide.SingleGuide;
010import org.maltparser.parser.history.action.GuideDecision;
011import org.maltparser.parser.history.action.GuideUserAction;
012
013public class BatchTrainerWithDiagnostics extends Trainer {
014        private final Diagnostics diagnostics;
015        private final OracleGuide oracleGuide;
016        private int parseCount;
017        private final FeatureModel featureModel;
018        
019        public BatchTrainerWithDiagnostics(DependencyParserConfig manager, SymbolTableHandler symbolTableHandler) throws MaltChainedException {
020                super(manager,symbolTableHandler);
021                this.diagnostics = new Diagnostics(manager.getOptionValue("singlemalt", "diafile").toString());
022                registry.setAlgorithm(this);
023                setGuide(new SingleGuide(this,  ClassifierGuide.GuideMode.BATCH));
024                String featureModelFileName = manager.getOptionValue("guide", "features").toString().trim();
025                if (manager.isLoggerInfoEnabled()) {
026                        manager.logDebugMessage("  Feature model        : " + featureModelFileName+"\n");
027                        manager.logDebugMessage("  Learner              : " + manager.getOptionValueString("guide", "learner").toString()+"\n");
028                }
029                String dataSplitColumn = manager.getOptionValue("guide", "data_split_column").toString().trim();
030                String dataSplitStructure = manager.getOptionValue("guide", "data_split_structure").toString().trim();
031                this.featureModel = manager.getFeatureModelManager().getFeatureModel(SingleGuide.findURL(featureModelFileName, manager), 0, getParserRegistry(), dataSplitColumn, dataSplitStructure);
032
033                manager.writeInfoToConfigFile("\nFEATURE MODEL\n");
034                manager.writeInfoToConfigFile(featureModel.toString());
035                oracleGuide = parserState.getFactory().makeOracleGuide(parserState.getHistory());
036        }
037        
038        public DependencyStructure parse(DependencyStructure goldDependencyGraph, DependencyStructure parseDependencyGraph) throws MaltChainedException {
039                parserState.clear();
040                parserState.initialize(parseDependencyGraph);
041                currentParserConfiguration = parserState.getConfiguration();
042                parseCount++;
043
044                diagnostics.writeToDiaFile(parseCount + "");
045
046                TransitionSystem transitionSystem = parserState.getTransitionSystem();
047                while (!parserState.isTerminalState()) {
048                        GuideUserAction action = transitionSystem.getDeterministicAction(parserState.getHistory(), currentParserConfiguration);
049                        if (action == null) {
050                                action = oracleGuide.predict(goldDependencyGraph, currentParserConfiguration);
051                                try {
052                                        classifierGuide.addInstance(featureModel,(GuideDecision)action);
053                                } catch (NullPointerException e) {
054                                        throw new MaltChainedException("The guide cannot be found. ", e);
055                                }
056                        } else {
057                                diagnostics.writeToDiaFile(" *");
058                        }
059
060                        diagnostics.writeToDiaFile(" " + transitionSystem.getActionString(action));
061
062                        parserState.apply(action);
063                }
064                copyEdges(currentParserConfiguration.getDependencyGraph(), parseDependencyGraph);
065                parseDependencyGraph.linkAllTreesToRoot();
066                oracleGuide.finalizeSentence(parseDependencyGraph);
067
068                diagnostics.writeToDiaFile("\n");
069
070                return parseDependencyGraph;
071        }
072        
073        public OracleGuide getOracleGuide() {
074                return oracleGuide;
075        }
076        
077        public void train() throws MaltChainedException { }
078        public void terminate() throws MaltChainedException {
079                diagnostics.closeDiaWriter();
080        }
081}