001    package org.maltparser.parser.algorithm.nivre;
002    
003    import java.util.Stack;
004    
005    import org.maltparser.core.exception.MaltChainedException;
006    import org.maltparser.core.symbol.SymbolTable;
007    import org.maltparser.core.symbol.SymbolTableHandler;
008    import org.maltparser.core.syntaxgraph.DependencyGraph;
009    import org.maltparser.core.syntaxgraph.DependencyStructure;
010    import org.maltparser.core.syntaxgraph.edge.Edge;
011    import org.maltparser.core.syntaxgraph.node.DependencyNode;
012    import org.maltparser.parser.ParserConfiguration;
013    import org.maltparser.parser.ParsingException;
014    /**
015     * @author Johan Hall
016     *
017     */
018    public class NivreConfig extends ParserConfiguration {
019            private final Stack<DependencyNode> stack;
020            private final Stack<DependencyNode> input;
021            private final DependencyStructure dependencyGraph;
022    
023            private boolean allowRoot;
024            private boolean allowReduce;
025            
026            public NivreConfig(SymbolTableHandler symbolTableHandler, boolean allowRoot, boolean allowReduce) throws MaltChainedException {
027                    super();
028                    stack = new Stack<DependencyNode>();
029                    input = new Stack<DependencyNode>();
030                    dependencyGraph = new DependencyGraph(symbolTableHandler);
031                    setAllowRoot(allowRoot);
032                    setAllowReduce(allowReduce);
033            }
034            
035            public Stack<DependencyNode> getStack() {
036                    return stack;
037            }
038            
039            public Stack<DependencyNode> getInput() {
040                    return input;
041            }
042            
043            public DependencyStructure getDependencyStructure() {
044                    return dependencyGraph;
045            }
046            
047            public boolean isTerminalState() {
048                    return input.isEmpty();
049            }
050            
051            public DependencyNode getStackNode(int index) throws MaltChainedException {
052                    if (index < 0) {
053                            throw new ParsingException("Stack index must be non-negative in feature specification. ");
054                    }
055                    if (stack.size()-index > 0) {
056                            return stack.get(stack.size()-1-index);
057                    }
058                    return null;
059            }
060            
061            public DependencyNode getInputNode(int index) throws MaltChainedException {
062                    if (index < 0) {
063                            throw new ParsingException("Input index must be non-negative in feature specification. ");
064                    }
065                    if (input.size()-index > 0) {
066                            return input.get(input.size()-1-index);
067                    }       
068                    return null;
069            }
070            
071            public void setDependencyGraph(DependencyStructure source) throws MaltChainedException {
072                    dependencyGraph.clear();
073                    for (int index : source.getTokenIndices()) {
074                            final DependencyNode gnode = source.getTokenNode(index);
075                            final DependencyNode pnode = dependencyGraph.addTokenNode(gnode.getIndex());
076                            for (SymbolTable table : gnode.getLabelTypes()) {
077                                    pnode.addLabel(table, gnode.getLabelSymbol(table));
078                            }
079                            
080                            if (gnode.hasHead()) {
081                                    final Edge s = gnode.getHeadEdge();
082                                    final Edge t = dependencyGraph.addDependencyEdge(s.getSource().getIndex(), s.getTarget().getIndex());
083                                    
084                                    for (SymbolTable table : s.getLabelTypes()) {
085                                            t.addLabel(table, s.getLabelSymbol(table));
086                                    }
087                            }
088                    }
089                    for (SymbolTable table : source.getDefaultRootEdgeLabels().keySet()) {
090                            dependencyGraph.setDefaultRootEdgeLabel(table, source.getDefaultRootEdgeLabelSymbol(table));
091                    }
092            }
093            
094            public DependencyStructure getDependencyGraph() {
095                    return dependencyGraph;
096            }
097            
098            public void initialize(ParserConfiguration parserConfiguration) throws MaltChainedException {
099                    if (parserConfiguration != null) {
100                            final NivreConfig nivreConfig = (NivreConfig)parserConfiguration;
101                            final Stack<DependencyNode> sourceStack = nivreConfig.getStack();
102                            final Stack<DependencyNode> sourceInput = nivreConfig.getInput();
103                            setDependencyGraph(nivreConfig.getDependencyGraph());
104                            for (int i = 0, n = sourceStack.size(); i < n; i++) {
105                                    stack.add(dependencyGraph.getDependencyNode(sourceStack.get(i).getIndex()));
106                            }
107                            for (int i = 0, n = sourceInput.size(); i < n; i++) {
108                                    input.add(dependencyGraph.getDependencyNode(sourceInput.get(i).getIndex()));
109                            }
110                    } else {
111                            stack.push(dependencyGraph.getDependencyRoot());
112                            for (int i = dependencyGraph.getHighestTokenIndex(); i > 0; i--) {
113                                    final DependencyNode node = dependencyGraph.getDependencyNode(i);
114                                    if (node != null && !node.hasHead()) { // added !node.hasHead()
115                                            input.push(node);
116                                    }
117                            }
118                    }
119            }
120            
121        public boolean isAllowRoot() {
122            return allowRoot;
123            }
124            
125            public void setAllowRoot(boolean allowRoot) {
126                    this.allowRoot = allowRoot;
127            }
128            
129            public boolean isAllowReduce() {
130                    return allowReduce;
131            }
132            
133            public void setAllowReduce(boolean allowReduce) {
134                    this.allowReduce = allowReduce;
135            }
136            
137            public void clear() throws MaltChainedException {
138                    dependencyGraph.clear();
139                    stack.clear();
140                    input.clear();
141                    historyNode = null;
142            }
143            
144            public boolean equals(Object obj) {
145                    if (this == obj)
146                            return true;
147                    if (obj == null)
148                            return false;
149                    if (getClass() != obj.getClass())
150                            return false;
151                    NivreConfig that = (NivreConfig)obj;
152                    
153                    if (stack.size() != that.getStack().size()) 
154                            return false;
155                    if (input.size() != that.getInput().size())
156                            return false;
157                    if (dependencyGraph.nEdges() != that.getDependencyGraph().nEdges())
158                            return false;
159                    for (int i = 0; i < stack.size(); i++) {
160                            if (stack.get(i).getIndex() != that.getStack().get(i).getIndex()) {
161                                    return false;
162                            }
163                    }
164                    for (int i = 0; i < input.size(); i++) {
165                            if (input.get(i).getIndex() != that.getInput().get(i).getIndex()) {
166                                    return false;
167                            }
168                    }               
169                    return dependencyGraph.getEdges().equals(that.getDependencyGraph().getEdges());
170            }
171            
172            public String toString() {
173                    final StringBuilder sb = new StringBuilder();
174                    sb.append(stack.size());
175                    sb.append(", ");
176                    sb.append(input.size());
177                    sb.append(", ");
178                    sb.append(dependencyGraph.nEdges());
179                    return sb.toString();
180            }
181    }