001package org.maltparser.parser.algorithm.stack;
002
003import java.util.ArrayList;
004import java.util.Stack;
005
006import org.maltparser.core.exception.MaltChainedException;
007import org.maltparser.core.syntaxgraph.DependencyStructure;
008import org.maltparser.core.syntaxgraph.node.DependencyNode;
009import org.maltparser.parser.DependencyParserConfig;
010import org.maltparser.parser.Oracle;
011import org.maltparser.parser.ParserConfiguration;
012import org.maltparser.parser.history.GuideUserHistory;
013import org.maltparser.parser.history.action.GuideUserAction;
014/**
015 * @author Johan Hall
016 *
017 */
018public class SwapEagerOracle extends Oracle {
019        private ArrayList<Integer> swapArray;
020        private boolean swapArrayActive = false;
021        
022        public SwapEagerOracle(DependencyParserConfig manager, GuideUserHistory history) throws MaltChainedException {
023                super(manager, history);
024                setGuideName("swapeager");
025                swapArray = new ArrayList<Integer>();
026        }
027        
028        public GuideUserAction predict(DependencyStructure gold, ParserConfiguration configuration) throws MaltChainedException {
029                StackConfig config = (StackConfig)configuration;
030                Stack<DependencyNode> stack = config.getStack();
031
032                if (!swapArrayActive) {
033                        createSwapArray(gold);
034                        swapArrayActive = true;
035                }
036                GuideUserAction action = null;
037                if (stack.size() < 2) {
038                        action = updateActionContainers(NonProjective.SHIFT, null);
039                } else {
040                        DependencyNode left = stack.get(stack.size()-2);
041                        int leftIndex = left.getIndex();
042                        int rightIndex = stack.get(stack.size()-1).getIndex();
043                        if (swapArray.get(leftIndex) > swapArray.get(rightIndex)) {
044                                action =  updateActionContainers(NonProjective.SWAP, null);
045                        } else if (!left.isRoot() && gold.getTokenNode(leftIndex).getHead().getIndex() == rightIndex
046                                        && nodeComplete(gold, config.getDependencyGraph(), leftIndex)) {
047                                action = updateActionContainers(NonProjective.LEFTARC, gold.getTokenNode(leftIndex).getHeadEdge().getLabelSet());
048                        } else if (gold.getTokenNode(rightIndex).getHead().getIndex() == leftIndex
049                                        && nodeComplete(gold, config.getDependencyGraph(), rightIndex)) {
050                                action = updateActionContainers(NonProjective.RIGHTARC, gold.getTokenNode(rightIndex).getHeadEdge().getLabelSet());
051                        } else {
052                                action = updateActionContainers(NonProjective.SHIFT, null);
053                        }
054                }
055                return action;
056        }
057        
058        private boolean nodeComplete(DependencyStructure gold, DependencyStructure parseDependencyGraph, int nodeIndex) {
059                if (gold.getTokenNode(nodeIndex).hasLeftDependent()) {
060                        if (!parseDependencyGraph.getTokenNode(nodeIndex).hasLeftDependent()) {
061                                return false;
062                        } else if (gold.getTokenNode(nodeIndex).getLeftmostDependent().getIndex() != parseDependencyGraph.getTokenNode(nodeIndex).getLeftmostDependent().getIndex()) {
063                                return false;
064                        }
065                }
066                if (gold.getTokenNode(nodeIndex).hasRightDependent()) {
067                        if (!parseDependencyGraph.getTokenNode(nodeIndex).hasRightDependent()) {
068                                return false;
069                        } else if (gold.getTokenNode(nodeIndex).getRightmostDependent().getIndex() != parseDependencyGraph.getTokenNode(nodeIndex).getRightmostDependent().getIndex()) {
070                                return false;
071                        }
072                }
073                return true;
074        }
075        
076//      private boolean checkRightDependent(DependencyStructure gold, DependencyStructure parseDependencyGraph, int index) throws MaltChainedException {
077//              if (gold.getTokenNode(index).getRightmostDependent() == null) {
078//                      return true;
079//              } else if (parseDependencyGraph.getTokenNode(index).getRightmostDependent() != null) {
080//                      if (gold.getTokenNode(index).getRightmostDependent().getIndex() == parseDependencyGraph.getTokenNode(index).getRightmostDependent().getIndex()) {
081//                              return true;
082//                      }
083//              }
084//              return false;
085//      }
086        
087        public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
088                swapArrayActive = false;
089        }
090        
091        public void terminate() throws MaltChainedException {
092        }
093        
094        private void createSwapArray(DependencyStructure goldDependencyGraph) throws MaltChainedException {
095                swapArray.clear();
096                for (int i = 0; i <= goldDependencyGraph.getHighestDependencyNodeIndex(); i++) {
097                        swapArray.add(new Integer(i));
098                }
099                createSwapArray(goldDependencyGraph.getDependencyRoot(), 0);
100        }
101        
102        private int createSwapArray(DependencyNode n, int order) {
103                int o = order; 
104                if (n != null) {
105                        for (int i=0; i < n.getLeftDependentCount(); i++) {
106                                o = createSwapArray(n.getLeftDependent(i), o);
107                        }
108                        swapArray.set(n.getIndex(), o++);
109                        for (int i=n.getRightDependentCount(); i >= 0; i--) {
110                                o = createSwapArray(n.getRightDependent(i), o);
111                        }
112                }
113                return o;
114        }
115}