001    package org.maltparser.core.symbol.trie;
002    
003    import java.io.BufferedReader;
004    import java.io.BufferedWriter;
005    import java.io.FileInputStream;
006    import java.io.FileNotFoundException;
007    import java.io.FileOutputStream;
008    import java.io.UnsupportedEncodingException;
009    
010    import java.io.IOException;
011    import java.io.InputStreamReader;
012    import java.io.OutputStreamWriter;
013    import java.util.Set;
014    import java.util.regex.Pattern;
015    import java.util.regex.PatternSyntaxException;
016    
017    import org.apache.log4j.Logger;
018    
019    import org.maltparser.core.exception.MaltChainedException;
020    import org.maltparser.core.helper.HashMap;
021    import org.maltparser.core.symbol.SymbolException;
022    import org.maltparser.core.symbol.SymbolTable;
023    import org.maltparser.core.symbol.SymbolTableHandler;
024    
025    
026    /**
027    
028    @author Johan Hall
029    @since 1.0
030    */
031    public class TrieSymbolTableHandler implements SymbolTableHandler {
032            private final Trie trie;
033            private final HashMap<String, TrieSymbolTable> symbolTables;
034            
035            public TrieSymbolTableHandler() {
036                    trie = new Trie();
037                    symbolTables = new HashMap<String, TrieSymbolTable>();
038            }
039    
040            public TrieSymbolTable addSymbolTable(String tableName) throws MaltChainedException {
041                    TrieSymbolTable symbolTable = symbolTables.get(tableName);
042                    if (symbolTable == null) {
043                            symbolTable = new TrieSymbolTable(tableName, trie);
044                            symbolTables.put(tableName, symbolTable);
045                    }
046                    return symbolTable;
047            }
048            
049            public TrieSymbolTable addSymbolTable(String tableName, SymbolTable parentTable) throws MaltChainedException {
050                    TrieSymbolTable symbolTable = symbolTables.get(tableName);
051                    if (symbolTable == null) {
052                            TrieSymbolTable trieParentTable = (TrieSymbolTable)parentTable;
053                            symbolTable = new TrieSymbolTable(tableName, trie, trieParentTable.getColumnCategory(), trieParentTable.getNullValueStrategy());
054                            symbolTables.put(tableName, symbolTable);
055                    }
056                    return symbolTable;
057            }
058            
059            public TrieSymbolTable addSymbolTable(String tableName, int columnCategory, String nullValueStrategy) throws MaltChainedException {
060                    TrieSymbolTable symbolTable = symbolTables.get(tableName);
061                    if (symbolTable == null) {
062                            symbolTable = new TrieSymbolTable(tableName, trie, columnCategory, nullValueStrategy);
063                            symbolTables.put(tableName, symbolTable);
064                    }
065                    return symbolTable;
066            }
067            
068            public TrieSymbolTable getSymbolTable(String tableName) {
069                    return symbolTables.get(tableName);
070            }
071            
072            public Set<String> getSymbolTableNames() {
073                    return symbolTables.keySet();
074            }
075            
076            public void save(OutputStreamWriter osw) throws MaltChainedException  {
077                    try {
078                            BufferedWriter bout = new BufferedWriter(osw);
079                            for (TrieSymbolTable table : symbolTables.values()) {
080                                    table.saveHeader(bout);
081                            }
082                            bout.write('\n');
083                            for (TrieSymbolTable table : symbolTables.values()) {
084                                    table.save(bout);
085                            }
086                            bout.close();
087                    } catch (IOException e) {
088                            throw new SymbolException("Could not save the symbol tables. ", e);
089                    }               
090            }
091            
092            public void save(String fileName, String charSet) throws MaltChainedException  {
093                    try {
094                            save(new OutputStreamWriter(new FileOutputStream(fileName), charSet));
095                    } catch (FileNotFoundException e) {
096                            throw new SymbolException("The symbol table file '"+fileName+"' cannot be created. ", e);
097                    } catch (UnsupportedEncodingException e) {
098                            throw new SymbolException("The char set '"+charSet+"' is not supported. ", e);
099                    }
100            }
101            
102            public void loadHeader(BufferedReader bin) throws MaltChainedException {
103                    String fileLine = "";
104                    Pattern tabPattern = Pattern.compile("\t");
105                    try {
106                            while ((fileLine = bin.readLine()) != null) {
107                                    if (fileLine.length() == 0 || fileLine.charAt(0) != '\t') {
108                                            break;
109                                    }
110                                    String items[];
111                                    try {
112                                            items = tabPattern.split(fileLine.substring(1));
113                                    } catch (PatternSyntaxException e) {
114                                            throw new SymbolException("The header line of the symbol table  '"+fileLine.substring(1)+"' could not split into atomic parts. ", e);
115                                    }
116                                    if (items.length != 3) {
117                                            throw new SymbolException("The header line of the symbol table  '"+fileLine.substring(1)+"' must contain four columns. ");
118                                    }
119                                    addSymbolTable(items[0], Integer.parseInt(items[1]), items[2]);
120                            }
121                    } catch (NumberFormatException e) {
122                            throw new SymbolException("The symbol table file (.sym) contains a non-integer value in the header. ", e);
123                    } catch (IOException e) {
124                            throw new SymbolException("Could not load the symbol table. ", e);
125                    }
126            }
127            
128            
129            public void load(InputStreamReader isr) throws MaltChainedException  {
130                    try {
131                            BufferedReader bin = new BufferedReader(isr);
132                            String fileLine;
133                            SymbolTable table = null;
134                            bin.mark(2);
135                            if (bin.read() == '\t') {
136                                    bin.reset();
137                                    loadHeader(bin);
138                            } else {
139                                    bin.reset();
140                            }
141                            while ((fileLine = bin.readLine()) != null) {
142                                    if (fileLine.length() > 0) {
143                                            table = addSymbolTable(fileLine);
144                                            table.load(bin);
145                                    }
146                            }
147                            bin.close();
148                    } catch (IOException e) {
149                            throw new SymbolException("Could not load the symbol tables. ", e);
150                    }                       
151            }
152            
153            public void load(String fileName, String charSet) throws MaltChainedException  {
154                    try {
155                            load(new InputStreamReader(new FileInputStream(fileName), charSet));
156    
157                    } catch (FileNotFoundException e) {
158                            throw new SymbolException("The symbol table file '"+fileName+"' cannot be found. ", e);
159                    } catch (UnsupportedEncodingException e) {
160                            throw new SymbolException("The char set '"+charSet+"' is not supported. ", e);
161                    }               
162            }
163            
164            
165            public SymbolTable loadTagset(String fileName, String tableName, String charSet, int columnCategory, String nullValueStrategy) throws MaltChainedException {
166                    try {
167                            BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(fileName), charSet));
168                            String fileLine;
169                            TrieSymbolTable table = addSymbolTable(tableName, columnCategory, nullValueStrategy);
170    
171                            while ((fileLine = br.readLine()) != null) {
172                                    table.addSymbol(fileLine.trim());
173                            }
174                            return table;
175                    } catch (FileNotFoundException e) {
176                            throw new SymbolException("The tagset file '"+fileName+"' cannot be found. ", e);
177                    } catch (UnsupportedEncodingException e) {
178                            throw new SymbolException("The char set '"+charSet+"' is not supported. ", e);
179                    } catch (IOException e) {
180                            throw new SymbolException("The tagset file '"+fileName+"' cannot be loaded. ", e);
181                    }
182            }
183            
184            public void printSymbolTables(Logger logger) throws MaltChainedException  {
185                    for (TrieSymbolTable table : symbolTables.values()) {
186                            table.printSymbolTable(logger);
187                    }       
188            }
189    }