001package org.maltparser.parser.guide.instance;
002
003import java.lang.reflect.Constructor;
004import java.lang.reflect.InvocationTargetException;
005import java.util.ArrayList;
006import java.util.Formatter;
007
008import org.maltparser.core.exception.MaltChainedException;
009import org.maltparser.core.feature.FeatureModel;
010import org.maltparser.core.feature.FeatureVector;
011import org.maltparser.core.feature.function.FeatureFunction;
012import org.maltparser.core.feature.function.Modifiable;
013import org.maltparser.core.syntaxgraph.DependencyStructure;
014import org.maltparser.ml.LearningMethod;
015import org.maltparser.ml.lib.LibLinear;
016import org.maltparser.ml.lib.LibSvm;
017import org.maltparser.parser.guide.ClassifierGuide;
018import org.maltparser.parser.guide.GuideException;
019import org.maltparser.parser.guide.Model;
020import org.maltparser.parser.history.action.SingleDecision;
021
022
023/**
024
025@author Johan Hall
026*/
027public class AtomicModel implements InstanceModel {
028        public static final Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
029        private final Model parent;
030        private final String modelName;
031//      private final FeatureVector featureVector;
032        private final int index;
033        private final LearningMethod method;
034        private int frequency = 0;
035
036        
037        /**
038         * Constructs an atomic model.
039         * 
040         * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 
041         * or the master divide model) and n is number of divide models.
042         * @param parent the parent guide model.
043         * @throws MaltChainedException
044         */
045        public AtomicModel(int index, Model parent) throws MaltChainedException {
046                this.parent = parent;
047                this.index = index;
048                if (index == -1) {
049                        this.modelName = parent.getModelName()+".";
050                } else {
051                        this.modelName = parent.getModelName()+"."+new Formatter().format("%03d", index)+".";
052                }
053//              this.featureVector = featureVector;
054                this.frequency = 0;
055                Integer learnerMode = null;
056                if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
057                        learnerMode = LearningMethod.CLASSIFY;
058                } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
059                        learnerMode = LearningMethod.BATCH;
060                }
061                
062                // start init learning method
063                Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
064                if (clazz == org.maltparser.ml.lib.LibSvm.class) {
065                        this.method = new LibSvm(this, learnerMode);
066                } else if (clazz == org.maltparser.ml.lib.LibLinear.class) {
067                        this.method = new LibLinear(this, learnerMode);
068                } else {
069                        Object[] arguments = {this, learnerMode};
070                        try {   
071                                Constructor<?> constructor = clazz.getConstructor(argTypes);
072                                this.method = (LearningMethod)constructor.newInstance(arguments);
073                        } catch (NoSuchMethodException e) {
074                                throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
075                        } catch (InstantiationException e) {
076                                throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
077                        } catch (IllegalAccessException e) {
078                                throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
079                        } catch (InvocationTargetException e) {
080                                throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
081                        }
082                }
083                // end init learning method
084                
085                if (learnerMode == LearningMethod.BATCH && index == -1 && getGuide().getConfiguration() != null) {
086                        getGuide().getConfiguration().writeInfoToConfigFile(method.toString());
087                }
088        }
089        
090        public void addInstance(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
091                try {
092                        method.addInstance(decision, featureVector);
093                } catch (NullPointerException e) {
094                        throw new GuideException("The learner cannot be found. ", e);
095                }
096        }
097
098        
099        public void noMoreInstances(FeatureModel featureModel) throws MaltChainedException {
100                try {
101                        method.noMoreInstances();
102                } catch (NullPointerException e) {
103                        throw new GuideException("The learner cannot be found. ", e);
104                }
105        }
106        
107        public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
108                try {
109                        method.finalizeSentence(dependencyGraph);
110                } catch (NullPointerException e) {
111                        throw new GuideException("The learner cannot be found. ", e);
112                }
113        }
114
115        public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
116                try {
117                        return method.predict(featureVector, decision);
118                } catch (NullPointerException e) {
119                        throw new GuideException("The learner cannot be found. ", e);
120                }
121        }
122
123        public FeatureVector predictExtract(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
124                try {
125
126                        if (method.predict(featureVector, decision)) {
127                                return featureVector;
128                        }
129                        return null;
130                } catch (NullPointerException e) {
131                        throw new GuideException("The learner cannot be found. ", e);
132                }
133        }
134        
135        public FeatureVector extract(FeatureVector featureVector) throws MaltChainedException {
136                return featureVector;
137        }
138        
139        public void terminate() throws MaltChainedException {
140                if (method != null) {
141                        method.terminate();
142                }
143        }
144        
145        /**
146         * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
147         * This method is used by the feature divide model to sum up all model below a certain threshold.
148         * 
149         * @param model the destination atomic model 
150         * @param divideFeature the divide feature
151         * @param divideFeatureIndexVector the divide feature index vector
152         * @throws MaltChainedException
153         */
154        public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
155                if (method == null) {
156                        throw new GuideException("The learner cannot be found. ");
157                } else if (model == null) {
158                        throw new GuideException("The guide model cannot be found. ");
159                } else if (divideFeature == null) {
160                        throw new GuideException("The divide feature cannot be found. ");
161                } else if (divideFeatureIndexVector == null) {
162                        throw new GuideException("The divide feature index vector cannot be found. ");
163                }
164                ((Modifiable)divideFeature).setFeatureValue(index);
165                method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
166                method.terminate();
167        }
168        
169        /**
170         * Invokes the train() of the learning method 
171         * 
172         * @throws MaltChainedException
173         */
174        public void train() throws MaltChainedException {
175                try {
176                        method.train();
177                        method.terminate();
178                } catch (NullPointerException e) {      
179                        throw new GuideException("The learner cannot be found. ", e);
180                }
181                
182
183        }
184        
185        /**
186         * Returns the parent guide model
187         * 
188         * @return the parent guide model
189         */
190        public Model getParent() throws MaltChainedException {
191                if (parent == null) {
192                        throw new GuideException("The atomic model can only be used by a parent model. ");
193                }
194                return parent;
195        }
196
197
198        public String getModelName() {
199                return modelName;
200        }
201
202        /**
203         * Returns the feature vector used by this atomic model
204         * 
205         * @return a feature vector object
206         */
207//      public FeatureVector getFeatures() {
208//              return featureVector;
209//      }
210
211        public ClassifierGuide getGuide() {
212                return parent.getGuide();
213        }
214        
215        /**
216         * Returns the index of the atomic model
217         * 
218         * @return the index of the atomic model
219         */
220        public int getIndex() {
221                return index;
222        }
223
224        /**
225         * Returns the frequency (number of instances)
226         * 
227         * @return the frequency (number of instances)
228         */
229        public int getFrequency() {
230                return frequency;
231        }
232        
233        /**
234         * Increase the frequency by 1
235         */
236        public void increaseFrequency() {
237                if (parent instanceof InstanceModel) {
238                        ((InstanceModel)parent).increaseFrequency();
239                }
240                frequency++;
241        }
242        
243        public void decreaseFrequency() {
244                if (parent instanceof InstanceModel) {
245                        ((InstanceModel)parent).decreaseFrequency();
246                }
247                frequency--;
248        }
249        /**
250         * Sets the frequency (number of instances)
251         * 
252         * @param frequency (number of instances)
253         */
254        protected void setFrequency(int frequency) {
255                this.frequency = frequency;
256        } 
257        
258        /**
259         * Returns a learner object
260         * 
261         * @return a learner object
262         */
263        public LearningMethod getMethod() {
264                return method;
265        }
266        
267        
268        /* (non-Javadoc)
269         * @see java.lang.Object#toString()
270         */
271        public String toString() {
272                final StringBuilder sb = new StringBuilder();
273                sb.append(method.toString());
274                return sb.toString();
275        }
276}