001package org.maltparser.core.symbol.trie;
002
003import java.io.BufferedReader;
004import java.io.BufferedWriter;
005import java.io.FileInputStream;
006import java.io.FileNotFoundException;
007import java.io.FileOutputStream;
008import java.io.UnsupportedEncodingException;
009import java.io.IOException;
010import java.io.InputStreamReader;
011import java.io.OutputStreamWriter;
012import java.util.Set;
013import java.util.regex.Pattern;
014import java.util.regex.PatternSyntaxException;
015
016import org.maltparser.core.exception.MaltChainedException;
017import org.maltparser.core.helper.HashMap;
018import org.maltparser.core.symbol.SymbolException;
019import org.maltparser.core.symbol.SymbolTable;
020import org.maltparser.core.symbol.SymbolTableHandler;
021
022
023/**
024
025@author Johan Hall
026*/
027public class TrieSymbolTableHandler implements SymbolTableHandler {
028        private final Trie trie;
029        private final HashMap<String, TrieSymbolTable> symbolTables;
030
031        public TrieSymbolTableHandler() { 
032                trie = new Trie();
033                symbolTables = new HashMap<String, TrieSymbolTable>();
034        }
035
036        public TrieSymbolTable addSymbolTable(String tableName) throws MaltChainedException {
037                TrieSymbolTable symbolTable = symbolTables.get(tableName);
038                if (symbolTable == null) {
039                        symbolTable = new TrieSymbolTable(tableName, trie); 
040                        symbolTables.put(tableName, symbolTable);
041                }
042                return symbolTable;
043        }
044        
045        public TrieSymbolTable addSymbolTable(String tableName, SymbolTable parentTable) throws MaltChainedException {
046                TrieSymbolTable symbolTable = symbolTables.get(tableName);
047                if (symbolTable == null) {
048                        TrieSymbolTable trieParentTable = (TrieSymbolTable)parentTable;
049                        symbolTable = new TrieSymbolTable(tableName, trie, trieParentTable.getCategory(), trieParentTable.getNullValueStrategy());
050                        symbolTables.put(tableName, symbolTable);
051                }
052                return symbolTable;
053        }
054        
055        public TrieSymbolTable addSymbolTable(String tableName, int columnCategory, int columnType, String nullValueStrategy) throws MaltChainedException {
056                TrieSymbolTable symbolTable = symbolTables.get(tableName);
057                if (symbolTable == null) {
058                        symbolTable = new TrieSymbolTable(tableName, trie, columnCategory, nullValueStrategy);
059                        symbolTables.put(tableName, symbolTable);
060                }
061                return symbolTable;
062        }
063        
064        public TrieSymbolTable getSymbolTable(String tableName) {
065                return symbolTables.get(tableName);
066        }
067        
068        public Set<String> getSymbolTableNames() {
069                return symbolTables.keySet();
070        }
071        
072        public void cleanUp() {
073        }
074        
075        public void save(OutputStreamWriter osw) throws MaltChainedException  {
076                try {
077                        BufferedWriter bout = new BufferedWriter(osw);
078                        for (TrieSymbolTable table : symbolTables.values()) {
079                                table.saveHeader(bout);
080                        }
081                        bout.write('\n');
082                        for (TrieSymbolTable table : symbolTables.values()) {
083                                table.save(bout);
084                        }
085                        bout.close();
086                } catch (IOException e) {
087                        throw new SymbolException("Could not save the symbol tables. ", e);
088                }               
089        }
090        
091        public void save(String fileName, String charSet) throws MaltChainedException  {
092                try {
093                        save(new OutputStreamWriter(new FileOutputStream(fileName), charSet));
094                } catch (FileNotFoundException e) {
095                        throw new SymbolException("The symbol table file '"+fileName+"' cannot be created. ", e);
096                } catch (UnsupportedEncodingException e) {
097                        throw new SymbolException("The char set '"+charSet+"' is not supported. ", e);
098                }
099        }
100        
101        public void loadHeader(BufferedReader bin) throws MaltChainedException {
102                String fileLine = "";
103                Pattern tabPattern = Pattern.compile("\t");
104                try {
105                        while ((fileLine = bin.readLine()) != null) {
106                                if (fileLine.length() == 0 || fileLine.charAt(0) != '\t') {
107                                        break;
108                                }
109                                String items[];
110                                try {
111                                        items = tabPattern.split(fileLine.substring(1));
112                                } catch (PatternSyntaxException e) {
113                                        throw new SymbolException("The header line of the symbol table  '"+fileLine.substring(1)+"' could not split into atomic parts. ", e);
114                                }
115                                if (items.length == 4)
116                                        addSymbolTable(items[0], Integer.parseInt(items[1]), Integer.parseInt(items[2]), items[3]);
117                                else if (items.length == 3) 
118                                        addSymbolTable(items[0], Integer.parseInt(items[1]), SymbolTable.STRING, items[2]);
119                                else
120                                        throw new SymbolException("The header line of the symbol table  '"+fileLine.substring(1)+"' must contain three or four columns. ");
121
122                        }
123                } catch (NumberFormatException e) {
124                        throw new SymbolException("The symbol table file (.sym) contains a non-integer value in the header. ", e);
125                } catch (IOException e) {
126                        throw new SymbolException("Could not load the symbol table. ", e);
127                }
128        }
129        
130        
131        public void load(InputStreamReader isr) throws MaltChainedException  {
132                try {
133                        BufferedReader bin = new BufferedReader(isr);
134                        String fileLine;
135                        SymbolTable table = null;
136                        bin.mark(2);
137                        if (bin.read() == '\t') {
138                                bin.reset();
139                                loadHeader(bin);
140                        } else {
141                                bin.reset();
142                        }
143                        while ((fileLine = bin.readLine()) != null) {
144                                if (fileLine.length() > 0) {
145                                        table = addSymbolTable(fileLine);
146                                        table.load(bin);
147                                }
148                        }
149                        bin.close();
150                } catch (IOException e) {
151                        throw new SymbolException("Could not load the symbol tables. ", e);
152                }                       
153        }
154        
155        public void load(String fileName, String charSet) throws MaltChainedException  {
156                try {
157                        load(new InputStreamReader(new FileInputStream(fileName), charSet));
158                } catch (FileNotFoundException e) {
159                        throw new SymbolException("The symbol table file '"+fileName+"' cannot be found. ", e);
160                } catch (UnsupportedEncodingException e) {
161                        throw new SymbolException("The char set '"+charSet+"' is not supported. ", e);
162                }               
163        }
164        
165        
166        public SymbolTable loadTagset(String fileName, String tableName, String charSet, int columnCategory, int columnType, String nullValueStrategy) throws MaltChainedException {
167                try {
168                        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(fileName), charSet));
169                        String fileLine;
170                        TrieSymbolTable table = addSymbolTable(tableName, columnCategory, columnType, nullValueStrategy);
171
172                        while ((fileLine = br.readLine()) != null) {
173                                table.addSymbol(fileLine.trim());
174                        }
175                        br.close();
176                        return table;
177                } catch (FileNotFoundException e) {
178                        throw new SymbolException("The tagset file '"+fileName+"' cannot be found. ", e);
179                } catch (UnsupportedEncodingException e) {
180                        throw new SymbolException("The char set '"+charSet+"' is not supported. ", e);
181                } catch (IOException e) {
182                        throw new SymbolException("The tagset file '"+fileName+"' cannot be loaded. ", e);
183                }
184        }
185        
186//      public String printSymbolTables() throws MaltChainedException  {
187//              StringBuilder sb = new StringBuilder();
188//              for (TrieSymbolTable table : symbolTables.values()) {
189//                      sb.append(table.printSymbolTable());
190//              }
191//              return sb.toString();
192//      }
193}