001    package org.maltparser.core.symbol.trie;
002    
003    import java.io.BufferedReader;
004    import java.io.BufferedWriter;
005    import java.io.IOException;
006    import java.util.Set;
007    import java.util.SortedMap;
008    import java.util.TreeMap;
009    
010    import org.apache.log4j.Logger;
011    import org.maltparser.core.exception.MaltChainedException;
012    import org.maltparser.core.io.dataformat.ColumnDescription;
013    import org.maltparser.core.symbol.SymbolException;
014    import org.maltparser.core.symbol.SymbolTable;
015    import org.maltparser.core.symbol.nullvalue.InputNullValues;
016    import org.maltparser.core.symbol.nullvalue.NullValues;
017    import org.maltparser.core.symbol.nullvalue.OutputNullValues;
018    import org.maltparser.core.symbol.nullvalue.NullValues.NullValueId;
019    /**
020    
021    @author Johan Hall
022    @since 1.0
023    */
024    public class TrieSymbolTable implements SymbolTable {
025            private final String name;
026            private final Trie trie;
027            private final SortedMap<Integer, TrieNode> codeTable;
028            private int columnCategory;
029            private final NullValues nullValues;
030            private int valueCounter;
031        /** Cache the hash code for the symbol table */
032        private int cachedHash;
033        
034            public TrieSymbolTable(String name, Trie trie, int columnCategory, String nullValueStrategy) throws MaltChainedException {
035                    this.name = name;
036                    this.trie = trie;
037                    this.columnCategory = columnCategory;
038                    codeTable = new TreeMap<Integer, TrieNode>();
039                    if (columnCategory == ColumnDescription.INPUT) {
040                            nullValues = new InputNullValues(nullValueStrategy, this);
041                    } else if (columnCategory == ColumnDescription.DEPENDENCY_EDGE_LABEL) {
042                            nullValues = new OutputNullValues(nullValueStrategy, this);
043                    } else {
044                            nullValues = new InputNullValues(nullValueStrategy, this);
045                    }
046                    valueCounter = nullValues.getNextCode();
047            }
048            
049            public TrieSymbolTable(String name, Trie trie) {
050                    this.name = name;
051                    this.trie = trie;
052                    codeTable = new TreeMap<Integer, TrieNode>();
053                    nullValues = new InputNullValues("one", this);
054                    valueCounter = 1;
055            }
056            
057            public int addSymbol(String symbol) throws MaltChainedException {
058                    if (nullValues == null || !nullValues.isNullValue(symbol)) {
059                            if (symbol == null || symbol.length() == 0) {
060                                    throw new SymbolException("Symbol table error: empty string cannot be added to the symbol table");
061                            }
062                            final TrieNode node = trie.addValue(symbol, this, -1);
063                            final int code = node.getEntry(this); 
064                            if (!codeTable.containsKey(code)) {
065                                    codeTable.put(code, node);
066                            }
067                            return code;
068                    } else {
069                            return nullValues.symbolToCode(symbol);
070                    }
071            }
072            
073            public int addSymbol(StringBuilder symbol) throws MaltChainedException {
074                    if (nullValues == null || !nullValues.isNullValue(symbol)) {
075                            if (symbol == null || symbol.length() == 0) {
076                                    throw new SymbolException("Symbol table error: empty string cannot be added to the symbol table");
077                            }
078                            final TrieNode node = trie.addValue(symbol, this, -1);
079                            final int code = node.getEntry(this);
080                            if (!codeTable.containsKey(code)) {
081                                    codeTable.put(code, node);
082                            }
083                            return code;
084                    } else {
085                            return nullValues.symbolToCode(symbol);
086                    }
087            }
088            
089            public String getSymbolCodeToString(int code) throws MaltChainedException {
090                    if (code >= 0) {
091                            if (nullValues == null || !nullValues.isNullValue(code)) {
092                                    if (trie == null) {
093                                            throw new SymbolException("The symbol table is corrupt. ");
094                                    }
095                                    return trie.getValue(codeTable.get(code), this);
096                            } else {
097                                    return nullValues.codeToSymbol(code);
098                            }
099                    } else {
100                            throw new SymbolException("The symbol code '"+code+"' cannot be found in the symbol table. ");
101                    }
102            }
103            
104            public int getSymbolStringToCode(String symbol) throws MaltChainedException {
105                    if (symbol != null) {
106                            if (nullValues == null || !nullValues.isNullValue(symbol)) {
107                                    if (trie == null) {
108                                            throw new SymbolException("The symbol table is corrupt. ");
109                                    } 
110                                    final Integer entry = trie.getEntry(symbol, this);
111                                    if (entry == null) {
112                                            throw new SymbolException("Could not find the symbol '"+symbol+"' in the symbol table. ");
113                                    }
114                                    return entry; //.getCode();                             
115                            } else {
116                                    return nullValues.symbolToCode(symbol);
117                            }
118                    } else {
119                            throw new SymbolException("The symbol code '"+symbol+"' cannot be found in the symbol table. ");
120                    }
121            }
122    
123            public String getNullValueStrategy() {
124                    if (nullValues == null) {
125                            return null;
126                    }
127                    return nullValues.getNullValueStrategy();
128            }
129            
130            
131            public int getColumnCategory() {
132                    return columnCategory;
133            }
134            
135            public void printSymbolTable(Logger logger) throws MaltChainedException {
136                    for (Integer code : codeTable.keySet()) {
137                            logger.info(code+"\t"+trie.getValue(codeTable.get(code), this)+"\n");
138                    }
139            }
140            
141            public void saveHeader(BufferedWriter out) throws MaltChainedException  {
142                    try {
143                            out.append('\t');
144                            out.append(getName());
145                            out.append('\t');
146                            out.append(Integer.toString(getColumnCategory()));
147                            out.append('\t');
148                            out.append(getNullValueStrategy());
149                            out.append('\n');
150                    } catch (IOException e) {
151                            throw new SymbolException("Could not save the symbol table. ", e);
152                    }
153            }
154            
155            public int size() {
156                    return codeTable.size();
157            }
158            
159            public void save(BufferedWriter out) throws MaltChainedException  {
160                    try {
161                            out.write(name);
162                            out.write('\n');
163                            for (Integer code : codeTable.keySet()) {
164                                    out.write(code+"");
165                                    out.write('\t');
166                                    out.write(trie.getValue(codeTable.get(code), this));
167                                    out.write('\n');
168                            }
169                            out.write('\n');
170                    } catch (IOException e) {
171                            throw new SymbolException("Could not save the symbol table. ", e);
172                    }
173            }
174            
175            public void load(BufferedReader in) throws MaltChainedException {
176                    int max = 0;
177                    int index = 0;
178                    String fileLine;
179                    try {
180                            while ((fileLine = in.readLine()) != null) {
181                                    if (fileLine.length() == 0 || (index = fileLine.indexOf('\t')) == -1) {
182                                            setValueCounter(max+1);
183                                            break;
184                                    }
185                                    int code = Integer.parseInt(fileLine.substring(0,index));
186                                    final String str = fileLine.substring(index+1);
187                                    final TrieNode node = trie.addValue(str, this, code);
188                                    codeTable.put(node.getEntry(this), node); //.getCode(), node);
189                                    if (max < code) {
190                                            max = code;
191                                    }
192                            }
193                    } catch (NumberFormatException e) {
194                            throw new SymbolException("The symbol table file (.sym) contains a non-integer value in the first column. ", e);
195                    } catch (IOException e) {
196                            throw new SymbolException("Could not load the symbol table. ", e);
197                    }
198            }
199            
200            public String getName() {
201                    return name;
202            }
203    
204            public int getValueCounter() {
205                    return valueCounter;
206            }
207    
208            private void setValueCounter(int valueCounter) {
209                    this.valueCounter = valueCounter;
210            }
211            
212            protected void updateValueCounter(int code) {
213                    if (code > valueCounter) {
214                            valueCounter = code;
215                    }
216            }
217            
218            protected int increaseValueCounter() {
219                    return valueCounter++;
220            }
221            
222            public int getNullValueCode(NullValueId nullValueIdentifier) throws MaltChainedException {
223                    if (nullValues == null) {
224                            throw new SymbolException("The symbol table does not have any null-values. ");
225                    }
226                    return nullValues.nullvalueToCode(nullValueIdentifier);
227            }
228            
229            public String getNullValueSymbol(NullValueId nullValueIdentifier) throws MaltChainedException {
230                    if (nullValues == null) {
231                            throw new SymbolException("The symbol table does not have any null-values. ");
232                    }
233                    return nullValues.nullvalueToSymbol(nullValueIdentifier);
234            }
235            
236            public boolean isNullValue(String symbol) throws MaltChainedException {
237                    if (nullValues != null) {
238                            return nullValues.isNullValue(symbol);
239                    } 
240                    return false;
241            }
242            
243            public boolean isNullValue(int code) throws MaltChainedException {
244                    if (nullValues != null) {
245                            return nullValues.isNullValue(code);
246                    } 
247                    return false;
248            }
249            
250            public void copy(SymbolTable fromTable) throws MaltChainedException {
251                    final SortedMap<Integer, TrieNode> fromCodeTable =  ((TrieSymbolTable)fromTable).getCodeTable();
252                    int max = getValueCounter()-1;
253                    for (Integer code : fromCodeTable.keySet()) {
254                            final String str = trie.getValue(fromCodeTable.get(code), this);
255                            final TrieNode node = trie.addValue(str, this, code);
256                            codeTable.put(node.getEntry(this), node); //.getCode(), node);
257                            if (max < code) {
258                                    max = code;
259                            }
260                    }
261                    setValueCounter(max+1);
262            }
263    
264            public SortedMap<Integer, TrieNode> getCodeTable() {
265                    return codeTable;
266            }
267            
268            public Set<Integer> getCodes() {
269                    return codeTable.keySet();
270            }
271            
272            protected Trie getTrie() {
273                    return trie;
274            }
275            
276            public boolean equals(Object obj) {
277                    if (this == obj)
278                            return true;
279                    if (obj == null)
280                            return false;
281                    if (getClass() != obj.getClass())
282                            return false;
283                    final TrieSymbolTable other = (TrieSymbolTable)obj;
284                    return ((name == null) ? other.name == null : name.equals(other.name));
285            }
286    
287            public int hashCode() {
288                    if (cachedHash == 0) {
289                            cachedHash = 217 + (null == name ? 0 : name.hashCode());
290                    }
291                    return cachedHash;
292            }
293            
294            public String toString() {
295                    final StringBuilder sb = new StringBuilder();
296                    sb.append(name);
297                    sb.append(' ');
298                    sb.append(valueCounter);
299                    return sb.toString();
300            }
301    }