001 package org.maltparser.parser.guide.instance; 002 003 import java.io.IOException; 004 import java.lang.reflect.Constructor; 005 import java.lang.reflect.InvocationTargetException; 006 import java.util.ArrayList; 007 import java.util.Formatter; 008 009 import org.maltparser.core.exception.MaltChainedException; 010 import org.maltparser.core.feature.FeatureVector; 011 import org.maltparser.core.feature.function.FeatureFunction; 012 import org.maltparser.core.feature.function.Modifiable; 013 import org.maltparser.core.syntaxgraph.DependencyStructure; 014 import org.maltparser.ml.LearningMethod; 015 import org.maltparser.parser.guide.ClassifierGuide; 016 import org.maltparser.parser.guide.GuideException; 017 import org.maltparser.parser.guide.Model; 018 import org.maltparser.parser.history.action.SingleDecision; 019 020 021 /** 022 023 @author Johan Hall 024 @since 1.0 025 */ 026 public class AtomicModel implements InstanceModel { 027 private Model parent; 028 private String modelName; 029 private FeatureVector featureVector; 030 private int index; 031 private int frequency = 0; 032 private LearningMethod method; 033 034 035 /** 036 * Constructs an atomic model. 037 * 038 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 039 * or the master divide model) and n is number of divide models. 040 * @param features the feature vector used by the atomic model. 041 * @param parent the parent guide model. 042 * @throws MaltChainedException 043 */ 044 public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException { 045 setParent(parent); 046 setIndex(index); 047 if (index == -1) { 048 setModelName(parent.getModelName()+"."); 049 } else { 050 setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+"."); 051 } 052 setFeatures(features); 053 setFrequency(0); 054 initMethod(); 055 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) { 056 try { 057 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString()); 058 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush(); 059 } catch (IOException e) { 060 throw new GuideException("Could not write learner settings to the information file. ", e); 061 } 062 } 063 } 064 065 public void addInstance(SingleDecision decision) throws MaltChainedException { 066 try { 067 method.addInstance(decision, featureVector); 068 } catch (NullPointerException e) { 069 throw new GuideException("The learner cannot be found. ", e); 070 } 071 } 072 073 074 public void noMoreInstances() throws MaltChainedException { 075 try { 076 method.noMoreInstances(); 077 } catch (NullPointerException e) { 078 throw new GuideException("The learner cannot be found. ", e); 079 } 080 } 081 082 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { 083 try { 084 method.finalizeSentence(dependencyGraph); 085 } catch (NullPointerException e) { 086 throw new GuideException("The learner cannot be found. ", e); 087 } 088 } 089 090 public boolean predict(SingleDecision decision) throws MaltChainedException { 091 try { 092 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 093 throw new GuideException("Cannot predict during batch training. "); 094 } 095 return method.predict(featureVector, decision); 096 } catch (NullPointerException e) { 097 throw new GuideException("The learner cannot be found. ", e); 098 } 099 } 100 101 public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException { 102 try { 103 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 104 throw new GuideException("Cannot predict during batch training. "); 105 } 106 if (method.predict(featureVector, decision)) { 107 return featureVector; 108 } 109 return null; 110 } catch (NullPointerException e) { 111 throw new GuideException("The learner cannot be found. ", e); 112 } 113 } 114 115 public FeatureVector extract() throws MaltChainedException { 116 return featureVector; 117 } 118 119 public void terminate() throws MaltChainedException { 120 if (method != null) { 121 method.terminate(); 122 method = null; 123 } 124 featureVector = null; 125 parent = null; 126 } 127 128 /** 129 * Moves all instance from this atomic model into the destination atomic model and add the divide feature. 130 * This method is used by the feature divide model to sum up all model below a certain threshold. 131 * 132 * @param model the destination atomic model 133 * @param divideFeature the divide feature 134 * @param divideFeatureIndexVector the divide feature index vector 135 * @throws MaltChainedException 136 */ 137 public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException { 138 if (method == null) { 139 throw new GuideException("The learner cannot be found. "); 140 } else if (model == null) { 141 throw new GuideException("The guide model cannot be found. "); 142 } else if (divideFeature == null) { 143 throw new GuideException("The divide feature cannot be found. "); 144 } else if (divideFeatureIndexVector == null) { 145 throw new GuideException("The divide feature index vector cannot be found. "); 146 } 147 ((Modifiable)divideFeature).setFeatureValue(index); 148 method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector); 149 method.terminate(); 150 method = null; 151 } 152 153 /** 154 * Invokes the train() of the learning method 155 * 156 * @throws MaltChainedException 157 */ 158 public void train() throws MaltChainedException { 159 try { 160 method.train(featureVector); 161 method.terminate(); 162 method = null; 163 164 } catch (NullPointerException e) { 165 throw new GuideException("The learner cannot be found. ", e); 166 } 167 168 169 } 170 171 /** 172 * Initialize the learning method according to the option --learner-method. 173 * 174 * @throws MaltChainedException 175 */ 176 public void initMethod() throws MaltChainedException { 177 Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner"); 178 // if (clazz == org.maltparser.ml.libsvm.Libsvm.class && (Boolean)getGuide().getConfiguration().getOptionValue("malt0.4", "behavior") == true) { 179 // try { 180 // clazz = Class.forName("org.maltparser.ml.libsvm.malt04.LibsvmMalt04"); 181 // } catch (ClassNotFoundException e) { 182 // throw new GuideException("Could not find the class 'org.maltparser.ml.libsvm.malt04.LibsvmMalt04'. ", e); 183 // } 184 // } 185 Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class }; 186 Object[] arguments = new Object[2]; 187 arguments[0] = this; 188 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 189 arguments[1] = LearningMethod.CLASSIFY; 190 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 191 arguments[1] = LearningMethod.BATCH; 192 } 193 194 try { 195 Constructor<?> constructor = clazz.getConstructor(argTypes); 196 this.method = (LearningMethod)constructor.newInstance(arguments); 197 } catch (NoSuchMethodException e) { 198 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 199 } catch (InstantiationException e) { 200 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 201 } catch (IllegalAccessException e) { 202 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 203 } catch (InvocationTargetException e) { 204 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 205 } 206 } 207 208 209 210 /** 211 * Returns the parent guide model 212 * 213 * @return the parent guide model 214 */ 215 public Model getParent() throws MaltChainedException { 216 if (parent == null) { 217 throw new GuideException("The atomic model can only be used by a parent model. "); 218 } 219 return parent; 220 } 221 222 /** 223 * Sets the parent guide model 224 * 225 * @param parent the parent guide model 226 */ 227 protected void setParent(Model parent) { 228 this.parent = parent; 229 } 230 231 public String getModelName() { 232 return modelName; 233 } 234 235 /** 236 * Sets the name of the atomic model 237 * 238 * @param modelName the name of the atomic model 239 */ 240 protected void setModelName(String modelName) { 241 this.modelName = modelName; 242 } 243 244 /** 245 * Returns the feature vector used by this atomic model 246 * 247 * @return a feature vector object 248 */ 249 public FeatureVector getFeatures() { 250 return featureVector; 251 } 252 253 /** 254 * Sets the feature vector used by the atomic model. 255 * 256 * @param features a feature vector object 257 */ 258 protected void setFeatures(FeatureVector features) { 259 this.featureVector = features; 260 } 261 262 public ClassifierGuide getGuide() { 263 return parent.getGuide(); 264 } 265 266 /** 267 * Returns the index of the atomic model 268 * 269 * @return the index of the atomic model 270 */ 271 public int getIndex() { 272 return index; 273 } 274 275 /** 276 * Sets the index of the model (-1..n), where -1 is a special value. 277 * 278 * @param index index value (-1..n) of the atomic model 279 */ 280 protected void setIndex(int index) { 281 this.index = index; 282 } 283 284 /** 285 * Returns the frequency (number of instances) 286 * 287 * @return the frequency (number of instances) 288 */ 289 public int getFrequency() { 290 return frequency; 291 } 292 293 /** 294 * Increase the frequency by 1 295 */ 296 public void increaseFrequency() { 297 if (parent instanceof InstanceModel) { 298 ((InstanceModel)parent).increaseFrequency(); 299 } 300 frequency++; 301 } 302 303 public void decreaseFrequency() { 304 if (parent instanceof InstanceModel) { 305 ((InstanceModel)parent).decreaseFrequency(); 306 } 307 frequency--; 308 } 309 /** 310 * Sets the frequency (number of instances) 311 * 312 * @param frequency (number of instances) 313 */ 314 protected void setFrequency(int frequency) { 315 this.frequency = frequency; 316 } 317 318 /** 319 * Returns a learner object 320 * 321 * @return a learner object 322 */ 323 public LearningMethod getMethod() { 324 return method; 325 } 326 327 328 /* (non-Javadoc) 329 * @see java.lang.Object#toString() 330 */ 331 public String toString() { 332 final StringBuilder sb = new StringBuilder(); 333 sb.append(method.toString()); 334 return sb.toString(); 335 } 336 }