001    package org.maltparser.parser.guide.instance;
002    
003    import java.io.IOException;
004    import java.lang.reflect.Constructor;
005    import java.lang.reflect.InvocationTargetException;
006    import java.util.ArrayList;
007    import java.util.Formatter;
008    
009    import org.maltparser.core.exception.MaltChainedException;
010    import org.maltparser.core.feature.FeatureVector;
011    import org.maltparser.core.feature.function.FeatureFunction;
012    import org.maltparser.core.feature.function.Modifiable;
013    import org.maltparser.core.syntaxgraph.DependencyStructure;
014    import org.maltparser.ml.LearningMethod;
015    import org.maltparser.parser.guide.ClassifierGuide;
016    import org.maltparser.parser.guide.GuideException;
017    import org.maltparser.parser.guide.Model;
018    import org.maltparser.parser.history.action.SingleDecision;
019    
020    
021    /**
022    
023    @author Johan Hall
024    @since 1.0
025    */
026    public class AtomicModel implements InstanceModel {
027            private Model parent;
028            private String modelName;
029            private FeatureVector featureVector;
030            private int index;
031            private int frequency = 0;
032            private LearningMethod method;
033    
034            
035            /**
036             * Constructs an atomic model.
037             * 
038             * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 
039             * or the master divide model) and n is number of divide models.
040             * @param features the feature vector used by the atomic model.
041             * @param parent the parent guide model.
042             * @throws MaltChainedException
043             */
044            public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException {
045                    setParent(parent);
046                    setIndex(index);
047                    if (index == -1) {
048                            setModelName(parent.getModelName()+".");
049                    } else {
050                            setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+".");
051                    }
052                    setFeatures(features);
053                    setFrequency(0);
054                    initMethod();
055                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) {
056                            try {
057                                    getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString());
058                                    getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush();
059                            } catch (IOException e) {
060                                    throw new GuideException("Could not write learner settings to the information file. ", e);
061                            }
062                    }
063            }
064            
065            public void addInstance(SingleDecision decision) throws MaltChainedException {
066                    try {
067                            method.addInstance(decision, featureVector);
068                    } catch (NullPointerException e) {
069                            throw new GuideException("The learner cannot be found. ", e);
070                    }
071            }
072    
073            
074            public void noMoreInstances() throws MaltChainedException {
075                    try {
076                            method.noMoreInstances();
077                    } catch (NullPointerException e) {
078                            throw new GuideException("The learner cannot be found. ", e);
079                    }
080            }
081            
082            public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
083                    try {
084                            method.finalizeSentence(dependencyGraph);
085                    } catch (NullPointerException e) {
086                            throw new GuideException("The learner cannot be found. ", e);
087                    }
088            }
089    
090            public boolean predict(SingleDecision decision) throws MaltChainedException {
091                    try {
092                            if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
093                                    throw new GuideException("Cannot predict during batch training. ");
094                            }
095                            return method.predict(featureVector, decision);
096                    } catch (NullPointerException e) {
097                            throw new GuideException("The learner cannot be found. ", e);
098                    }
099            }
100    
101            public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException {
102                    try {
103                            if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
104                                    throw new GuideException("Cannot predict during batch training. ");
105                            }
106                            if (method.predict(featureVector, decision)) {
107                                    return featureVector;
108                            }
109                            return null;
110                    } catch (NullPointerException e) {
111                            throw new GuideException("The learner cannot be found. ", e);
112                    }
113            }
114            
115            public FeatureVector extract() throws MaltChainedException {
116                    return featureVector;
117            }
118            
119            public void terminate() throws MaltChainedException {
120                    if (method != null) {
121                            method.terminate();
122                            method = null;
123                    }
124                    featureVector = null;
125                    parent = null;
126            }
127            
128            /**
129             * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
130             * This method is used by the feature divide model to sum up all model below a certain threshold.
131             * 
132             * @param model the destination atomic model 
133             * @param divideFeature the divide feature
134             * @param divideFeatureIndexVector the divide feature index vector
135             * @throws MaltChainedException
136             */
137            public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
138                    if (method == null) {
139                            throw new GuideException("The learner cannot be found. ");
140                    } else if (model == null) {
141                            throw new GuideException("The guide model cannot be found. ");
142                    } else if (divideFeature == null) {
143                            throw new GuideException("The divide feature cannot be found. ");
144                    } else if (divideFeatureIndexVector == null) {
145                            throw new GuideException("The divide feature index vector cannot be found. ");
146                    }
147                    ((Modifiable)divideFeature).setFeatureValue(index);
148                    method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
149                    method.terminate();
150                    method = null;
151            }
152            
153            /**
154             * Invokes the train() of the learning method 
155             * 
156             * @throws MaltChainedException
157             */
158            public void train() throws MaltChainedException {
159                    try {
160                            method.train(featureVector);
161                            method.terminate();
162                            method = null;
163                            
164                    } catch (NullPointerException e) {      
165                            throw new GuideException("The learner cannot be found. ", e);
166                    }
167                    
168    
169            }
170            
171            /**
172             * Initialize the learning method according to the option --learner-method.
173             * 
174             * @throws MaltChainedException
175             */
176            public void initMethod() throws MaltChainedException {
177                    Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
178    //              if (clazz == org.maltparser.ml.libsvm.Libsvm.class && (Boolean)getGuide().getConfiguration().getOptionValue("malt0.4", "behavior") == true) {
179    //                      try {
180    //                              clazz = Class.forName("org.maltparser.ml.libsvm.malt04.LibsvmMalt04");
181    //                      } catch (ClassNotFoundException e) {
182    //                              throw new GuideException("Could not find the class 'org.maltparser.ml.libsvm.malt04.LibsvmMalt04'. ", e);
183    //                      }
184    //              }
185                    Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
186                    Object[] arguments = new Object[2];
187                    arguments[0] = this;
188                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
189                            arguments[1] = LearningMethod.CLASSIFY;
190                    } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
191                            arguments[1] = LearningMethod.BATCH;
192                    } 
193    
194                    try {   
195                            Constructor<?> constructor = clazz.getConstructor(argTypes);
196                            this.method = (LearningMethod)constructor.newInstance(arguments);
197                    } catch (NoSuchMethodException e) {
198                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
199                    } catch (InstantiationException e) {
200                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
201                    } catch (IllegalAccessException e) {
202                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
203                    } catch (InvocationTargetException e) {
204                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
205                    }
206            }
207            
208            
209            
210            /**
211             * Returns the parent guide model
212             * 
213             * @return the parent guide model
214             */
215            public Model getParent() throws MaltChainedException {
216                    if (parent == null) {
217                            throw new GuideException("The atomic model can only be used by a parent model. ");
218                    }
219                    return parent;
220            }
221    
222            /**
223             * Sets the parent guide model
224             * 
225             * @param parent the parent guide model
226             */
227            protected void setParent(Model parent) {
228                    this.parent = parent;
229            }
230    
231            public String getModelName() {
232                    return modelName;
233            }
234    
235            /**
236             * Sets the name of the atomic model
237             * 
238             * @param modelName the name of the atomic model
239             */
240            protected void setModelName(String modelName) {
241                    this.modelName = modelName;
242            }
243    
244            /**
245             * Returns the feature vector used by this atomic model
246             * 
247             * @return a feature vector object
248             */
249            public FeatureVector getFeatures() {
250                    return featureVector;
251            }
252    
253            /**
254             * Sets the feature vector used by the atomic model.
255             * 
256             * @param features a feature vector object
257             */
258            protected void setFeatures(FeatureVector features) {
259                    this.featureVector = features;
260            }
261    
262            public ClassifierGuide getGuide() {
263                    return parent.getGuide();
264            }
265            
266            /**
267             * Returns the index of the atomic model
268             * 
269             * @return the index of the atomic model
270             */
271            public int getIndex() {
272                    return index;
273            }
274    
275            /**
276             * Sets the index of the model (-1..n), where -1 is a special value.
277             * 
278             * @param index index value (-1..n) of the atomic model
279             */
280            protected void setIndex(int index) {
281                    this.index = index;
282            }
283    
284            /**
285             * Returns the frequency (number of instances)
286             * 
287             * @return the frequency (number of instances)
288             */
289            public int getFrequency() {
290                    return frequency;
291            }
292            
293            /**
294             * Increase the frequency by 1
295             */
296            public void increaseFrequency() {
297                    if (parent instanceof InstanceModel) {
298                            ((InstanceModel)parent).increaseFrequency();
299                    }
300                    frequency++;
301            }
302            
303            public void decreaseFrequency() {
304                    if (parent instanceof InstanceModel) {
305                            ((InstanceModel)parent).decreaseFrequency();
306                    }
307                    frequency--;
308            }
309            /**
310             * Sets the frequency (number of instances)
311             * 
312             * @param frequency (number of instances)
313             */
314            protected void setFrequency(int frequency) {
315                    this.frequency = frequency;
316            } 
317            
318            /**
319             * Returns a learner object
320             * 
321             * @return a learner object
322             */
323            public LearningMethod getMethod() {
324                    return method;
325            }
326            
327            
328            /* (non-Javadoc)
329             * @see java.lang.Object#toString()
330             */
331            public String toString() {
332                    final StringBuilder sb = new StringBuilder();
333                    sb.append(method.toString());
334                    return sb.toString();
335            }
336    }