001package org.maltparser.parser.guide.instance; 002 003import java.lang.reflect.Constructor; 004import java.lang.reflect.InvocationTargetException; 005import java.util.ArrayList; 006import java.util.Formatter; 007 008import org.maltparser.core.exception.MaltChainedException; 009import org.maltparser.core.feature.FeatureModel; 010import org.maltparser.core.feature.FeatureVector; 011import org.maltparser.core.feature.function.FeatureFunction; 012import org.maltparser.core.feature.function.Modifiable; 013import org.maltparser.core.syntaxgraph.DependencyStructure; 014import org.maltparser.ml.LearningMethod; 015import org.maltparser.ml.lib.LibLinear; 016import org.maltparser.ml.lib.LibSvm; 017import org.maltparser.parser.guide.ClassifierGuide; 018import org.maltparser.parser.guide.GuideException; 019import org.maltparser.parser.guide.Model; 020import org.maltparser.parser.history.action.SingleDecision; 021 022 023/** 024 025@author Johan Hall 026*/ 027public class AtomicModel implements InstanceModel { 028 public static final Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class }; 029 private final Model parent; 030 private final String modelName; 031// private final FeatureVector featureVector; 032 private final int index; 033 private final LearningMethod method; 034 private int frequency = 0; 035 036 037 /** 038 * Constructs an atomic model. 039 * 040 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 041 * or the master divide model) and n is number of divide models. 042 * @param parent the parent guide model. 043 * @throws MaltChainedException 044 */ 045 public AtomicModel(int index, Model parent) throws MaltChainedException { 046 this.parent = parent; 047 this.index = index; 048 if (index == -1) { 049 this.modelName = parent.getModelName()+"."; 050 } else { 051 this.modelName = parent.getModelName()+"."+new Formatter().format("%03d", index)+"."; 052 } 053// this.featureVector = featureVector; 054 this.frequency = 0; 055 Integer learnerMode = null; 056 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 057 learnerMode = LearningMethod.CLASSIFY; 058 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 059 learnerMode = LearningMethod.BATCH; 060 } 061 062 // start init learning method 063 Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner"); 064 if (clazz == org.maltparser.ml.lib.LibSvm.class) { 065 this.method = new LibSvm(this, learnerMode); 066 } else if (clazz == org.maltparser.ml.lib.LibLinear.class) { 067 this.method = new LibLinear(this, learnerMode); 068 } else { 069 Object[] arguments = {this, learnerMode}; 070 try { 071 Constructor<?> constructor = clazz.getConstructor(argTypes); 072 this.method = (LearningMethod)constructor.newInstance(arguments); 073 } catch (NoSuchMethodException e) { 074 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 075 } catch (InstantiationException e) { 076 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 077 } catch (IllegalAccessException e) { 078 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 079 } catch (InvocationTargetException e) { 080 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 081 } 082 } 083 // end init learning method 084 085 if (learnerMode == LearningMethod.BATCH && index == -1 && getGuide().getConfiguration() != null) { 086 getGuide().getConfiguration().writeInfoToConfigFile(method.toString()); 087 } 088 } 089 090 public void addInstance(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException { 091 try { 092 method.addInstance(decision, featureVector); 093 } catch (NullPointerException e) { 094 throw new GuideException("The learner cannot be found. ", e); 095 } 096 } 097 098 099 public void noMoreInstances(FeatureModel featureModel) throws MaltChainedException { 100 try { 101 method.noMoreInstances(); 102 } catch (NullPointerException e) { 103 throw new GuideException("The learner cannot be found. ", e); 104 } 105 } 106 107 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { 108 try { 109 method.finalizeSentence(dependencyGraph); 110 } catch (NullPointerException e) { 111 throw new GuideException("The learner cannot be found. ", e); 112 } 113 } 114 115 public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException { 116 try { 117 return method.predict(featureVector, decision); 118 } catch (NullPointerException e) { 119 throw new GuideException("The learner cannot be found. ", e); 120 } 121 } 122 123 public FeatureVector predictExtract(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException { 124 try { 125 126 if (method.predict(featureVector, decision)) { 127 return featureVector; 128 } 129 return null; 130 } catch (NullPointerException e) { 131 throw new GuideException("The learner cannot be found. ", e); 132 } 133 } 134 135 public FeatureVector extract(FeatureVector featureVector) throws MaltChainedException { 136 return featureVector; 137 } 138 139 public void terminate() throws MaltChainedException { 140 if (method != null) { 141 method.terminate(); 142 } 143 } 144 145 /** 146 * Moves all instance from this atomic model into the destination atomic model and add the divide feature. 147 * This method is used by the feature divide model to sum up all model below a certain threshold. 148 * 149 * @param model the destination atomic model 150 * @param divideFeature the divide feature 151 * @param divideFeatureIndexVector the divide feature index vector 152 * @throws MaltChainedException 153 */ 154 public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException { 155 if (method == null) { 156 throw new GuideException("The learner cannot be found. "); 157 } else if (model == null) { 158 throw new GuideException("The guide model cannot be found. "); 159 } else if (divideFeature == null) { 160 throw new GuideException("The divide feature cannot be found. "); 161 } else if (divideFeatureIndexVector == null) { 162 throw new GuideException("The divide feature index vector cannot be found. "); 163 } 164 ((Modifiable)divideFeature).setFeatureValue(index); 165 method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector); 166 method.terminate(); 167 } 168 169 /** 170 * Invokes the train() of the learning method 171 * 172 * @throws MaltChainedException 173 */ 174 public void train() throws MaltChainedException { 175 try { 176 method.train(); 177 method.terminate(); 178 } catch (NullPointerException e) { 179 throw new GuideException("The learner cannot be found. ", e); 180 } 181 182 183 } 184 185 /** 186 * Returns the parent guide model 187 * 188 * @return the parent guide model 189 */ 190 public Model getParent() throws MaltChainedException { 191 if (parent == null) { 192 throw new GuideException("The atomic model can only be used by a parent model. "); 193 } 194 return parent; 195 } 196 197 198 public String getModelName() { 199 return modelName; 200 } 201 202 /** 203 * Returns the feature vector used by this atomic model 204 * 205 * @return a feature vector object 206 */ 207// public FeatureVector getFeatures() { 208// return featureVector; 209// } 210 211 public ClassifierGuide getGuide() { 212 return parent.getGuide(); 213 } 214 215 /** 216 * Returns the index of the atomic model 217 * 218 * @return the index of the atomic model 219 */ 220 public int getIndex() { 221 return index; 222 } 223 224 /** 225 * Returns the frequency (number of instances) 226 * 227 * @return the frequency (number of instances) 228 */ 229 public int getFrequency() { 230 return frequency; 231 } 232 233 /** 234 * Increase the frequency by 1 235 */ 236 public void increaseFrequency() { 237 if (parent instanceof InstanceModel) { 238 ((InstanceModel)parent).increaseFrequency(); 239 } 240 frequency++; 241 } 242 243 public void decreaseFrequency() { 244 if (parent instanceof InstanceModel) { 245 ((InstanceModel)parent).decreaseFrequency(); 246 } 247 frequency--; 248 } 249 /** 250 * Sets the frequency (number of instances) 251 * 252 * @param frequency (number of instances) 253 */ 254 protected void setFrequency(int frequency) { 255 this.frequency = frequency; 256 } 257 258 /** 259 * Returns a learner object 260 * 261 * @return a learner object 262 */ 263 public LearningMethod getMethod() { 264 return method; 265 } 266 267 268 /* (non-Javadoc) 269 * @see java.lang.Object#toString() 270 */ 271 public String toString() { 272 final StringBuilder sb = new StringBuilder(); 273 sb.append(method.toString()); 274 return sb.toString(); 275 } 276}