001package org.maltparser.core.lw.graph;
002
003import java.util.ArrayList;
004import java.util.Collections;
005import java.util.List;
006import java.util.SortedMap;
007import java.util.SortedSet;
008import java.util.TreeSet;
009
010
011import org.maltparser.concurrent.graph.dataformat.ColumnDescription;
012import org.maltparser.concurrent.graph.dataformat.DataFormat;
013import org.maltparser.core.exception.MaltChainedException;
014import org.maltparser.core.symbol.SymbolTable;
015import org.maltparser.core.symbol.SymbolTableHandler;
016import org.maltparser.core.syntaxgraph.DependencyStructure;
017import org.maltparser.core.syntaxgraph.Element;
018import org.maltparser.core.syntaxgraph.LabelSet;
019import org.maltparser.core.syntaxgraph.RootLabels;
020import org.maltparser.core.syntaxgraph.edge.Edge;
021import org.maltparser.core.syntaxgraph.node.ComparableNode;
022import org.maltparser.core.syntaxgraph.node.DependencyNode;
023import org.maltparser.core.syntaxgraph.node.TokenNode;
024
025/**
026* A lightweight version of org.maltparser.core.syntaxgraph.DependencyGraph. 
027* 
028* @author Johan Hall
029*/
030public final class LWDependencyGraph implements DependencyStructure {
031        private static final String TAB_SIGN = "\t";
032        
033        private final DataFormat dataFormat;
034        private final SymbolTableHandler symbolTables;
035        private final RootLabels rootLabels;
036        private final List<LWNode> nodes;
037        
038        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables) throws MaltChainedException {
039                this.dataFormat = _dataFormat;
040                this.symbolTables = _symbolTables;
041                this.rootLabels = new RootLabels();
042                this.nodes = new ArrayList<LWNode>();
043                this.nodes.add(new LWNode(this, 0)); // ROOT
044        }
045        
046        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables, String[] inputTokens, String defaultRootLabel) throws MaltChainedException {
047                this(_dataFormat, _symbolTables, inputTokens, defaultRootLabel, true);
048        }
049        
050        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables, String[] inputTokens, String defaultRootLabel, boolean addEdges) throws MaltChainedException {
051                this.dataFormat = _dataFormat;
052                this.symbolTables = _symbolTables;
053                this.rootLabels = new RootLabels();
054                this.nodes = new ArrayList<LWNode>(inputTokens.length+1);
055                
056                // Add nodes
057                nodes.add(new LWNode(this, 0)); // ROOT
058                for (int i = 0; i < inputTokens.length; i++) {
059                        nodes.add(new LWNode(this, i+1));
060                }
061                
062                for (int i = 0; i < inputTokens.length; i++) {
063                        nodes.get(i+1).addColumnLabels(inputTokens[i].split(TAB_SIGN), addEdges);
064                }
065                // Check graph
066                for (int i = 0; i < nodes.size(); i++) {
067                        if (nodes.get(i).getHeadIndex() >= nodes.size()) {
068                                throw new LWGraphException("Not allowed to add a head node that doesn't exists");
069                        }
070                }
071                
072                for (int i = 0; i < dataFormat.numberOfColumns(); i++) {
073                        ColumnDescription column = dataFormat.getColumnDescription(i);
074                        if (!column.isInternal() && column.getCategory() == ColumnDescription.DEPENDENCY_EDGE_LABEL) {
075                                rootLabels.setDefaultRootLabel(symbolTables.getSymbolTable(column.getName()), defaultRootLabel);
076                        }
077                }
078        }
079        
080        public DataFormat getDataFormat() {
081                return dataFormat;
082        }
083        
084        public LWNode getNode(int nodeIndex) {
085                if (nodeIndex < 0 || nodeIndex >= nodes.size()) {
086                        return null;
087                }
088                return nodes.get(nodeIndex);
089        }
090        
091        public int nNodes() {
092                return nodes.size();
093        }
094        
095        protected boolean hasDependent(int nodeIndex) {
096                for (int i = 1; i < nodes.size(); i++) {
097                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
098                                return true;
099                        }
100                }
101                return false;
102        }
103        
104        protected boolean hasLeftDependent(int nodeIndex) {
105                for (int i = 1; i < nodeIndex; i++) {
106                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
107                                return true;
108                        }
109                }
110                return false;
111        }
112        
113        protected boolean hasRightDependent(int nodeIndex) {
114                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
115                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
116                                return true;
117                        }
118                }
119                return false;
120        }
121        
122        protected List<DependencyNode> getListOfLeftDependents(int nodeIndex) {
123                List<DependencyNode> leftDependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
124                for (int i = 1; i < nodeIndex; i++) {
125                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
126                                leftDependents.add(nodes.get(i));
127                        }
128                }
129                return leftDependents;
130        }
131        
132        protected SortedSet<DependencyNode> getSortedSetOfLeftDependents(int nodeIndex) {
133                SortedSet<DependencyNode> leftDependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
134                for (int i = 1; i < nodeIndex; i++) {
135                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
136                                leftDependents.add(nodes.get(i));
137                        }
138                }
139                return leftDependents;
140        }
141        
142        protected List<DependencyNode> getListOfRightDependents(int nodeIndex) {
143                List<DependencyNode> rightDependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
144                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
145                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
146                                rightDependents.add(nodes.get(i));
147                        }
148                }
149                return rightDependents;
150        }
151        
152        protected SortedSet<DependencyNode> getSortedSetOfRightDependents(int nodeIndex) {
153                SortedSet<DependencyNode> rightDependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
154                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
155                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
156                                rightDependents.add(nodes.get(i));
157                        }
158                }
159                return rightDependents;
160        }
161        
162        protected List<DependencyNode> getListOfDependents(int nodeIndex) {
163                List<DependencyNode> dependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
164                for (int i = 1; i < nodes.size(); i++) {
165                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
166                                dependents.add(nodes.get(i));
167                        }
168                }
169                return dependents;
170        }
171        
172        protected SortedSet<DependencyNode> getSortedSetOfDependents(int nodeIndex) {
173                SortedSet<DependencyNode> dependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
174                for (int i = 1; i < nodes.size(); i++) {
175                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
176                                dependents.add(nodes.get(i));
177                        }
178                }
179                return dependents;
180        }
181        
182        protected int getRank(int nodeIndex) {
183                int[] components = new int[nodes.size()];
184                int[] ranks = new int[nodes.size()];
185                for (int i = 0; i < components.length; i++) {
186                        components[i] = i;
187                        ranks[i] = 0;
188                }
189                for (int i = 1; i < nodes.size(); i++) {
190                        if (nodes.get(i).hasHead()) {
191                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
192                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
193                                if (hcIndex != dcIndex) {
194                                        link(hcIndex, dcIndex, components, ranks);              
195                                }
196                        }
197                }
198                return ranks[nodeIndex];
199        }
200        
201        protected DependencyNode findComponent(int nodeIndex) {
202                int[] components = new int[nodes.size()];
203                int[] ranks = new int[nodes.size()];
204                for (int i = 0; i < components.length; i++) {
205                        components[i] = i;
206                        ranks[i] = 0;
207                }
208                for (int i = 1; i < nodes.size(); i++) {
209                        if (nodes.get(i).hasHead()) {
210                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
211                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
212                                if (hcIndex != dcIndex) {
213                                        link(hcIndex, dcIndex, components, ranks);              
214                                }
215                        }
216                }
217                return nodes.get(findComponent(nodeIndex, components));
218        }
219        
220        private int[] findComponents() {
221                int[] components = new int[nodes.size()];
222                int[] ranks = new int[nodes.size()];
223                for (int i = 0; i < components.length; i++) {
224                        components[i] = i;
225                        ranks[i] = 0;
226                }
227                for (int i = 1; i < nodes.size(); i++) {
228                        if (nodes.get(i).hasHead()) {
229                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
230                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
231                                if (hcIndex != dcIndex) {
232                                        link(hcIndex, dcIndex, components, ranks);              
233                                }
234                        }
235                }
236                return components;
237        }
238        
239        private int findComponent(int xIndex, int[] components) {
240                if (xIndex != components[xIndex]) {
241                        components[xIndex] = findComponent(components[xIndex], components);
242                }
243                return components[xIndex];
244        }
245        
246        private int link(int xIndex, int yIndex, int[] components, int[] ranks) {
247                if (ranks[xIndex] > ranks[yIndex]) {  
248                        components[yIndex] = xIndex;
249                } else {
250                        components[xIndex] = yIndex;
251
252                        if (ranks[xIndex] == ranks[yIndex]) {
253                                ranks[yIndex]++;
254                        }
255                        return yIndex;
256                }
257                return xIndex;
258        }
259        
260        @Override
261        public TokenNode addTokenNode() throws MaltChainedException {
262                throw new LWGraphException("Not implemented in the light-weight dependency graph package");
263        }
264
265        @Override
266        public TokenNode addTokenNode(int index) throws MaltChainedException {
267                throw new LWGraphException("Not implemented in the light-weight dependency graph package");
268        }
269
270        @Override
271        public TokenNode getTokenNode(int index) {
272//              throw new LWGraphException("Not implemented in the light-weight dependency graph package");
273                return null;
274        }
275
276        @Override
277        public int nTokenNode() {
278                return nodes.size()-1;
279        }
280
281        @Override
282        public SortedSet<Integer> getTokenIndices() {
283                SortedSet<Integer> indices = Collections.synchronizedSortedSet(new TreeSet<Integer>());
284                for (int i = 1; i < nodes.size(); i++) {
285                        indices.add(i);
286                }
287                return indices;
288        }
289
290        @Override
291        public int getHighestTokenIndex() {
292                return nodes.size()-1;
293        }
294
295        @Override
296        public boolean hasTokens() {
297                return nodes.size() > 1;
298        }
299
300        @Override
301        public int getSentenceID() {
302                return 0;
303        }
304
305        @Override
306        public void setSentenceID(int sentenceID) {     }
307
308        @Override
309        public void clear() throws MaltChainedException {
310                nodes.clear();
311        }
312
313        @Override
314        public SymbolTableHandler getSymbolTables() {
315                return symbolTables;
316        }
317
318        @Override
319        public void setSymbolTables(SymbolTableHandler symbolTables) { }
320
321        @Override
322        public void addLabel(Element element, String labelFunction, String label) throws MaltChainedException {
323                element.addLabel(symbolTables.addSymbolTable(labelFunction), label);
324        }
325        
326        @Override
327        public LabelSet checkOutNewLabelSet() throws MaltChainedException {
328                throw new LWGraphException("Not implemented in light-weight dependency graph");
329        }
330
331        @Override
332        public void checkInLabelSet(LabelSet labelSet) throws MaltChainedException {
333                throw new LWGraphException("Not implemented in light-weight dependency graph");
334        }
335
336        @Override
337        public Edge addSecondaryEdge(ComparableNode source, ComparableNode target) throws MaltChainedException {
338                throw new LWGraphException("Not implemented in light-weight dependency graph");
339        }
340
341        @Override
342        public void removeSecondaryEdge(ComparableNode source, ComparableNode target)
343                        throws MaltChainedException {
344                throw new LWGraphException("Not implemented in light-weight dependency graph");
345        }
346
347        @Override
348        public DependencyNode addDependencyNode() throws MaltChainedException {
349                LWNode node = new LWNode(this, nodes.size());
350                nodes.add(node);
351                return node;
352        }
353
354        @Override
355        public DependencyNode addDependencyNode(int index) throws MaltChainedException {
356                if (index == 0) {
357                        return nodes.get(0);
358                } else if (index == nodes.size()) {
359                        return addDependencyNode();
360                }
361                throw new LWGraphException("Not implemented in light-weight dependency graph");
362        }
363
364        @Override
365        public DependencyNode getDependencyNode(int index) throws MaltChainedException {
366                if (index < 0 || index >= nodes.size()) {
367                        return null;
368                }
369                return nodes.get(index);
370        }
371
372        @Override
373        public int nDependencyNode() {
374                return nodes.size();
375        }
376
377        @Override
378        public int getHighestDependencyNodeIndex() {
379                return nodes.size()-1;
380        }
381
382        @Override
383        public Edge addDependencyEdge(int headIndex, int dependentIndex) throws MaltChainedException {
384                if (headIndex < 0 && headIndex >= nodes.size()) {
385                        throw new LWGraphException("The head doesn't exists");
386                }
387                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
388                        throw new LWGraphException("The dependent doesn't exists");
389                }
390                LWNode head = nodes.get(headIndex);
391                LWNode dependent = nodes.get(dependentIndex);
392                Edge headEdge = new LWEdge(head, dependent);
393                dependent.addIncomingEdge(headEdge);
394                return headEdge;
395        }
396
397        @Override
398        public Edge moveDependencyEdge(int newHeadIndex, int dependentIndex) throws MaltChainedException {
399                if (newHeadIndex < 0 && newHeadIndex >= nodes.size()) {
400                        throw new LWGraphException("The head doesn't exists");
401                }
402                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
403                        throw new LWGraphException("The dependent doesn't exists");
404                }
405
406                LWNode head = nodes.get(newHeadIndex);
407                LWNode dependent = nodes.get(dependentIndex);
408                Edge oldheadEdge = dependent.getHeadEdge();
409                Edge headEdge = new LWEdge(head, dependent);
410                headEdge.addLabel(oldheadEdge.getLabelSet());
411                dependent.addIncomingEdge(headEdge);
412                return headEdge;
413        }
414
415        @Override
416        public void removeDependencyEdge(int headIndex, int dependentIndex) throws MaltChainedException {
417                if (headIndex < 0 && headIndex >= nodes.size()) {
418                        throw new LWGraphException("The head doesn't exists");
419                }
420                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
421                        throw new LWGraphException("The dependent doesn't exists");
422                }
423                LWNode head = nodes.get(headIndex);
424                LWNode dependent = nodes.get(dependentIndex);
425                Edge headEdge = new LWEdge(head, dependent);
426                dependent.removeIncomingEdge(headEdge);
427        }
428
429        @Override
430        public void linkAllTreesToRoot() throws MaltChainedException {
431                for (int i = 0; i < nodes.size(); i++) {
432                        if (!nodes.get(i).hasHead()) {
433                                LWNode head = nodes.get(0);
434                                LWNode dependent = nodes.get(i);
435                                Edge headEdge = new LWEdge(head, dependent);
436                                headEdge.addLabel(getDefaultRootEdgeLabels());
437                                dependent.addIncomingEdge(headEdge);
438                        }
439                }
440        }
441        
442        @Override
443        public int nEdges() {
444                int n = 0;
445                for (int i = 1; i < nodes.size(); i++) {
446                        if (nodes.get(i).hasHead()) {
447                                n++;
448                        }
449                }
450                return n;
451        }
452
453        @Override
454        public SortedSet<Edge> getEdges() {
455                SortedSet<Edge> edges = Collections.synchronizedSortedSet(new TreeSet<Edge>());
456                for (int i = 1; i < nodes.size(); i++) {
457                        if (nodes.get(i).hasHead()) {
458                                edges.add(nodes.get(i).getHeadEdge());
459                        }
460                }
461                return edges;
462        }
463
464        @Override
465        public SortedSet<Integer> getDependencyIndices() {
466                SortedSet<Integer> indices = Collections.synchronizedSortedSet(new TreeSet<Integer>());
467                for (int i = 0; i < nodes.size(); i++) {
468                        indices.add(i);
469                }
470                return indices;
471        }
472
473        @Override
474        public DependencyNode getDependencyRoot() {
475                return nodes.get(0);
476        }
477
478        @Override
479        public boolean hasLabeledDependency(int index) {
480                if (index < 0 || index >= nodes.size()) {
481                        return false;
482                }
483                if (!nodes.get(index).hasHead()) {
484                        return false;
485                }
486                return nodes.get(index).isHeadLabeled();
487        }
488
489        @Override
490        public boolean isConnected() {
491                int[] components = findComponents();
492                int tmp = components[0];
493                for (int i = 1; i < components.length; i++) {
494                        if (tmp != components[i]) {
495                                return false;
496                        }
497                }
498                return true;
499        }
500
501        @Override
502        public boolean isProjective() throws MaltChainedException {
503                for (int i = 1; i < nodes.size(); i++) {
504                        if (!nodes.get(i).isProjective()) {
505                                return false;
506                        }
507                }
508                return true;
509        }
510
511        @Override
512        public boolean isSingleHeaded() {
513                return true;
514        }
515
516        @Override
517        public boolean isTree() {
518                return isConnected() && isSingleHeaded();
519        }
520
521        @Override
522        public int nNonProjectiveEdges() throws MaltChainedException {
523                int c = 0;
524                for (int i = 1; i < nodes.size(); i++) {
525                        if (!nodes.get(i).isProjective()) {
526                                c++;
527                        }
528                }
529                return c;
530        }
531
532        @Override
533        public LabelSet getDefaultRootEdgeLabels() throws MaltChainedException {
534                return rootLabels.getDefaultRootLabels();
535        }
536        
537        @Override
538        public String getDefaultRootEdgeLabelSymbol(SymbolTable table) throws MaltChainedException {
539                return rootLabels.getDefaultRootLabelSymbol(table);
540        }
541        
542        @Override
543        public int getDefaultRootEdgeLabelCode(SymbolTable table) throws MaltChainedException {
544                return rootLabels.getDefaultRootLabelCode(table);
545        }
546        
547        @Override
548        public void setDefaultRootEdgeLabel(SymbolTable table, String defaultRootSymbol) throws MaltChainedException {
549                rootLabels.setDefaultRootLabel(table, defaultRootSymbol);
550        }
551        
552        @Override
553        public void setDefaultRootEdgeLabels(String rootLabelOption, SortedMap<String, SymbolTable> edgeSymbolTables) throws MaltChainedException {
554                rootLabels.setRootLabels(rootLabelOption, edgeSymbolTables);
555        }
556        
557        public String toString() {
558                final StringBuilder sb = new StringBuilder();
559                for (LWNode node : nodes) {
560                        sb.append(node.toString().trim());
561                        sb.append('\n');
562                }
563                sb.append('\n');
564                return sb.toString();
565        }
566}