001 package org.maltparser.parser.guide.instance;
002
003 import java.io.IOException;
004 import java.lang.reflect.Constructor;
005 import java.lang.reflect.InvocationTargetException;
006 import java.util.ArrayList;
007 import java.util.Formatter;
008
009 import org.maltparser.core.exception.MaltChainedException;
010 import org.maltparser.core.feature.FeatureVector;
011 import org.maltparser.core.feature.function.FeatureFunction;
012 import org.maltparser.core.feature.function.Modifiable;
013 import org.maltparser.core.syntaxgraph.DependencyStructure;
014 import org.maltparser.ml.LearningMethod;
015 import org.maltparser.parser.guide.ClassifierGuide;
016 import org.maltparser.parser.guide.GuideException;
017 import org.maltparser.parser.guide.Model;
018 import org.maltparser.parser.history.action.SingleDecision;
019
020
021 /**
022
023 @author Johan Hall
024 @since 1.0
025 */
026 public class AtomicModel implements InstanceModel {
027 private Model parent;
028 private String modelName;
029 private FeatureVector featureVector;
030 private int index;
031 private int frequency = 0;
032 private LearningMethod method;
033
034
035 /**
036 * Constructs an atomic model.
037 *
038 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model
039 * or the master divide model) and n is number of divide models.
040 * @param features the feature vector used by the atomic model.
041 * @param parent the parent guide model.
042 * @throws MaltChainedException
043 */
044 public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException {
045 setParent(parent);
046 setIndex(index);
047 if (index == -1) {
048 setModelName(parent.getModelName()+".");
049 } else {
050 setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+".");
051 }
052 setFeatures(features);
053 setFrequency(0);
054 initMethod();
055 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) {
056 try {
057 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString());
058 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush();
059 } catch (IOException e) {
060 throw new GuideException("Could not write learner settings to the information file. ", e);
061 }
062 }
063 }
064
065 public void addInstance(SingleDecision decision) throws MaltChainedException {
066 try {
067 method.addInstance(decision, featureVector);
068 } catch (NullPointerException e) {
069 throw new GuideException("The learner cannot be found. ", e);
070 }
071 }
072
073
074 public void noMoreInstances() throws MaltChainedException {
075 try {
076 method.noMoreInstances();
077 } catch (NullPointerException e) {
078 throw new GuideException("The learner cannot be found. ", e);
079 }
080 }
081
082 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
083 try {
084 method.finalizeSentence(dependencyGraph);
085 } catch (NullPointerException e) {
086 throw new GuideException("The learner cannot be found. ", e);
087 }
088 }
089
090 public boolean predict(SingleDecision decision) throws MaltChainedException {
091 try {
092 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
093 throw new GuideException("Cannot predict during batch training. ");
094 }
095 return method.predict(featureVector, decision);
096 } catch (NullPointerException e) {
097 throw new GuideException("The learner cannot be found. ", e);
098 }
099 }
100
101 public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException {
102 try {
103 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
104 throw new GuideException("Cannot predict during batch training. ");
105 }
106 if (method.predict(featureVector, decision)) {
107 return featureVector;
108 }
109 return null;
110 } catch (NullPointerException e) {
111 throw new GuideException("The learner cannot be found. ", e);
112 }
113 }
114
115 public FeatureVector extract() throws MaltChainedException {
116 return featureVector;
117 }
118
119 public void terminate() throws MaltChainedException {
120 if (method != null) {
121 method.terminate();
122 method = null;
123 }
124 featureVector = null;
125 parent = null;
126 }
127
128 /**
129 * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
130 * This method is used by the feature divide model to sum up all model below a certain threshold.
131 *
132 * @param model the destination atomic model
133 * @param divideFeature the divide feature
134 * @param divideFeatureIndexVector the divide feature index vector
135 * @throws MaltChainedException
136 */
137 public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
138 if (method == null) {
139 throw new GuideException("The learner cannot be found. ");
140 } else if (model == null) {
141 throw new GuideException("The guide model cannot be found. ");
142 } else if (divideFeature == null) {
143 throw new GuideException("The divide feature cannot be found. ");
144 } else if (divideFeatureIndexVector == null) {
145 throw new GuideException("The divide feature index vector cannot be found. ");
146 }
147 ((Modifiable)divideFeature).setFeatureValue(index);
148 method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
149 method.terminate();
150 method = null;
151 }
152
153 /**
154 * Invokes the train() of the learning method
155 *
156 * @throws MaltChainedException
157 */
158 public void train() throws MaltChainedException {
159 try {
160 method.train(featureVector);
161 method.terminate();
162 method = null;
163
164 } catch (NullPointerException e) {
165 throw new GuideException("The learner cannot be found. ", e);
166 }
167
168
169 }
170
171 /**
172 * Initialize the learning method according to the option --learner-method.
173 *
174 * @throws MaltChainedException
175 */
176 public void initMethod() throws MaltChainedException {
177 Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
178 // if (clazz == org.maltparser.ml.libsvm.Libsvm.class && (Boolean)getGuide().getConfiguration().getOptionValue("malt0.4", "behavior") == true) {
179 // try {
180 // clazz = Class.forName("org.maltparser.ml.libsvm.malt04.LibsvmMalt04");
181 // } catch (ClassNotFoundException e) {
182 // throw new GuideException("Could not find the class 'org.maltparser.ml.libsvm.malt04.LibsvmMalt04'. ", e);
183 // }
184 // }
185 Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
186 Object[] arguments = new Object[2];
187 arguments[0] = this;
188 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
189 arguments[1] = LearningMethod.CLASSIFY;
190 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
191 arguments[1] = LearningMethod.BATCH;
192 }
193
194 try {
195 Constructor<?> constructor = clazz.getConstructor(argTypes);
196 this.method = (LearningMethod)constructor.newInstance(arguments);
197 } catch (NoSuchMethodException e) {
198 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
199 } catch (InstantiationException e) {
200 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
201 } catch (IllegalAccessException e) {
202 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
203 } catch (InvocationTargetException e) {
204 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
205 }
206 }
207
208
209
210 /**
211 * Returns the parent guide model
212 *
213 * @return the parent guide model
214 */
215 public Model getParent() throws MaltChainedException {
216 if (parent == null) {
217 throw new GuideException("The atomic model can only be used by a parent model. ");
218 }
219 return parent;
220 }
221
222 /**
223 * Sets the parent guide model
224 *
225 * @param parent the parent guide model
226 */
227 protected void setParent(Model parent) {
228 this.parent = parent;
229 }
230
231 public String getModelName() {
232 return modelName;
233 }
234
235 /**
236 * Sets the name of the atomic model
237 *
238 * @param modelName the name of the atomic model
239 */
240 protected void setModelName(String modelName) {
241 this.modelName = modelName;
242 }
243
244 /**
245 * Returns the feature vector used by this atomic model
246 *
247 * @return a feature vector object
248 */
249 public FeatureVector getFeatures() {
250 return featureVector;
251 }
252
253 /**
254 * Sets the feature vector used by the atomic model.
255 *
256 * @param features a feature vector object
257 */
258 protected void setFeatures(FeatureVector features) {
259 this.featureVector = features;
260 }
261
262 public ClassifierGuide getGuide() {
263 return parent.getGuide();
264 }
265
266 /**
267 * Returns the index of the atomic model
268 *
269 * @return the index of the atomic model
270 */
271 public int getIndex() {
272 return index;
273 }
274
275 /**
276 * Sets the index of the model (-1..n), where -1 is a special value.
277 *
278 * @param index index value (-1..n) of the atomic model
279 */
280 protected void setIndex(int index) {
281 this.index = index;
282 }
283
284 /**
285 * Returns the frequency (number of instances)
286 *
287 * @return the frequency (number of instances)
288 */
289 public int getFrequency() {
290 return frequency;
291 }
292
293 /**
294 * Increase the frequency by 1
295 */
296 public void increaseFrequency() {
297 if (parent instanceof InstanceModel) {
298 ((InstanceModel)parent).increaseFrequency();
299 }
300 frequency++;
301 }
302
303 public void decreaseFrequency() {
304 if (parent instanceof InstanceModel) {
305 ((InstanceModel)parent).decreaseFrequency();
306 }
307 frequency--;
308 }
309 /**
310 * Sets the frequency (number of instances)
311 *
312 * @param frequency (number of instances)
313 */
314 protected void setFrequency(int frequency) {
315 this.frequency = frequency;
316 }
317
318 /**
319 * Returns a learner object
320 *
321 * @return a learner object
322 */
323 public LearningMethod getMethod() {
324 return method;
325 }
326
327
328 /* (non-Javadoc)
329 * @see java.lang.Object#toString()
330 */
331 public String toString() {
332 final StringBuilder sb = new StringBuilder();
333 sb.append(method.toString());
334 return sb.toString();
335 }
336 }