001 package org.maltparser.core.symbol.trie; 002 003 import java.io.BufferedReader; 004 import java.io.BufferedWriter; 005 import java.io.IOException; 006 import java.util.Set; 007 import java.util.SortedMap; 008 import java.util.TreeMap; 009 010 import org.apache.log4j.Logger; 011 import org.maltparser.core.exception.MaltChainedException; 012 import org.maltparser.core.io.dataformat.ColumnDescription; 013 import org.maltparser.core.symbol.SymbolException; 014 import org.maltparser.core.symbol.SymbolTable; 015 import org.maltparser.core.symbol.nullvalue.InputNullValues; 016 import org.maltparser.core.symbol.nullvalue.NullValues; 017 import org.maltparser.core.symbol.nullvalue.OutputNullValues; 018 import org.maltparser.core.symbol.nullvalue.NullValues.NullValueId; 019 /** 020 021 @author Johan Hall 022 @since 1.0 023 */ 024 public class TrieSymbolTable implements SymbolTable { 025 private final String name; 026 private final Trie trie; 027 private final SortedMap<Integer, TrieNode> codeTable; 028 private int columnCategory; 029 private final NullValues nullValues; 030 private int valueCounter; 031 /** Cache the hash code for the symbol table */ 032 private int cachedHash; 033 034 public TrieSymbolTable(String name, Trie trie, int columnCategory, String nullValueStrategy) throws MaltChainedException { 035 this.name = name; 036 this.trie = trie; 037 this.columnCategory = columnCategory; 038 codeTable = new TreeMap<Integer, TrieNode>(); 039 if (columnCategory == ColumnDescription.INPUT) { 040 nullValues = new InputNullValues(nullValueStrategy, this); 041 } else if (columnCategory == ColumnDescription.DEPENDENCY_EDGE_LABEL) { 042 nullValues = new OutputNullValues(nullValueStrategy, this); 043 } else { 044 nullValues = new InputNullValues(nullValueStrategy, this); 045 } 046 valueCounter = nullValues.getNextCode(); 047 } 048 049 public TrieSymbolTable(String name, Trie trie) { 050 this.name = name; 051 this.trie = trie; 052 codeTable = new TreeMap<Integer, TrieNode>(); 053 nullValues = new InputNullValues("one", this); 054 valueCounter = 1; 055 } 056 057 public int addSymbol(String symbol) throws MaltChainedException { 058 if (nullValues == null || !nullValues.isNullValue(symbol)) { 059 if (symbol == null || symbol.length() == 0) { 060 throw new SymbolException("Symbol table error: empty string cannot be added to the symbol table"); 061 } 062 final TrieNode node = trie.addValue(symbol, this, -1); 063 final int code = node.getEntry(this); 064 if (!codeTable.containsKey(code)) { 065 codeTable.put(code, node); 066 } 067 return code; 068 } else { 069 return nullValues.symbolToCode(symbol); 070 } 071 } 072 073 public int addSymbol(StringBuilder symbol) throws MaltChainedException { 074 if (nullValues == null || !nullValues.isNullValue(symbol)) { 075 if (symbol == null || symbol.length() == 0) { 076 throw new SymbolException("Symbol table error: empty string cannot be added to the symbol table"); 077 } 078 final TrieNode node = trie.addValue(symbol, this, -1); 079 final int code = node.getEntry(this); 080 if (!codeTable.containsKey(code)) { 081 codeTable.put(code, node); 082 } 083 return code; 084 } else { 085 return nullValues.symbolToCode(symbol); 086 } 087 } 088 089 public String getSymbolCodeToString(int code) throws MaltChainedException { 090 if (code >= 0) { 091 if (nullValues == null || !nullValues.isNullValue(code)) { 092 if (trie == null) { 093 throw new SymbolException("The symbol table is corrupt. "); 094 } 095 return trie.getValue(codeTable.get(code), this); 096 } else { 097 return nullValues.codeToSymbol(code); 098 } 099 } else { 100 throw new SymbolException("The symbol code '"+code+"' cannot be found in the symbol table. "); 101 } 102 } 103 104 public int getSymbolStringToCode(String symbol) throws MaltChainedException { 105 if (symbol != null) { 106 if (nullValues == null || !nullValues.isNullValue(symbol)) { 107 if (trie == null) { 108 throw new SymbolException("The symbol table is corrupt. "); 109 } 110 final Integer entry = trie.getEntry(symbol, this); 111 if (entry == null) { 112 throw new SymbolException("Could not find the symbol '"+symbol+"' in the symbol table. "); 113 } 114 return entry; //.getCode(); 115 } else { 116 return nullValues.symbolToCode(symbol); 117 } 118 } else { 119 throw new SymbolException("The symbol code '"+symbol+"' cannot be found in the symbol table. "); 120 } 121 } 122 123 public String getNullValueStrategy() { 124 if (nullValues == null) { 125 return null; 126 } 127 return nullValues.getNullValueStrategy(); 128 } 129 130 131 public int getColumnCategory() { 132 return columnCategory; 133 } 134 135 public void printSymbolTable(Logger logger) throws MaltChainedException { 136 for (Integer code : codeTable.keySet()) { 137 logger.info(code+"\t"+trie.getValue(codeTable.get(code), this)+"\n"); 138 } 139 } 140 141 public void saveHeader(BufferedWriter out) throws MaltChainedException { 142 try { 143 out.append('\t'); 144 out.append(getName()); 145 out.append('\t'); 146 out.append(Integer.toString(getColumnCategory())); 147 out.append('\t'); 148 out.append(getNullValueStrategy()); 149 out.append('\n'); 150 } catch (IOException e) { 151 throw new SymbolException("Could not save the symbol table. ", e); 152 } 153 } 154 155 public int size() { 156 return codeTable.size(); 157 } 158 159 public void save(BufferedWriter out) throws MaltChainedException { 160 try { 161 out.write(name); 162 out.write('\n'); 163 for (Integer code : codeTable.keySet()) { 164 out.write(code+""); 165 out.write('\t'); 166 out.write(trie.getValue(codeTable.get(code), this)); 167 out.write('\n'); 168 } 169 out.write('\n'); 170 } catch (IOException e) { 171 throw new SymbolException("Could not save the symbol table. ", e); 172 } 173 } 174 175 public void load(BufferedReader in) throws MaltChainedException { 176 int max = 0; 177 int index = 0; 178 String fileLine; 179 try { 180 while ((fileLine = in.readLine()) != null) { 181 if (fileLine.length() == 0 || (index = fileLine.indexOf('\t')) == -1) { 182 setValueCounter(max+1); 183 break; 184 } 185 int code = Integer.parseInt(fileLine.substring(0,index)); 186 final String str = fileLine.substring(index+1); 187 final TrieNode node = trie.addValue(str, this, code); 188 codeTable.put(node.getEntry(this), node); //.getCode(), node); 189 if (max < code) { 190 max = code; 191 } 192 } 193 } catch (NumberFormatException e) { 194 throw new SymbolException("The symbol table file (.sym) contains a non-integer value in the first column. ", e); 195 } catch (IOException e) { 196 throw new SymbolException("Could not load the symbol table. ", e); 197 } 198 } 199 200 public String getName() { 201 return name; 202 } 203 204 public int getValueCounter() { 205 return valueCounter; 206 } 207 208 private void setValueCounter(int valueCounter) { 209 this.valueCounter = valueCounter; 210 } 211 212 protected void updateValueCounter(int code) { 213 if (code > valueCounter) { 214 valueCounter = code; 215 } 216 } 217 218 protected int increaseValueCounter() { 219 return valueCounter++; 220 } 221 222 public int getNullValueCode(NullValueId nullValueIdentifier) throws MaltChainedException { 223 if (nullValues == null) { 224 throw new SymbolException("The symbol table does not have any null-values. "); 225 } 226 return nullValues.nullvalueToCode(nullValueIdentifier); 227 } 228 229 public String getNullValueSymbol(NullValueId nullValueIdentifier) throws MaltChainedException { 230 if (nullValues == null) { 231 throw new SymbolException("The symbol table does not have any null-values. "); 232 } 233 return nullValues.nullvalueToSymbol(nullValueIdentifier); 234 } 235 236 public boolean isNullValue(String symbol) throws MaltChainedException { 237 if (nullValues != null) { 238 return nullValues.isNullValue(symbol); 239 } 240 return false; 241 } 242 243 public boolean isNullValue(int code) throws MaltChainedException { 244 if (nullValues != null) { 245 return nullValues.isNullValue(code); 246 } 247 return false; 248 } 249 250 public void copy(SymbolTable fromTable) throws MaltChainedException { 251 final SortedMap<Integer, TrieNode> fromCodeTable = ((TrieSymbolTable)fromTable).getCodeTable(); 252 int max = getValueCounter()-1; 253 for (Integer code : fromCodeTable.keySet()) { 254 final String str = trie.getValue(fromCodeTable.get(code), this); 255 final TrieNode node = trie.addValue(str, this, code); 256 codeTable.put(node.getEntry(this), node); //.getCode(), node); 257 if (max < code) { 258 max = code; 259 } 260 } 261 setValueCounter(max+1); 262 } 263 264 public SortedMap<Integer, TrieNode> getCodeTable() { 265 return codeTable; 266 } 267 268 public Set<Integer> getCodes() { 269 return codeTable.keySet(); 270 } 271 272 protected Trie getTrie() { 273 return trie; 274 } 275 276 public boolean equals(Object obj) { 277 if (this == obj) 278 return true; 279 if (obj == null) 280 return false; 281 if (getClass() != obj.getClass()) 282 return false; 283 final TrieSymbolTable other = (TrieSymbolTable)obj; 284 return ((name == null) ? other.name == null : name.equals(other.name)); 285 } 286 287 public int hashCode() { 288 if (cachedHash == 0) { 289 cachedHash = 31 * 7 + (null == name ? 0 : name.hashCode()); 290 } 291 return cachedHash; 292 } 293 294 public String toString() { 295 final StringBuilder sb = new StringBuilder(); 296 sb.append(name); 297 sb.append(" "); 298 sb.append(valueCounter); 299 return sb.toString(); 300 } 301 }