001package org.maltparser.core.syntaxgraph.writer;
002
003import java.io.BufferedWriter;
004import java.io.FileNotFoundException;
005import java.io.FileOutputStream;
006import java.io.IOException;
007import java.io.OutputStream;
008import java.io.OutputStreamWriter;
009import java.io.UnsupportedEncodingException;
010import java.util.Iterator;
011import java.util.LinkedHashMap;
012import java.util.SortedMap;
013import java.util.TreeMap;
014import java.util.regex.PatternSyntaxException;
015
016import org.maltparser.core.exception.MaltChainedException;
017import org.maltparser.core.io.dataformat.ColumnDescription;
018import org.maltparser.core.io.dataformat.DataFormatException;
019import org.maltparser.core.io.dataformat.DataFormatInstance;
020import org.maltparser.core.symbol.SymbolTable;
021import org.maltparser.core.symbol.SymbolTableHandler;
022import org.maltparser.core.syntaxgraph.PhraseStructure;
023import org.maltparser.core.syntaxgraph.TokenStructure;
024import org.maltparser.core.syntaxgraph.edge.Edge;
025import org.maltparser.core.syntaxgraph.node.NonTerminalNode;
026import org.maltparser.core.syntaxgraph.node.PhraseStructureNode;
027/**
028*
029*
030* @author Johan Hall
031*/
032public class NegraWriter implements SyntaxGraphWriter {
033        private BufferedWriter writer; 
034        private DataFormatInstance dataFormatInstance;
035        private String optionString;
036        private int sentenceCount;
037        private LinkedHashMap<Integer, Integer> nonTerminalIndexMap;
038        private int START_ID_OF_NONTERMINALS = 500;
039        private boolean closeStream = true;
040        
041        public NegraWriter() { 
042                nonTerminalIndexMap = new LinkedHashMap<Integer, Integer>();
043        }
044        
045        public void open(String fileName, String charsetName) throws MaltChainedException {
046                try {
047                        open(new OutputStreamWriter(new FileOutputStream(fileName),charsetName));
048                } catch (FileNotFoundException e) {
049                        throw new DataFormatException("The output file '"+fileName+"' cannot be found.", e);
050                } catch (UnsupportedEncodingException e) {
051                        throw new DataFormatException("The character encoding set '"+charsetName+"' isn't supported.", e);
052                }       
053        }
054        
055        public void open(OutputStream os, String charsetName) throws MaltChainedException {
056                try {
057                        if (os == System.out || os == System.err) {
058                                closeStream = false;
059                        }
060                        open(new OutputStreamWriter(os, charsetName));
061                } catch (UnsupportedEncodingException e) {
062                        throw new DataFormatException("The character encoding set '"+charsetName+"' isn't supported.", e);
063                }
064        }
065        
066        private void open(OutputStreamWriter osw) throws MaltChainedException {
067                setWriter(new BufferedWriter(osw));
068                setSentenceCount(0);
069        }
070        
071        public void writeProlog() throws MaltChainedException { }
072        
073        public void writeSentence(TokenStructure syntaxGraph) throws MaltChainedException {
074                if (syntaxGraph == null || dataFormatInstance == null || !(syntaxGraph instanceof PhraseStructure) || !syntaxGraph.hasTokens()) {
075                        return;
076                }
077                PhraseStructure phraseStructure = (PhraseStructure)syntaxGraph;
078                sentenceCount++;
079                try {
080                        writer.write("#BOS ");
081                        if (phraseStructure.getSentenceID() != 0) {
082                                writer.write(Integer.toString(phraseStructure.getSentenceID()));
083                        } else {
084                                writer.write(Integer.toString(sentenceCount));
085                        }
086                        writer.write('\n');
087
088                        if (phraseStructure.hasNonTerminals()) {
089                                calculateIndices(phraseStructure);
090                                writeTerminals(phraseStructure);
091                                writeNonTerminals(phraseStructure);
092                        } else {
093                                writeTerminals(phraseStructure);
094                        }
095                        writer.write("#EOS ");
096                        if (phraseStructure.getSentenceID() != 0) {
097                                writer.write(Integer.toString(phraseStructure.getSentenceID()));
098                        } else {
099                                writer.write(Integer.toString(sentenceCount));
100                        }
101                        writer.write('\n');
102                } catch (IOException e) {
103                        throw new DataFormatException("Could not write to the output file. ", e);
104                }
105        }
106        public void writeEpilog() throws MaltChainedException { }
107        
108
109        private void calculateIndices(PhraseStructure phraseStructure) throws MaltChainedException {
110                final SortedMap<Integer,Integer> heights = new TreeMap<Integer,Integer>();
111                for (int index : phraseStructure.getNonTerminalIndices()) {
112                        heights.put(index, ((NonTerminalNode)phraseStructure.getNonTerminalNode(index)).getHeight());
113                }
114                
115                boolean done = false;
116                int h = 1;
117                int ntid = START_ID_OF_NONTERMINALS;
118                nonTerminalIndexMap.clear();
119                while (!done) {
120                        done = true;
121                        for (int index : phraseStructure.getNonTerminalIndices()) {
122                                if (heights.get(index) == h) {
123                                        NonTerminalNode nt = (NonTerminalNode)phraseStructure.getNonTerminalNode(index);
124                                        nonTerminalIndexMap.put(nt.getIndex(), ntid++);
125//                                      nonTerminalIndexMap.put(nt.getIndex(), nt.getIndex()+START_ID_OF_NONTERMINALS-1);
126                                        done = false;
127                                }
128                        }
129                        h++;
130                }
131                
132//              boolean done = false;
133//              int h = 1;
134////            int ntid = START_ID_OF_NONTERMINALS;
135////            nonTerminalIndexMap.clear();
136//              while (!done) {
137//                      done = true;
138//                      for (int index : phraseStructure.getNonTerminalIndices()) {
139//                              if (heights.get(index) == h) {
140//                                      NonTerminalNode nt = (NonTerminalNode)phraseStructure.getNonTerminalNode(index);
141////                                    nonTerminalIndexMap.put(nt.getIndex(), ntid++);
142//                                      nonTerminalIndexMap.put(nt.getIndex(), nt.getIndex()+START_ID_OF_NONTERMINALS-1);
143//                                      done = false;
144//                              }
145//                      }
146//                      h++;
147//              }
148        }
149        
150        private void writeTerminals(PhraseStructure phraseStructure) throws MaltChainedException {
151                try {
152                        final SymbolTableHandler symbolTables = phraseStructure.getSymbolTables();
153                        for (int index : phraseStructure.getTokenIndices()) {
154                                final PhraseStructureNode terminal = phraseStructure.getTokenNode(index);
155                                final Iterator<ColumnDescription> columns = dataFormatInstance.iterator();
156                                ColumnDescription column = null;
157                                int ti = 1;
158                                while (columns.hasNext()) {
159                                        column = columns.next();
160                                        if (column.getCategory() == ColumnDescription.INPUT) {
161                                                SymbolTable table = symbolTables.getSymbolTable(column.getName());
162                                                writer.write(terminal.getLabelSymbol(table));
163                                                int nTabs = 1;
164                                                if (ti == 1 || ti == 2) {
165                                                        nTabs = 3 - (terminal.getLabelSymbol(table).length() / 8);
166                                                } else if (ti == 3) {
167                                                        nTabs = 1;
168                                                } else if (ti == 4) {
169                                                        nTabs = 2 - (terminal.getLabelSymbol(table).length() / 8);
170                                                }
171                                                if (nTabs < 1) {
172                                                        nTabs = 1;
173                                                }
174                                                for (int j = 0; j < nTabs; j++) {
175                                                        writer.write('\t');
176                                                }
177                                                ti++;
178                                        } else if (column.getCategory() == ColumnDescription.PHRASE_STRUCTURE_EDGE_LABEL) {
179                                                SymbolTable table = symbolTables.getSymbolTable(column.getName());
180                                                if (terminal.getParent() != null && terminal.hasParentEdgeLabel(table)) {
181                                                        writer.write(terminal.getParentEdgeLabelSymbol(table));
182                                                        writer.write('\t');
183                                                } else {
184                                                        writer.write("--\t");
185                                                }
186                                        } else if (column.getCategory() == ColumnDescription.PHRASE_STRUCTURE_NODE_LABEL) { 
187                                                if (terminal.getParent() == null || terminal.getParent() == phraseStructure.getPhraseStructureRoot()) {
188                                                        writer.write('0');
189                                                } else {
190                                                        writer.write(Integer.toString(nonTerminalIndexMap.get(terminal.getParent().getIndex())));
191//                                                      writer.write(Integer.toString(terminal.getParent().getIndex()+START_ID_OF_NONTERMINALS-1));
192                                                }
193                                        }
194                                }
195                                SymbolTable table = symbolTables.getSymbolTable(column.getName());
196                                for (Edge e : terminal.getIncomingSecondaryEdges()) {
197                                        if (e.hasLabel(table)) {
198                                                writer.write('\t');
199                                                writer.write(e.getLabelSymbol(table));
200                                                writer.write('\t');
201                                                if (e.getSource() instanceof NonTerminalNode) {
202                                                        writer.write(Integer.toString(nonTerminalIndexMap.get(e.getSource().getIndex())));
203//                                                      writer.write(Integer.toString(e.getSource().getIndex()+START_ID_OF_NONTERMINALS-1));
204                                                } else {
205                                                        writer.write(Integer.toString(e.getSource().getIndex()));
206                                                }
207                                        }
208                                }
209                                writer.write("\n");
210                        }
211
212                } catch (IOException e) {
213                        throw new DataFormatException("The Negra writer is not able to write. ", e);
214                }
215        }
216        
217        private void writeNonTerminals(PhraseStructure phraseStructure) throws MaltChainedException {
218                final SymbolTableHandler symbolTables = phraseStructure.getSymbolTables();
219                
220                for (int index : nonTerminalIndexMap.keySet()) {
221//              for (int index : phraseStructure.getNonTerminalIndices()) {
222                        NonTerminalNode nonTerminal = (NonTerminalNode)phraseStructure.getNonTerminalNode(index);
223        
224                        if (nonTerminal == null || nonTerminal.isRoot()) {
225                                return;
226                        }
227                        try {
228                                writer.write('#');
229//                              writer.write(Integer.toString(index+START_ID_OF_NONTERMINALS-1));
230                                writer.write(Integer.toString(nonTerminalIndexMap.get(index)));
231                                writer.write("\t\t\t--\t\t\t");
232                                if (nonTerminal.hasLabel(symbolTables.getSymbolTable("CAT"))) {
233                                        writer.write(nonTerminal.getLabelSymbol(symbolTables.getSymbolTable("CAT")));
234                                } else {
235                                        writer.write("--");
236                                }
237                                writer.write("\t--\t\t");
238                                if (nonTerminal.hasParentEdgeLabel(symbolTables.getSymbolTable("LABEL"))) {
239                                        writer.write(nonTerminal.getParentEdgeLabelSymbol(symbolTables.getSymbolTable("LABEL")));
240                                } else {
241                                        writer.write("--");
242                                }
243                                writer.write('\t');
244                                if (nonTerminal.getParent() == null || nonTerminal.getParent().isRoot()) {
245                                        writer.write('0');
246                                } else {
247//                                      writer.write(Integer.toString(nonTerminal.getParent().getIndex()+START_ID_OF_NONTERMINALS-1));
248                                        writer.write(Integer.toString(nonTerminalIndexMap.get(nonTerminal.getParent().getIndex())));
249                                }
250                                for (Edge e : nonTerminal.getIncomingSecondaryEdges()) {
251                                        if (e.hasLabel(symbolTables.getSymbolTable("SECEDGELABEL"))) {
252                                                writer.write('\t');
253                                                writer.write(e.getLabelSymbol(symbolTables.getSymbolTable("SECEDGELABEL")));
254                                                writer.write('\t');
255                                                if (e.getSource() instanceof NonTerminalNode) {
256//                                                      writer.write(Integer.toString(e.getSource().getIndex()+START_ID_OF_NONTERMINALS-1));
257                                                        writer.write(Integer.toString(nonTerminalIndexMap.get(e.getSource().getIndex())));
258                                                } else {
259                                                        writer.write(Integer.toString(e.getSource().getIndex()));
260                                                }
261                                        }
262                                }
263                                writer.write("\n");
264                        } catch (IOException e) {
265                                throw new DataFormatException("The Negra writer is not able to write the non-terminals. ", e);
266                        }
267                }
268        }
269        
270        public BufferedWriter getWriter() {
271                return writer;
272        }
273
274        public void setWriter(BufferedWriter writer) {
275                this.writer = writer;
276        }
277        
278        public int getSentenceCount() {
279                return sentenceCount;
280        }
281
282        public void setSentenceCount(int sentenceCount) {
283                this.sentenceCount = sentenceCount;
284        }
285        
286        public DataFormatInstance getDataFormatInstance() {
287                return dataFormatInstance;
288        }
289
290        public void setDataFormatInstance(DataFormatInstance dataFormatInstance) {
291                this.dataFormatInstance = dataFormatInstance;
292        }
293
294        public String getOptions() {
295                return optionString;
296        }
297        
298        public void setOptions(String optionString) throws MaltChainedException {
299                this.optionString = optionString;
300                String[] argv;
301                try {
302                        argv = optionString.split("[_\\p{Blank}]");
303                } catch (PatternSyntaxException e) {
304                        throw new DataFormatException("Could not split the penn writer option '"+optionString+"'. ", e);
305                }
306                for (int i=0; i < argv.length-1; i++) {
307                        if(argv[i].charAt(0) != '-') {
308                                throw new DataFormatException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0));
309                        }
310                        if(++i>=argv.length) {
311                                throw new DataFormatException("The last argument does not have any value. ");
312                        }
313                        switch(argv[i-1].charAt(1)) {
314                        case 's': 
315                                try {
316                                        START_ID_OF_NONTERMINALS = Integer.parseInt(argv[i]);
317                                } catch (NumberFormatException e){
318                                        throw new MaltChainedException("The TigerXML Reader option -s must be an integer value. ");
319                                }
320                                break;
321                        default:
322                                throw new DataFormatException("Unknown svm parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. ");             
323                        }
324                }       
325        }
326        
327        public void close() throws MaltChainedException {
328                try {
329                        if (writer != null) {
330                                writer.flush();
331                                if (closeStream) {
332                                        writer.close();
333                                }
334                                writer = null;
335                        }
336                }   catch (IOException e) {
337                        throw new DataFormatException("Could not close the output file. ", e);
338                } 
339        }
340}