001 package org.maltparser.core.symbol.trie; 002 003 import java.io.BufferedReader; 004 import java.io.BufferedWriter; 005 import java.io.IOException; 006 import java.util.Set; 007 import java.util.SortedMap; 008 import java.util.TreeMap; 009 010 import org.apache.log4j.Logger; 011 import org.maltparser.core.exception.MaltChainedException; 012 import org.maltparser.core.io.dataformat.ColumnDescription; 013 import org.maltparser.core.symbol.SymbolException; 014 import org.maltparser.core.symbol.SymbolTable; 015 import org.maltparser.core.symbol.nullvalue.InputNullValues; 016 import org.maltparser.core.symbol.nullvalue.NullValues; 017 import org.maltparser.core.symbol.nullvalue.OutputNullValues; 018 import org.maltparser.core.symbol.nullvalue.NullValues.NullValueId; 019 /** 020 021 @author Johan Hall 022 @since 1.0 023 */ 024 public class TrieSymbolTable implements SymbolTable { 025 private final String name; 026 private final Trie trie; 027 private final SortedMap<Integer, TrieNode> codeTable; 028 private int columnCategory; 029 private NullValues nullValues; 030 private int valueCounter; 031 /** Cache the hash code for the symbol table */ 032 private int cachedHash; 033 034 public TrieSymbolTable(String name, Trie trie, int columnCategory, String nullValueStrategy) throws MaltChainedException { 035 this.name = name; 036 this.trie = trie; 037 this.columnCategory = columnCategory; 038 codeTable = new TreeMap<Integer, TrieNode>(); 039 if (columnCategory == ColumnDescription.INPUT) { 040 nullValues = new InputNullValues(nullValueStrategy, this); 041 } else if (columnCategory == ColumnDescription.DEPENDENCY_EDGE_LABEL) { 042 nullValues = new OutputNullValues(nullValueStrategy, this, null); 043 } else { 044 nullValues = new InputNullValues(nullValueStrategy, this); 045 } 046 valueCounter = nullValues.getNextCode(); 047 } 048 049 public TrieSymbolTable(String name, Trie trie, int columnCategory, String nullValueStrategy, String rootLabel) throws MaltChainedException { 050 this.name = name; 051 this.trie = trie; 052 this.columnCategory = columnCategory; 053 codeTable = new TreeMap<Integer, TrieNode>(); 054 if (columnCategory == ColumnDescription.INPUT) { 055 nullValues = new InputNullValues(nullValueStrategy, this); 056 } else if (columnCategory == ColumnDescription.DEPENDENCY_EDGE_LABEL) { 057 nullValues = new OutputNullValues(nullValueStrategy, this, rootLabel); 058 } 059 valueCounter = nullValues.getNextCode(); 060 } 061 062 public TrieSymbolTable(String name, Trie trie) { 063 this.name = name; 064 this.trie = trie; 065 codeTable = new TreeMap<Integer, TrieNode>(); 066 nullValues = new InputNullValues("one", this); 067 //nullValues = null; 068 valueCounter = 1; 069 } 070 071 public int addSymbol(String symbol) throws MaltChainedException { 072 if (nullValues == null || !nullValues.isNullValue(symbol)) { 073 final TrieNode node = trie.addValue(symbol, this, -1); 074 final int code = node.getEntry(this).getCode(); 075 if (!codeTable.containsKey(code)) { 076 codeTable.put(code, node); 077 } 078 return code; 079 } else { 080 return nullValues.symbolToCode(symbol); 081 } 082 } 083 084 public int addSymbol(StringBuilder symbol) throws MaltChainedException { 085 if (nullValues == null || !nullValues.isNullValue(symbol)) { 086 final TrieNode node = trie.addValue(symbol, this, -1); 087 final int code = node.getEntry(this).getCode(); 088 if (!codeTable.containsKey(code)) { 089 codeTable.put(code, node); 090 } 091 return code; 092 } else { 093 return nullValues.symbolToCode(symbol); 094 } 095 } 096 097 public String getSymbolCodeToString(int code) throws MaltChainedException { 098 if (code >= 0) { 099 if (nullValues == null || !nullValues.isNullValue(code)) { 100 if (trie == null) { 101 throw new SymbolException("The symbol table is corrupt. "); 102 } 103 return trie.getValue(codeTable.get(code), this); 104 } else { 105 return nullValues.codeToSymbol(code); 106 } 107 } else { 108 throw new SymbolException("The symbol code '"+code+"' cannot be found in the symbol table. "); 109 } 110 } 111 112 public int getSymbolStringToCode(String symbol) throws MaltChainedException { 113 if (symbol != null) { 114 if (nullValues == null || !nullValues.isNullValue(symbol)) { 115 if (trie == null) { 116 throw new SymbolException("The symbol table is corrupt. "); 117 } 118 final TrieEntry entry = trie.getEntry(symbol, this); 119 if (entry == null) { 120 throw new SymbolException("Could not find the symbol '"+symbol+"' in the symbol table. "); 121 } 122 return entry.getCode(); 123 } else { 124 return nullValues.symbolToCode(symbol); 125 } 126 } else { 127 throw new SymbolException("The symbol code '"+symbol+"' cannot be found in the symbol table. "); 128 } 129 } 130 131 public String getNullValueStrategy() { 132 if (nullValues == null) { 133 return null; 134 } 135 return nullValues.getNullValueStrategy(); 136 } 137 138 139 public int getColumnCategory() { 140 return columnCategory; 141 } 142 143 public boolean getKnown(int code) { 144 if (code >= 0) { 145 if (nullValues == null || !nullValues.isNullValue(code)) { 146 return codeTable.get(code).getEntry(this).isKnown(); 147 } else { 148 return true; 149 } 150 } else { 151 return false; 152 } 153 } 154 155 public boolean getKnown(String symbol) { 156 if (nullValues == null || !nullValues.isNullValue(symbol)) { 157 final TrieEntry entry = trie.getEntry(symbol, this); 158 if (entry == null) { 159 return false; 160 } 161 return entry.isKnown(); 162 } else { 163 return true; 164 } 165 } 166 167 public void makeKnown(int code) { 168 if (code >= 0) { 169 if (nullValues == null || !nullValues.isNullValue(code)) { 170 codeTable.get(code).getEntry(this).setKnown(true); 171 } 172 } 173 } 174 175 public void printSymbolTable(Logger logger) throws MaltChainedException { 176 for (Integer code : codeTable.keySet()) { 177 logger.info(code+"\t"+trie.getValue(codeTable.get(code), this)+"\n"); 178 } 179 } 180 181 public void saveHeader(BufferedWriter out) throws MaltChainedException { 182 try { 183 out.append('\t'); 184 out.append(getName()); 185 out.append('\t'); 186 out.append(Integer.toString(getColumnCategory())); 187 out.append('\t'); 188 out.append(getNullValueStrategy()); 189 out.append('\t'); 190 if (nullValues instanceof OutputNullValues && ((OutputNullValues)nullValues).getRootLabel() != null) { 191 out.append(((OutputNullValues)nullValues).getRootLabel()); 192 } else { 193 out.append("#DUMMY#"); 194 } 195 out.append('\n'); 196 } catch (IOException e) { 197 throw new SymbolException("Could not save the symbol table. ", e); 198 } 199 } 200 201 202 public void save(BufferedWriter out) throws MaltChainedException { 203 try { 204 out.write(name); 205 out.write('\n'); 206 for (Integer code : codeTable.keySet()) { 207 out.write(code+""); 208 out.write('\t'); 209 out.write(trie.getValue(codeTable.get(code), this)); 210 out.write('\n'); 211 } 212 out.write('\n'); 213 } catch (IOException e) { 214 throw new SymbolException("Could not save the symbol table. ", e); 215 } 216 } 217 218 public void load(BufferedReader in) throws MaltChainedException { 219 int max = 0; 220 int index = 0; 221 String fileLine; 222 try { 223 while ((fileLine = in.readLine()) != null) { 224 if (fileLine.length() == 0 || (index = fileLine.indexOf('\t')) == -1) { 225 setValueCounter(max+1); 226 break; 227 } 228 int code = Integer.parseInt(fileLine.substring(0,index)); 229 final String str = fileLine.substring(index+1); 230 final TrieNode node = trie.addValue(str, this, code); 231 codeTable.put(node.getEntry(this).getCode(), node); 232 if (max < code) { 233 max = code; 234 } 235 } 236 } catch (NumberFormatException e) { 237 throw new SymbolException("The symbol table file (.sym) contains a non-integer value in the first column. ", e); 238 } catch (IOException e) { 239 throw new SymbolException("Could not load the symbol table. ", e); 240 } 241 } 242 243 public String getName() { 244 return name; 245 } 246 247 public int getValueCounter() { 248 return valueCounter; 249 } 250 251 private void setValueCounter(int valueCounter) { 252 this.valueCounter = valueCounter; 253 } 254 255 protected void updateValueCounter(int code) { 256 if (code > valueCounter) { 257 valueCounter = code; 258 } 259 } 260 261 protected int increaseValueCounter() { 262 return valueCounter++; 263 } 264 265 public int getNullValueCode(NullValueId nullValueIdentifier) throws MaltChainedException { 266 if (nullValues == null) { 267 throw new SymbolException("The symbol table does not have any null-values. "); 268 } 269 return nullValues.nullvalueToCode(nullValueIdentifier); 270 } 271 272 public String getNullValueSymbol(NullValueId nullValueIdentifier) throws MaltChainedException { 273 if (nullValues == null) { 274 throw new SymbolException("The symbol table does not have any null-values. "); 275 } 276 return nullValues.nullvalueToSymbol(nullValueIdentifier); 277 } 278 279 public boolean isNullValue(String symbol) throws MaltChainedException { 280 if (nullValues != null) { 281 return nullValues.isNullValue(symbol); 282 } 283 return false; 284 } 285 286 public boolean isNullValue(int code) throws MaltChainedException { 287 if (nullValues != null) { 288 return nullValues.isNullValue(code); 289 } 290 return false; 291 } 292 293 public void copy(SymbolTable fromTable) throws MaltChainedException { 294 final SortedMap<Integer, TrieNode> fromCodeTable = ((TrieSymbolTable)fromTable).getCodeTable(); 295 int max = getValueCounter()-1; 296 for (Integer code : fromCodeTable.keySet()) { 297 final String str = trie.getValue(fromCodeTable.get(code), this); 298 final TrieNode node = trie.addValue(str, this, code); 299 codeTable.put(node.getEntry(this).getCode(), node); 300 if (max < code) { 301 max = code; 302 } 303 } 304 setValueCounter(max+1); 305 } 306 307 public SortedMap<Integer, TrieNode> getCodeTable() { 308 return codeTable; 309 } 310 311 public Set<Integer> getCodes() { 312 return codeTable.keySet(); 313 } 314 315 protected Trie getTrie() { 316 return trie; 317 } 318 319 public boolean equals(Object obj) { 320 if (this == obj) 321 return true; 322 if (obj == null) 323 return false; 324 if (getClass() != obj.getClass()) 325 return false; 326 return ((name == null) ? ((TrieSymbolTable)obj).name == null : name.equals(((TrieSymbolTable)obj).name)); 327 } 328 329 public int hashCode() { 330 if (cachedHash == 0) { 331 cachedHash = 31 * 7 + (null == name ? 0 : name.hashCode()); 332 } 333 return cachedHash; 334 } 335 336 public String toString() { 337 final StringBuilder sb = new StringBuilder(); 338 sb.append(name); 339 sb.append(" "); 340 sb.append(valueCounter); 341 return sb.toString(); 342 } 343 }