001package org.maltparser.core.lw.graph;
002
003import java.util.ArrayList;
004import java.util.Collections;
005import java.util.List;
006import java.util.SortedMap;
007import java.util.SortedSet;
008import java.util.TreeSet;
009
010
011import org.maltparser.concurrent.graph.dataformat.ColumnDescription;
012import org.maltparser.concurrent.graph.dataformat.DataFormat;
013import org.maltparser.core.exception.MaltChainedException;
014import org.maltparser.core.symbol.SymbolTable;
015import org.maltparser.core.symbol.SymbolTableHandler;
016import org.maltparser.core.syntaxgraph.DependencyStructure;
017import org.maltparser.core.syntaxgraph.Element;
018import org.maltparser.core.syntaxgraph.LabelSet;
019import org.maltparser.core.syntaxgraph.RootLabels;
020import org.maltparser.core.syntaxgraph.edge.Edge;
021import org.maltparser.core.syntaxgraph.node.ComparableNode;
022import org.maltparser.core.syntaxgraph.node.DependencyNode;
023import org.maltparser.core.syntaxgraph.node.TokenNode;
024
025/**
026* A lightweight version of org.maltparser.core.syntaxgraph.DependencyGraph. 
027* 
028* @author Johan Hall
029*/
030public final class LWDependencyGraph implements DependencyStructure {
031        private static final String TAB_SIGN = "\t";
032        
033        private final DataFormat dataFormat;
034        private final SymbolTableHandler symbolTables;
035        private final RootLabels rootLabels;
036        private final List<LWNode> nodes;
037        
038        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables) throws MaltChainedException {
039                this.dataFormat = _dataFormat;
040                this.symbolTables = _symbolTables;
041                this.rootLabels = new RootLabels();
042                this.nodes = new ArrayList<LWNode>();
043                this.nodes.add(new LWNode(this, 0)); // ROOT
044        }
045        
046        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables, String[] inputTokens, String defaultRootLabel) throws MaltChainedException {
047                this.dataFormat = _dataFormat;
048                this.symbolTables = _symbolTables;
049                this.rootLabels = new RootLabels();
050                this.nodes = new ArrayList<LWNode>(inputTokens.length+1);
051                
052                // Add nodes
053                nodes.add(new LWNode(this, 0)); // ROOT
054                for (int i = 0; i < inputTokens.length; i++) {
055                        nodes.add(new LWNode(this, i+1));
056                }
057                
058                for (int i = 0; i < inputTokens.length; i++) {
059                        nodes.get(i+1).addColumnLabels(inputTokens[i].split(TAB_SIGN));
060                }
061                // Check graph
062                for (int i = 0; i < nodes.size(); i++) {
063                        if (nodes.get(i).getHeadIndex() >= nodes.size()) {
064                                throw new LWGraphException("Not allowed to add a head node that doesn't exists");
065                        }
066                }
067                
068                for (int i = 0; i < dataFormat.numberOfColumns(); i++) {
069                        ColumnDescription column = dataFormat.getColumnDescription(i);
070                        if (!column.isInternal() && column.getCategory() == ColumnDescription.DEPENDENCY_EDGE_LABEL) {
071                                rootLabels.setDefaultRootLabel(symbolTables.getSymbolTable(column.getName()), defaultRootLabel);
072                        }
073                }
074        }
075        
076        public DataFormat getDataFormat() {
077                return dataFormat;
078        }
079        
080        public LWNode getNode(int nodeIndex) {
081                if (nodeIndex < 0 || nodeIndex >= nodes.size()) {
082                        return null;
083                }
084                return nodes.get(nodeIndex);
085        }
086        
087        public int nNodes() {
088                return nodes.size();
089        }
090        
091        protected boolean hasDependent(int nodeIndex) {
092                for (int i = 1; i < nodes.size(); i++) {
093                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
094                                return true;
095                        }
096                }
097                return false;
098        }
099        
100        protected boolean hasLeftDependent(int nodeIndex) {
101                for (int i = 1; i < nodeIndex; i++) {
102                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
103                                return true;
104                        }
105                }
106                return false;
107        }
108        
109        protected boolean hasRightDependent(int nodeIndex) {
110                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
111                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
112                                return true;
113                        }
114                }
115                return false;
116        }
117        
118        protected List<DependencyNode> getListOfLeftDependents(int nodeIndex) {
119                List<DependencyNode> leftDependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
120                for (int i = 1; i < nodeIndex; i++) {
121                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
122                                leftDependents.add(nodes.get(i));
123                        }
124                }
125                return leftDependents;
126        }
127        
128        protected SortedSet<DependencyNode> getSortedSetOfLeftDependents(int nodeIndex) {
129                SortedSet<DependencyNode> leftDependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
130                for (int i = 1; i < nodeIndex; i++) {
131                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
132                                leftDependents.add(nodes.get(i));
133                        }
134                }
135                return leftDependents;
136        }
137        
138        protected List<DependencyNode> getListOfRightDependents(int nodeIndex) {
139                List<DependencyNode> rightDependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
140                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
141                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
142                                rightDependents.add(nodes.get(i));
143                        }
144                }
145                return rightDependents;
146        }
147        
148        protected SortedSet<DependencyNode> getSortedSetOfRightDependents(int nodeIndex) {
149                SortedSet<DependencyNode> rightDependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
150                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
151                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
152                                rightDependents.add(nodes.get(i));
153                        }
154                }
155                return rightDependents;
156        }
157        
158        protected List<DependencyNode> getListOfDependents(int nodeIndex) {
159                List<DependencyNode> dependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
160                for (int i = 1; i < nodes.size(); i++) {
161                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
162                                dependents.add(nodes.get(i));
163                        }
164                }
165                return dependents;
166        }
167        
168        protected SortedSet<DependencyNode> getSortedSetOfDependents(int nodeIndex) {
169                SortedSet<DependencyNode> dependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
170                for (int i = 1; i < nodes.size(); i++) {
171                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
172                                dependents.add(nodes.get(i));
173                        }
174                }
175                return dependents;
176        }
177        
178        protected int getRank(int nodeIndex) {
179                int[] components = new int[nodes.size()];
180                int[] ranks = new int[nodes.size()];
181                for (int i = 0; i < components.length; i++) {
182                        components[i] = i;
183                        ranks[i] = 0;
184                }
185                for (int i = 1; i < nodes.size(); i++) {
186                        if (nodes.get(i).hasHead()) {
187                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
188                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
189                                if (hcIndex != dcIndex) {
190                                        link(hcIndex, dcIndex, components, ranks);              
191                                }
192                        }
193                }
194                return ranks[nodeIndex];
195        }
196        
197        protected DependencyNode findComponent(int nodeIndex) {
198                int[] components = new int[nodes.size()];
199                int[] ranks = new int[nodes.size()];
200                for (int i = 0; i < components.length; i++) {
201                        components[i] = i;
202                        ranks[i] = 0;
203                }
204                for (int i = 1; i < nodes.size(); i++) {
205                        if (nodes.get(i).hasHead()) {
206                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
207                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
208                                if (hcIndex != dcIndex) {
209                                        link(hcIndex, dcIndex, components, ranks);              
210                                }
211                        }
212                }
213                return nodes.get(findComponent(nodeIndex, components));
214        }
215        
216        private int[] findComponents() {
217                int[] components = new int[nodes.size()];
218                int[] ranks = new int[nodes.size()];
219                for (int i = 0; i < components.length; i++) {
220                        components[i] = i;
221                        ranks[i] = 0;
222                }
223                for (int i = 1; i < nodes.size(); i++) {
224                        if (nodes.get(i).hasHead()) {
225                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
226                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
227                                if (hcIndex != dcIndex) {
228                                        link(hcIndex, dcIndex, components, ranks);              
229                                }
230                        }
231                }
232                return components;
233        }
234        
235        private int findComponent(int xIndex, int[] components) {
236                if (xIndex != components[xIndex]) {
237                        components[xIndex] = findComponent(components[xIndex], components);
238                }
239                return components[xIndex];
240        }
241        
242        private int link(int xIndex, int yIndex, int[] components, int[] ranks) {
243                if (ranks[xIndex] > ranks[yIndex]) {  
244                        components[yIndex] = xIndex;
245                } else {
246                        components[xIndex] = yIndex;
247
248                        if (ranks[xIndex] == ranks[yIndex]) {
249                                ranks[yIndex]++;
250                        }
251                        return yIndex;
252                }
253                return xIndex;
254        }
255        
256        @Override
257        public TokenNode addTokenNode() throws MaltChainedException {
258                throw new LWGraphException("Not implemented in the light-weight dependency graph package");
259        }
260
261        @Override
262        public TokenNode addTokenNode(int index) throws MaltChainedException {
263                throw new LWGraphException("Not implemented in the light-weight dependency graph package");
264        }
265
266        @Override
267        public TokenNode getTokenNode(int index) {
268//              throw new LWGraphException("Not implemented in the light-weight dependency graph package");
269                return null;
270        }
271
272        @Override
273        public int nTokenNode() {
274                return nodes.size()-1;
275        }
276
277        @Override
278        public SortedSet<Integer> getTokenIndices() {
279                SortedSet<Integer> indices = Collections.synchronizedSortedSet(new TreeSet<Integer>());
280                for (int i = 1; i < nodes.size(); i++) {
281                        indices.add(i);
282                }
283                return indices;
284        }
285
286        @Override
287        public int getHighestTokenIndex() {
288                return nodes.size()-1;
289        }
290
291        @Override
292        public boolean hasTokens() {
293                return nodes.size() > 1;
294        }
295
296        @Override
297        public int getSentenceID() {
298                return 0;
299        }
300
301        @Override
302        public void setSentenceID(int sentenceID) {     }
303
304        @Override
305        public void clear() throws MaltChainedException {
306                nodes.clear();
307        }
308
309        @Override
310        public SymbolTableHandler getSymbolTables() {
311                return symbolTables;
312        }
313
314        @Override
315        public void setSymbolTables(SymbolTableHandler symbolTables) { }
316
317        @Override
318        public void addLabel(Element element, String labelFunction, String label) throws MaltChainedException {
319                element.addLabel(symbolTables.addSymbolTable(labelFunction), label);
320        }
321        
322        @Override
323        public LabelSet checkOutNewLabelSet() throws MaltChainedException {
324                throw new LWGraphException("Not implemented in light-weight dependency graph");
325        }
326
327        @Override
328        public void checkInLabelSet(LabelSet labelSet) throws MaltChainedException {
329                throw new LWGraphException("Not implemented in light-weight dependency graph");
330        }
331
332        @Override
333        public Edge addSecondaryEdge(ComparableNode source, ComparableNode target) throws MaltChainedException {
334                throw new LWGraphException("Not implemented in light-weight dependency graph");
335        }
336
337        @Override
338        public void removeSecondaryEdge(ComparableNode source, ComparableNode target)
339                        throws MaltChainedException {
340                throw new LWGraphException("Not implemented in light-weight dependency graph");
341        }
342
343        @Override
344        public DependencyNode addDependencyNode() throws MaltChainedException {
345                LWNode node = new LWNode(this, nodes.size());
346                nodes.add(node);
347                return node;
348        }
349
350        @Override
351        public DependencyNode addDependencyNode(int index) throws MaltChainedException {
352                if (index == 0) {
353                        return nodes.get(0);
354                } else if (index == nodes.size()) {
355                        return addDependencyNode();
356                }
357                throw new LWGraphException("Not implemented in light-weight dependency graph");
358        }
359
360        @Override
361        public DependencyNode getDependencyNode(int index) throws MaltChainedException {
362                if (index < 0 || index >= nodes.size()) {
363                        return null;
364                }
365                return nodes.get(index);
366        }
367
368        @Override
369        public int nDependencyNode() {
370                return nodes.size();
371        }
372
373        @Override
374        public int getHighestDependencyNodeIndex() {
375                return nodes.size()-1;
376        }
377
378        @Override
379        public Edge addDependencyEdge(int headIndex, int dependentIndex) throws MaltChainedException {
380                if (headIndex < 0 && headIndex >= nodes.size()) {
381                        throw new LWGraphException("The head doesn't exists");
382                }
383                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
384                        throw new LWGraphException("The dependent doesn't exists");
385                }
386                LWNode head = nodes.get(headIndex);
387                LWNode dependent = nodes.get(dependentIndex);
388                Edge headEdge = new LWEdge(head, dependent);
389                dependent.addIncomingEdge(headEdge);
390                return headEdge;
391        }
392
393        @Override
394        public Edge moveDependencyEdge(int newHeadIndex, int dependentIndex) throws MaltChainedException {
395                if (newHeadIndex < 0 && newHeadIndex >= nodes.size()) {
396                        throw new LWGraphException("The head doesn't exists");
397                }
398                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
399                        throw new LWGraphException("The dependent doesn't exists");
400                }
401
402                LWNode head = nodes.get(newHeadIndex);
403                LWNode dependent = nodes.get(dependentIndex);
404                Edge oldheadEdge = dependent.getHeadEdge();
405                Edge headEdge = new LWEdge(head, dependent);
406                headEdge.addLabel(oldheadEdge.getLabelSet());
407                dependent.addIncomingEdge(headEdge);
408                return headEdge;
409        }
410
411        @Override
412        public void removeDependencyEdge(int headIndex, int dependentIndex) throws MaltChainedException {
413                if (headIndex < 0 && headIndex >= nodes.size()) {
414                        throw new LWGraphException("The head doesn't exists");
415                }
416                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
417                        throw new LWGraphException("The dependent doesn't exists");
418                }
419                LWNode head = nodes.get(headIndex);
420                LWNode dependent = nodes.get(dependentIndex);
421                Edge headEdge = new LWEdge(head, dependent);
422                dependent.removeIncomingEdge(headEdge);
423        }
424
425        @Override
426        public void linkAllTreesToRoot() throws MaltChainedException {
427                for (int i = 0; i < nodes.size(); i++) {
428                        if (!nodes.get(i).hasHead()) {
429                                LWNode head = nodes.get(0);
430                                LWNode dependent = nodes.get(i);
431                                Edge headEdge = new LWEdge(head, dependent);
432                                headEdge.addLabel(getDefaultRootEdgeLabels());
433                                dependent.addIncomingEdge(headEdge);
434                        }
435                }
436        }
437        
438        @Override
439        public int nEdges() {
440                int n = 0;
441                for (int i = 1; i < nodes.size(); i++) {
442                        if (nodes.get(i).hasHead()) {
443                                n++;
444                        }
445                }
446                return n;
447        }
448
449        @Override
450        public SortedSet<Edge> getEdges() {
451                SortedSet<Edge> edges = Collections.synchronizedSortedSet(new TreeSet<Edge>());
452                for (int i = 1; i < nodes.size(); i++) {
453                        if (nodes.get(i).hasHead()) {
454                                edges.add(nodes.get(i).getHeadEdge());
455                        }
456                }
457                return edges;
458        }
459
460        @Override
461        public SortedSet<Integer> getDependencyIndices() {
462                SortedSet<Integer> indices = Collections.synchronizedSortedSet(new TreeSet<Integer>());
463                for (int i = 0; i < nodes.size(); i++) {
464                        indices.add(i);
465                }
466                return indices;
467        }
468
469        @Override
470        public DependencyNode getDependencyRoot() {
471                return nodes.get(0);
472        }
473
474        @Override
475        public boolean hasLabeledDependency(int index) {
476                if (index < 0 || index >= nodes.size()) {
477                        return false;
478                }
479                if (!nodes.get(index).hasHead()) {
480                        return false;
481                }
482                return nodes.get(index).isHeadLabeled();
483        }
484
485        @Override
486        public boolean isConnected() {
487                int[] components = findComponents();
488                int tmp = components[0];
489                for (int i = 1; i < components.length; i++) {
490                        if (tmp != components[i]) {
491                                return false;
492                        }
493                }
494                return true;
495        }
496
497        @Override
498        public boolean isProjective() throws MaltChainedException {
499                for (int i = 1; i < nodes.size(); i++) {
500                        if (!nodes.get(i).isProjective()) {
501                                return false;
502                        }
503                }
504                return true;
505        }
506
507        @Override
508        public boolean isSingleHeaded() {
509                return true;
510        }
511
512        @Override
513        public boolean isTree() {
514                return isConnected() && isSingleHeaded();
515        }
516
517        @Override
518        public int nNonProjectiveEdges() throws MaltChainedException {
519                int c = 0;
520                for (int i = 1; i < nodes.size(); i++) {
521                        if (!nodes.get(i).isProjective()) {
522                                c++;
523                        }
524                }
525                return c;
526        }
527
528        @Override
529        public LabelSet getDefaultRootEdgeLabels() throws MaltChainedException {
530                return rootLabels.getDefaultRootLabels();
531        }
532        
533        @Override
534        public String getDefaultRootEdgeLabelSymbol(SymbolTable table) throws MaltChainedException {
535                return rootLabels.getDefaultRootLabelSymbol(table);
536        }
537        
538        @Override
539        public int getDefaultRootEdgeLabelCode(SymbolTable table) throws MaltChainedException {
540                return rootLabels.getDefaultRootLabelCode(table);
541        }
542        
543        @Override
544        public void setDefaultRootEdgeLabel(SymbolTable table, String defaultRootSymbol) throws MaltChainedException {
545                rootLabels.setDefaultRootLabel(table, defaultRootSymbol);
546        }
547        
548        @Override
549        public void setDefaultRootEdgeLabels(String rootLabelOption, SortedMap<String, SymbolTable> edgeSymbolTables) throws MaltChainedException {
550                rootLabels.setRootLabels(rootLabelOption, edgeSymbolTables);
551        }
552}