001    package org.maltparser.parser.guide.instance;
002    
003    import java.io.IOException;
004    import java.lang.reflect.Constructor;
005    import java.lang.reflect.InvocationTargetException;
006    import java.util.ArrayList;
007    import java.util.Formatter;
008    
009    import org.maltparser.core.exception.MaltChainedException;
010    import org.maltparser.core.feature.FeatureVector;
011    import org.maltparser.core.feature.function.FeatureFunction;
012    import org.maltparser.core.feature.function.Modifiable;
013    import org.maltparser.core.syntaxgraph.DependencyStructure;
014    import org.maltparser.ml.LearningMethod;
015    import org.maltparser.parser.guide.ClassifierGuide;
016    import org.maltparser.parser.guide.GuideException;
017    import org.maltparser.parser.guide.Model;
018    import org.maltparser.parser.history.action.SingleDecision;
019    
020    
021    /**
022    
023    @author Johan Hall
024    @since 1.0
025    */
026    public class AtomicModel implements InstanceModel {
027            private Model parent;
028            private String modelName;
029            private FeatureVector featureVector;
030            private int index;
031            private int frequency = 0;
032            private LearningMethod method;
033    
034            
035            /**
036             * Constructs an atomic model.
037             * 
038             * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 
039             * or the master divide model) and n is number of divide models.
040             * @param features the feature vector used by the atomic model.
041             * @param parent the parent guide model.
042             * @throws MaltChainedException
043             */
044            public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException {
045                    setParent(parent);
046                    setIndex(index);
047                    if (index == -1) {
048                            setModelName(parent.getModelName()+".");
049                    } else {
050                            setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+".");
051                    }
052                    setFeatures(features);
053                    setFrequency(0);
054                    initMethod();
055                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) {
056                            try {
057                                    getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString());
058                                    getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush();
059                            } catch (IOException e) {
060                                    throw new GuideException("Could not write learner settings to the information file. ", e);
061                            }
062                    }
063            }
064            
065            public void addInstance(SingleDecision decision) throws MaltChainedException {
066                    try {
067                            method.addInstance(decision, featureVector);
068                    } catch (NullPointerException e) {
069                            throw new GuideException("The learner cannot be found. ", e);
070                    }
071            }
072    
073            
074            public void noMoreInstances() throws MaltChainedException {
075                    try {
076                            method.noMoreInstances();
077                    } catch (NullPointerException e) {
078                            throw new GuideException("The learner cannot be found. ", e);
079                    }
080            }
081            
082            public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
083                    try {
084                            method.finalizeSentence(dependencyGraph);
085                    } catch (NullPointerException e) {
086                            throw new GuideException("The learner cannot be found. ", e);
087                    }
088            }
089    
090            public boolean predict(SingleDecision decision) throws MaltChainedException {
091                    try {
092    //                      if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
093    //                              throw new GuideException("Cannot predict during batch training. ");
094    //                      }
095                            return method.predict(featureVector, decision);
096                    } catch (NullPointerException e) {
097                            throw new GuideException("The learner cannot be found. ", e);
098                    }
099            }
100    
101            public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException {
102                    try {
103    //                      if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
104    //                              throw new GuideException("Cannot predict during batch training. ");
105    //                      }
106                            if (method.predict(featureVector, decision)) {
107                                    return featureVector;
108                            }
109                            return null;
110                    } catch (NullPointerException e) {
111                            throw new GuideException("The learner cannot be found. ", e);
112                    }
113            }
114            
115            public FeatureVector extract() throws MaltChainedException {
116                    return featureVector;
117            }
118            
119            public void terminate() throws MaltChainedException {
120                    if (method != null) {
121                            method.terminate();
122                            method = null;
123                    }
124                    featureVector = null;
125                    parent = null;
126            }
127            
128            /**
129             * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
130             * This method is used by the feature divide model to sum up all model below a certain threshold.
131             * 
132             * @param model the destination atomic model 
133             * @param divideFeature the divide feature
134             * @param divideFeatureIndexVector the divide feature index vector
135             * @throws MaltChainedException
136             */
137            public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
138                    if (method == null) {
139                            throw new GuideException("The learner cannot be found. ");
140                    } else if (model == null) {
141                            throw new GuideException("The guide model cannot be found. ");
142                    } else if (divideFeature == null) {
143                            throw new GuideException("The divide feature cannot be found. ");
144                    } else if (divideFeatureIndexVector == null) {
145                            throw new GuideException("The divide feature index vector cannot be found. ");
146                    }
147                    ((Modifiable)divideFeature).setFeatureValue(index);
148                    method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
149                    method.terminate();
150                    method = null;
151            }
152            
153            /**
154             * Invokes the train() of the learning method 
155             * 
156             * @throws MaltChainedException
157             */
158            public void train() throws MaltChainedException {
159                    try {
160                            method.train(featureVector);
161                            method.terminate();
162                            method = null;
163                            
164                    } catch (NullPointerException e) {      
165                            throw new GuideException("The learner cannot be found. ", e);
166                    }
167                    
168    
169            }
170            
171            /**
172             * Initialize the learning method according to the option --learner-method.
173             * 
174             * @throws MaltChainedException
175             */
176            public void initMethod() throws MaltChainedException {
177                    Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
178                    Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
179                    Object[] arguments = new Object[2];
180                    arguments[0] = this;
181                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
182                            arguments[1] = LearningMethod.CLASSIFY;
183                    } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
184                            arguments[1] = LearningMethod.BATCH;
185                    } 
186    
187                    try {   
188                            Constructor<?> constructor = clazz.getConstructor(argTypes);
189                            this.method = (LearningMethod)constructor.newInstance(arguments);
190                    } catch (NoSuchMethodException e) {
191                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
192                    } catch (InstantiationException e) {
193                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
194                    } catch (IllegalAccessException e) {
195                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
196                    } catch (InvocationTargetException e) {
197                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
198                    }
199            }
200            
201            
202            
203            /**
204             * Returns the parent guide model
205             * 
206             * @return the parent guide model
207             */
208            public Model getParent() throws MaltChainedException {
209                    if (parent == null) {
210                            throw new GuideException("The atomic model can only be used by a parent model. ");
211                    }
212                    return parent;
213            }
214    
215            /**
216             * Sets the parent guide model
217             * 
218             * @param parent the parent guide model
219             */
220            protected void setParent(Model parent) {
221                    this.parent = parent;
222            }
223    
224            public String getModelName() {
225                    return modelName;
226            }
227    
228            /**
229             * Sets the name of the atomic model
230             * 
231             * @param modelName the name of the atomic model
232             */
233            protected void setModelName(String modelName) {
234                    this.modelName = modelName;
235            }
236    
237            /**
238             * Returns the feature vector used by this atomic model
239             * 
240             * @return a feature vector object
241             */
242            public FeatureVector getFeatures() {
243                    return featureVector;
244            }
245    
246            /**
247             * Sets the feature vector used by the atomic model.
248             * 
249             * @param features a feature vector object
250             */
251            protected void setFeatures(FeatureVector features) {
252                    this.featureVector = features;
253            }
254    
255            public ClassifierGuide getGuide() {
256                    return parent.getGuide();
257            }
258            
259            /**
260             * Returns the index of the atomic model
261             * 
262             * @return the index of the atomic model
263             */
264            public int getIndex() {
265                    return index;
266            }
267    
268            /**
269             * Sets the index of the model (-1..n), where -1 is a special value.
270             * 
271             * @param index index value (-1..n) of the atomic model
272             */
273            protected void setIndex(int index) {
274                    this.index = index;
275            }
276    
277            /**
278             * Returns the frequency (number of instances)
279             * 
280             * @return the frequency (number of instances)
281             */
282            public int getFrequency() {
283                    return frequency;
284            }
285            
286            /**
287             * Increase the frequency by 1
288             */
289            public void increaseFrequency() {
290                    if (parent instanceof InstanceModel) {
291                            ((InstanceModel)parent).increaseFrequency();
292                    }
293                    frequency++;
294            }
295            
296            public void decreaseFrequency() {
297                    if (parent instanceof InstanceModel) {
298                            ((InstanceModel)parent).decreaseFrequency();
299                    }
300                    frequency--;
301            }
302            /**
303             * Sets the frequency (number of instances)
304             * 
305             * @param frequency (number of instances)
306             */
307            protected void setFrequency(int frequency) {
308                    this.frequency = frequency;
309            } 
310            
311            /**
312             * Returns a learner object
313             * 
314             * @return a learner object
315             */
316            public LearningMethod getMethod() {
317                    return method;
318            }
319            
320            
321            /* (non-Javadoc)
322             * @see java.lang.Object#toString()
323             */
324            public String toString() {
325                    final StringBuilder sb = new StringBuilder();
326                    sb.append(method.toString());
327                    return sb.toString();
328            }
329    }