001package org.maltparser.core.lw.graph;
002
003import java.util.ArrayList;
004import java.util.Collections;
005import java.util.List;
006import java.util.SortedMap;
007import java.util.SortedSet;
008import java.util.TreeSet;
009
010import org.maltparser.concurrent.graph.dataformat.ColumnDescription;
011import org.maltparser.concurrent.graph.dataformat.DataFormat;
012import org.maltparser.core.exception.MaltChainedException;
013import org.maltparser.core.helper.HashMap;
014import org.maltparser.core.symbol.SymbolTable;
015import org.maltparser.core.symbol.SymbolTableHandler;
016import org.maltparser.core.syntaxgraph.DependencyStructure;
017import org.maltparser.core.syntaxgraph.Element;
018import org.maltparser.core.syntaxgraph.LabelSet;
019import org.maltparser.core.syntaxgraph.RootLabels;
020import org.maltparser.core.syntaxgraph.edge.Edge;
021import org.maltparser.core.syntaxgraph.node.ComparableNode;
022import org.maltparser.core.syntaxgraph.node.DependencyNode;
023import org.maltparser.core.syntaxgraph.node.TokenNode;
024
025/**
026* A lightweight version of org.maltparser.core.syntaxgraph.DependencyGraph.
027*
028* @author Johan Hall
029*/
030public final class LWDependencyGraph implements DependencyStructure {
031        private static final String TAB_SIGN = "\t";
032
033        private final DataFormat dataFormat;
034        private final SymbolTableHandler symbolTables;
035        private final RootLabels rootLabels;
036        private final List<LWNode> nodes;
037        private final HashMap<Integer, ArrayList<String>> comments;
038
039        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables) throws MaltChainedException {
040                this.dataFormat = _dataFormat;
041                this.symbolTables = _symbolTables;
042                this.rootLabels = new RootLabels();
043                this.nodes = new ArrayList<LWNode>();
044                this.nodes.add(new LWNode(this, 0)); // ROOT
045                this.comments = new HashMap<Integer, ArrayList<String>>();
046        }
047
048        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables, String[] inputTokens, String defaultRootLabel) throws MaltChainedException {
049                this(_dataFormat, _symbolTables, inputTokens, defaultRootLabel, true);
050        }
051
052        public LWDependencyGraph(DataFormat _dataFormat, SymbolTableHandler _symbolTables, String[] inputTokens, String defaultRootLabel, boolean addEdges) throws MaltChainedException {
053                this.dataFormat = _dataFormat;
054                this.symbolTables = _symbolTables;
055                this.rootLabels = new RootLabels();
056                this.nodes = new ArrayList<LWNode>(inputTokens.length+1);
057                this.comments = new HashMap<Integer, ArrayList<String>>();
058                resetTokens(inputTokens, defaultRootLabel, addEdges);
059        }
060
061        public void resetTokens(String[] inputTokens, String defaultRootLabel, boolean addEdges) throws MaltChainedException {
062                nodes.clear();
063                comments.clear();
064                symbolTables.cleanUp();
065                // Add nodes
066                nodes.add(new LWNode(this, 0)); // ROOT
067                for (int i = 0; i < inputTokens.length; i++) {
068                        nodes.add(new LWNode(this, i+1));
069                }
070
071                for (int i = 0; i < inputTokens.length; i++) {
072                        nodes.get(i+1).addColumnLabels(inputTokens[i].split(TAB_SIGN), addEdges);
073                }
074                // Check graph
075                for (int i = 0; i < nodes.size(); i++) {
076                        if (nodes.get(i).getHeadIndex() >= nodes.size()) {
077                                throw new LWGraphException("Not allowed to add a head node that doesn't exists");
078                        }
079                }
080
081                for (int i = 0; i < dataFormat.numberOfColumns(); i++) {
082                        ColumnDescription column = dataFormat.getColumnDescription(i);
083                        if (!column.isInternal() && column.getCategory() == ColumnDescription.DEPENDENCY_EDGE_LABEL) {
084                                rootLabels.setDefaultRootLabel(symbolTables.getSymbolTable(column.getName()), defaultRootLabel);
085                        }
086                }
087        }
088
089        public DataFormat getDataFormat() {
090                return dataFormat;
091        }
092
093        public LWNode getNode(int nodeIndex) {
094                if (nodeIndex < 0 || nodeIndex >= nodes.size()) {
095                        return null;
096                }
097                return nodes.get(nodeIndex);
098        }
099
100        public int nNodes() {
101                return nodes.size();
102        }
103
104        protected boolean hasDependent(int nodeIndex) {
105                for (int i = 1; i < nodes.size(); i++) {
106                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
107                                return true;
108                        }
109                }
110                return false;
111        }
112
113        protected boolean hasLeftDependent(int nodeIndex) {
114                for (int i = 1; i < nodeIndex; i++) {
115                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
116                                return true;
117                        }
118                }
119                return false;
120        }
121
122        protected boolean hasRightDependent(int nodeIndex) {
123                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
124                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
125                                return true;
126                        }
127                }
128                return false;
129        }
130
131        protected List<DependencyNode> getListOfLeftDependents(int nodeIndex) {
132                List<DependencyNode> leftDependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
133                for (int i = 1; i < nodeIndex; i++) {
134                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
135                                leftDependents.add(nodes.get(i));
136                        }
137                }
138                return leftDependents;
139        }
140
141        protected SortedSet<DependencyNode> getSortedSetOfLeftDependents(int nodeIndex) {
142                SortedSet<DependencyNode> leftDependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
143                for (int i = 1; i < nodeIndex; i++) {
144                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
145                                leftDependents.add(nodes.get(i));
146                        }
147                }
148                return leftDependents;
149        }
150
151        protected List<DependencyNode> getListOfRightDependents(int nodeIndex) {
152                List<DependencyNode> rightDependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
153                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
154                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
155                                rightDependents.add(nodes.get(i));
156                        }
157                }
158                return rightDependents;
159        }
160
161        protected SortedSet<DependencyNode> getSortedSetOfRightDependents(int nodeIndex) {
162                SortedSet<DependencyNode> rightDependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
163                for (int i = nodeIndex + 1; i < nodes.size(); i++) {
164                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
165                                rightDependents.add(nodes.get(i));
166                        }
167                }
168                return rightDependents;
169        }
170
171        protected List<DependencyNode> getListOfDependents(int nodeIndex) {
172                List<DependencyNode> dependents = Collections.synchronizedList(new ArrayList<DependencyNode>());
173                for (int i = 1; i < nodes.size(); i++) {
174                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
175                                dependents.add(nodes.get(i));
176                        }
177                }
178                return dependents;
179        }
180
181        protected SortedSet<DependencyNode> getSortedSetOfDependents(int nodeIndex) {
182                SortedSet<DependencyNode> dependents = Collections.synchronizedSortedSet(new TreeSet<DependencyNode>());
183                for (int i = 1; i < nodes.size(); i++) {
184                        if (nodeIndex == nodes.get(i).getHeadIndex()) {
185                                dependents.add(nodes.get(i));
186                        }
187                }
188                return dependents;
189        }
190
191        protected int getRank(int nodeIndex) {
192                int[] components = new int[nodes.size()];
193                int[] ranks = new int[nodes.size()];
194                for (int i = 0; i < components.length; i++) {
195                        components[i] = i;
196                        ranks[i] = 0;
197                }
198                for (int i = 1; i < nodes.size(); i++) {
199                        if (nodes.get(i).hasHead()) {
200                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
201                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
202                                if (hcIndex != dcIndex) {
203                                        link(hcIndex, dcIndex, components, ranks);
204                                }
205                        }
206                }
207                return ranks[nodeIndex];
208        }
209
210        protected DependencyNode findComponent(int nodeIndex) {
211                int[] components = new int[nodes.size()];
212                int[] ranks = new int[nodes.size()];
213                for (int i = 0; i < components.length; i++) {
214                        components[i] = i;
215                        ranks[i] = 0;
216                }
217                for (int i = 1; i < nodes.size(); i++) {
218                        if (nodes.get(i).hasHead()) {
219                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
220                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
221                                if (hcIndex != dcIndex) {
222                                        link(hcIndex, dcIndex, components, ranks);
223                                }
224                        }
225                }
226                return nodes.get(findComponent(nodeIndex, components));
227        }
228
229        private int[] findComponents() {
230                int[] components = new int[nodes.size()];
231                int[] ranks = new int[nodes.size()];
232                for (int i = 0; i < components.length; i++) {
233                        components[i] = i;
234                        ranks[i] = 0;
235                }
236                for (int i = 1; i < nodes.size(); i++) {
237                        if (nodes.get(i).hasHead()) {
238                                int hcIndex = findComponent(nodes.get(i).getHead().getIndex(), components);
239                                int dcIndex = findComponent(nodes.get(i).getIndex(), components);
240                                if (hcIndex != dcIndex) {
241                                        link(hcIndex, dcIndex, components, ranks);
242                                }
243                        }
244                }
245                return components;
246        }
247
248        private int findComponent(int xIndex, int[] components) {
249                if (xIndex != components[xIndex]) {
250                        components[xIndex] = findComponent(components[xIndex], components);
251                }
252                return components[xIndex];
253        }
254
255        private int link(int xIndex, int yIndex, int[] components, int[] ranks) {
256                if (ranks[xIndex] > ranks[yIndex]) {
257                        components[yIndex] = xIndex;
258                } else {
259                        components[xIndex] = yIndex;
260
261                        if (ranks[xIndex] == ranks[yIndex]) {
262                                ranks[yIndex]++;
263                        }
264                        return yIndex;
265                }
266                return xIndex;
267        }
268
269        @Override
270        public TokenNode addTokenNode() throws MaltChainedException {
271                throw new LWGraphException("Not implemented in the light-weight dependency graph package");
272        }
273
274        @Override
275        public TokenNode addTokenNode(int index) throws MaltChainedException {
276                throw new LWGraphException("Not implemented in the light-weight dependency graph package");
277        }
278
279        @Override
280        public TokenNode getTokenNode(int index) {
281//              throw new LWGraphException("Not implemented in the light-weight dependency graph package");
282                return null;
283        }
284
285
286        @Override
287        public void addComment(String comment, int at_index) {
288                ArrayList<String> commentList = comments.get(at_index);
289                if (commentList == null) {
290                        commentList = comments.put(at_index, new ArrayList<String>());
291                }
292                commentList.add(comment);
293        }
294
295        @Override
296        public ArrayList<String> getComment(int at_index) {
297                return comments.get(at_index);
298        }
299
300        @Override
301        public boolean hasComments() {
302                return comments.size() > 0;
303        }
304
305        @Override
306        public int nTokenNode() {
307                return nodes.size()-1;
308        }
309
310        @Override
311        public SortedSet<Integer> getTokenIndices() {
312                SortedSet<Integer> indices = Collections.synchronizedSortedSet(new TreeSet<Integer>());
313                for (int i = 1; i < nodes.size(); i++) {
314                        indices.add(i);
315                }
316                return indices;
317        }
318
319        @Override
320        public int getHighestTokenIndex() {
321                return nodes.size()-1;
322        }
323
324        @Override
325        public boolean hasTokens() {
326                return nodes.size() > 1;
327        }
328
329        @Override
330        public int getSentenceID() {
331                return 0;
332        }
333
334        @Override
335        public void setSentenceID(int sentenceID) {     }
336
337        @Override
338        public void clear() throws MaltChainedException {
339                nodes.clear();
340        }
341
342        @Override
343        public SymbolTableHandler getSymbolTables() {
344                return symbolTables;
345        }
346
347        @Override
348        public void setSymbolTables(SymbolTableHandler symbolTables) { }
349
350        @Override
351        public void addLabel(Element element, String labelFunction, String label) throws MaltChainedException {
352                element.addLabel(symbolTables.addSymbolTable(labelFunction), label);
353        }
354
355        @Override
356        public LabelSet checkOutNewLabelSet() throws MaltChainedException {
357                throw new LWGraphException("Not implemented in light-weight dependency graph");
358        }
359
360        @Override
361        public void checkInLabelSet(LabelSet labelSet) throws MaltChainedException {
362                throw new LWGraphException("Not implemented in light-weight dependency graph");
363        }
364
365        @Override
366        public Edge addSecondaryEdge(ComparableNode source, ComparableNode target) throws MaltChainedException {
367                throw new LWGraphException("Not implemented in light-weight dependency graph");
368        }
369
370        @Override
371        public void removeSecondaryEdge(ComparableNode source, ComparableNode target)
372                        throws MaltChainedException {
373                throw new LWGraphException("Not implemented in light-weight dependency graph");
374        }
375
376        @Override
377        public DependencyNode addDependencyNode() throws MaltChainedException {
378                LWNode node = new LWNode(this, nodes.size());
379                nodes.add(node);
380                return node;
381        }
382
383        @Override
384        public DependencyNode addDependencyNode(int index) throws MaltChainedException {
385                if (index == 0) {
386                        return nodes.get(0);
387                } else if (index == nodes.size()) {
388                        return addDependencyNode();
389                }
390                throw new LWGraphException("Not implemented in light-weight dependency graph");
391        }
392
393        @Override
394        public DependencyNode getDependencyNode(int index) throws MaltChainedException {
395                if (index < 0 || index >= nodes.size()) {
396                        return null;
397                }
398                return nodes.get(index);
399        }
400
401        @Override
402        public int nDependencyNode() {
403                return nodes.size();
404        }
405
406        @Override
407        public int getHighestDependencyNodeIndex() {
408                return nodes.size()-1;
409        }
410
411        @Override
412        public Edge addDependencyEdge(int headIndex, int dependentIndex) throws MaltChainedException {
413                if (headIndex < 0 && headIndex >= nodes.size()) {
414                        throw new LWGraphException("The head doesn't exists");
415                }
416                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
417                        throw new LWGraphException("The dependent doesn't exists");
418                }
419                LWNode head = nodes.get(headIndex);
420                LWNode dependent = nodes.get(dependentIndex);
421                Edge headEdge = new LWEdge(head, dependent);
422                dependent.addIncomingEdge(headEdge);
423                return headEdge;
424        }
425
426        @Override
427        public Edge moveDependencyEdge(int newHeadIndex, int dependentIndex) throws MaltChainedException {
428                if (newHeadIndex < 0 && newHeadIndex >= nodes.size()) {
429                        throw new LWGraphException("The head doesn't exists");
430                }
431                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
432                        throw new LWGraphException("The dependent doesn't exists");
433                }
434
435                LWNode head = nodes.get(newHeadIndex);
436                LWNode dependent = nodes.get(dependentIndex);
437                Edge oldheadEdge = dependent.getHeadEdge();
438                Edge headEdge = new LWEdge(head, dependent);
439                headEdge.addLabel(oldheadEdge.getLabelSet());
440                dependent.addIncomingEdge(headEdge);
441                return headEdge;
442        }
443
444        @Override
445        public void removeDependencyEdge(int headIndex, int dependentIndex) throws MaltChainedException {
446                if (headIndex < 0 && headIndex >= nodes.size()) {
447                        throw new LWGraphException("The head doesn't exists");
448                }
449                if (dependentIndex < 0 && dependentIndex >= nodes.size()) {
450                        throw new LWGraphException("The dependent doesn't exists");
451                }
452                LWNode head = nodes.get(headIndex);
453                LWNode dependent = nodes.get(dependentIndex);
454                Edge headEdge = new LWEdge(head, dependent);
455                dependent.removeIncomingEdge(headEdge);
456        }
457
458        @Override
459        public void linkAllTreesToRoot() throws MaltChainedException {
460                for (int i = 0; i < nodes.size(); i++) {
461                        if (!nodes.get(i).hasHead()) {
462                                LWNode head = nodes.get(0);
463                                LWNode dependent = nodes.get(i);
464                                Edge headEdge = new LWEdge(head, dependent);
465                                headEdge.addLabel(getDefaultRootEdgeLabels());
466                                dependent.addIncomingEdge(headEdge);
467                        }
468                }
469        }
470
471        @Override
472        public int nEdges() {
473                int n = 0;
474                for (int i = 1; i < nodes.size(); i++) {
475                        if (nodes.get(i).hasHead()) {
476                                n++;
477                        }
478                }
479                return n;
480        }
481
482        @Override
483        public SortedSet<Edge> getEdges() {
484                SortedSet<Edge> edges = Collections.synchronizedSortedSet(new TreeSet<Edge>());
485                for (int i = 1; i < nodes.size(); i++) {
486                        if (nodes.get(i).hasHead()) {
487                                edges.add(nodes.get(i).getHeadEdge());
488                        }
489                }
490                return edges;
491        }
492
493        @Override
494        public SortedSet<Integer> getDependencyIndices() {
495                SortedSet<Integer> indices = Collections.synchronizedSortedSet(new TreeSet<Integer>());
496                for (int i = 0; i < nodes.size(); i++) {
497                        indices.add(i);
498                }
499                return indices;
500        }
501
502        @Override
503        public DependencyNode getDependencyRoot() {
504                return nodes.get(0);
505        }
506
507        @Override
508        public boolean hasLabeledDependency(int index) {
509                if (index < 0 || index >= nodes.size()) {
510                        return false;
511                }
512                if (!nodes.get(index).hasHead()) {
513                        return false;
514                }
515                return nodes.get(index).isHeadLabeled();
516        }
517
518        @Override
519        public boolean isConnected() {
520                int[] components = findComponents();
521                int tmp = components[0];
522                for (int i = 1; i < components.length; i++) {
523                        if (tmp != components[i]) {
524                                return false;
525                        }
526                }
527                return true;
528        }
529
530        @Override
531        public boolean isProjective() throws MaltChainedException {
532                for (int i = 1; i < nodes.size(); i++) {
533                        if (!nodes.get(i).isProjective()) {
534                                return false;
535                        }
536                }
537                return true;
538        }
539
540        @Override
541        public boolean isSingleHeaded() {
542                return true;
543        }
544
545        @Override
546        public boolean isTree() {
547                return isConnected() && isSingleHeaded();
548        }
549
550        @Override
551        public int nNonProjectiveEdges() throws MaltChainedException {
552                int c = 0;
553                for (int i = 1; i < nodes.size(); i++) {
554                        if (!nodes.get(i).isProjective()) {
555                                c++;
556                        }
557                }
558                return c;
559        }
560
561        @Override
562        public LabelSet getDefaultRootEdgeLabels() throws MaltChainedException {
563                return rootLabels.getDefaultRootLabels();
564        }
565
566        @Override
567        public String getDefaultRootEdgeLabelSymbol(SymbolTable table) throws MaltChainedException {
568                return rootLabels.getDefaultRootLabelSymbol(table);
569        }
570
571        @Override
572        public int getDefaultRootEdgeLabelCode(SymbolTable table) throws MaltChainedException {
573                return rootLabels.getDefaultRootLabelCode(table);
574        }
575
576        @Override
577        public void setDefaultRootEdgeLabel(SymbolTable table, String defaultRootSymbol) throws MaltChainedException {
578                rootLabels.setDefaultRootLabel(table, defaultRootSymbol);
579        }
580
581        @Override
582        public void setDefaultRootEdgeLabels(String rootLabelOption, SortedMap<String, SymbolTable> edgeSymbolTables) throws MaltChainedException {
583                rootLabels.setRootLabels(rootLabelOption, edgeSymbolTables);
584        }
585
586        public String toString() {
587                final StringBuilder sb = new StringBuilder();
588                for (LWNode node : nodes) {
589                        sb.append(node.toString().trim());
590                        sb.append('\n');
591                }
592                sb.append('\n');
593                return sb.toString();
594        }
595}