001package org.maltparser.parser.guide.instance;
002
003import java.io.BufferedWriter;
004import java.io.IOException;
005import java.util.SortedMap;
006
007import java.util.TreeMap;
008import java.util.TreeSet;
009import java.util.regex.Pattern;
010
011import org.maltparser.core.exception.MaltChainedException;
012import org.maltparser.core.feature.FeatureModel;
013import org.maltparser.core.feature.FeatureVector;
014import org.maltparser.core.feature.value.SingleFeatureValue;
015import org.maltparser.core.syntaxgraph.DependencyStructure;
016import org.maltparser.parser.guide.ClassifierGuide;
017import org.maltparser.parser.guide.GuideException;
018import org.maltparser.parser.guide.Model;
019import org.maltparser.parser.history.action.SingleDecision;
020
021/**
022The feature divide model is used for divide the training instances into several models according to
023a divide feature. Usually this strategy decrease the training and classification time, but can also decrease 
024the accuracy of the parser.  
025
026@author Johan Hall
027*/
028public class FeatureDivideModel implements InstanceModel {
029        private final Model parent;
030        private final SortedMap<Integer,AtomicModel> divideModels;
031//      private FeatureVector masterFeatureVector;
032        private int frequency = 0;
033        private final int divideThreshold;
034        private AtomicModel masterModel;
035        
036        /**
037         * Constructs a feature divide model.
038         * 
039         * @param parent the parent guide model.
040         * @throws MaltChainedException
041         */
042        public FeatureDivideModel(Model parent) throws MaltChainedException {
043                this.parent = parent;
044                setFrequency(0);
045//              this.masterFeatureVector = featureVector;
046
047                String data_split_threshold = getGuide().getConfiguration().getOptionValue("guide", "data_split_threshold").toString().trim();
048                if (data_split_threshold != null) {
049                        try {
050                                divideThreshold = Integer.parseInt(data_split_threshold);
051                        } catch (NumberFormatException e) {
052                                throw new GuideException("The --guide-data_split_threshold option is not an integer value. ", e);
053                        }
054                } else {
055                        divideThreshold = 0;
056                }
057                divideModels = new TreeMap<Integer,AtomicModel>();
058                if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
059                        masterModel = new AtomicModel(-1, this);
060                } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
061                        load();
062                }
063        }
064        
065        public void addInstance(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
066//              featureVector.getFeatureModel().getDivideFeatureFunction().update();
067                SingleFeatureValue featureValue = (SingleFeatureValue)featureVector.getFeatureModel().getDivideFeatureFunction().getFeatureValue();
068                if (!divideModels.containsKey(featureValue.getIndexCode())) {
069                        divideModels.put(featureValue.getIndexCode(), new AtomicModel(featureValue.getIndexCode(), this));
070                }
071                FeatureVector divideFeatureVector = featureVector.getFeatureModel().getFeatureVector("/" + featureVector.getSpecSubModel().getSubModelName());
072                divideModels.get(featureValue.getIndexCode()).addInstance(divideFeatureVector, decision);
073        }
074        
075        public void noMoreInstances(FeatureModel featureModel) throws MaltChainedException {
076                for (Integer index : divideModels.keySet()) {
077                        divideModels.get(index).noMoreInstances(featureModel);
078                }
079                final TreeSet<Integer> removeSet = new TreeSet<Integer>();
080                for (Integer index : divideModels.keySet()) {
081                        if (divideModels.get(index).getFrequency() <= divideThreshold) {
082                                divideModels.get(index).moveAllInstances(masterModel, featureModel.getDivideFeatureFunction(), featureModel.getDivideFeatureIndexVector());
083                                removeSet.add(index);
084                        }
085                }
086                for (Integer index : removeSet) {
087                        divideModels.remove(index);
088                }
089                masterModel.noMoreInstances(featureModel);
090        }
091
092        public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
093                if (divideModels != null) { 
094                        for (AtomicModel divideModel : divideModels.values()) {
095                                divideModel.finalizeSentence(dependencyGraph);
096                        }
097                } else {
098                        throw new GuideException("The feature divide models cannot be found. ");
099                }
100        }
101
102        public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
103                AtomicModel model = getAtomicModel((SingleFeatureValue)featureVector.getFeatureModel().getDivideFeatureFunction().getFeatureValue());
104                if (model == null) {
105                        if (getGuide().getConfiguration().isLoggerInfoEnabled()) {
106                                getGuide().getConfiguration().logInfoMessage("Could not predict the next parser decision because there is " +
107                                                "no divide or master model that covers the divide value '"+((SingleFeatureValue)featureVector.getFeatureModel().getDivideFeatureFunction().getFeatureValue()).getIndexCode()+"', as default" +
108                                                                " class code '1' is used. ");
109                        }
110                        decision.addDecision(1); // default prediction
111                        return true;
112                }
113                return model.predict(getModelFeatureVector(model, featureVector), decision);
114        }
115
116        public FeatureVector predictExtract(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
117                AtomicModel model = getAtomicModel((SingleFeatureValue)featureVector.getFeatureModel().getDivideFeatureFunction().getFeatureValue());
118                if (model == null) {
119                        return null;
120                }
121                return model.predictExtract(getModelFeatureVector(model, featureVector), decision);
122        }
123        
124        public FeatureVector extract(FeatureVector featureVector) throws MaltChainedException {
125                AtomicModel model = getAtomicModel((SingleFeatureValue)featureVector.getFeatureModel().getDivideFeatureFunction().getFeatureValue());
126                if (model == null) {
127                        return featureVector;
128                }
129                return model.extract(getModelFeatureVector(model, featureVector));
130        }
131        
132        private FeatureVector getModelFeatureVector(AtomicModel model, FeatureVector featureVector) {
133                if (model.getIndex() == -1) {
134                        return featureVector;
135                } else {
136                        return featureVector.getFeatureModel().getFeatureVector("/" + featureVector.getSpecSubModel().getSubModelName());
137                }
138        }
139        
140        private AtomicModel getAtomicModel(SingleFeatureValue featureValue) throws MaltChainedException {
141                //((SingleFeatureValue)masterFeatureVector.getFeatureModel().getDivideFeatureFunction().getFeatureValue()).getIndexCode()
142                if (divideModels != null && divideModels.containsKey(featureValue.getIndexCode())) {
143                        return divideModels.get(featureValue.getIndexCode());
144                } else if (masterModel != null && masterModel.getFrequency() > 0) {
145                        return masterModel;
146                } 
147                return null;
148        }
149        
150        public void terminate() throws MaltChainedException {
151                if (divideModels != null) {
152                        for (AtomicModel divideModel : divideModels.values()) { 
153                                divideModel.terminate();
154                        }
155                }
156                if (masterModel != null) {
157                        masterModel.terminate();
158                }
159        }
160        
161        public void train() throws MaltChainedException {
162                for (AtomicModel divideModel : divideModels.values()) {
163                        divideModel.train();
164                }
165                masterModel.train();
166                save();
167                for (AtomicModel divideModel : divideModels.values()) {
168                        divideModel.terminate();
169                }
170                masterModel.terminate();
171        }
172        
173        /**
174         * Saves the feature divide model settings .dsm file.
175         * 
176         * @throws MaltChainedException
177         */
178        protected void save() throws MaltChainedException {
179                try {
180                        final BufferedWriter out = new BufferedWriter(getGuide().getConfiguration().getOutputStreamWriter(getModelName()+".dsm"));
181                        out.write(masterModel.getIndex() + "\t" + masterModel.getFrequency() + "\n");
182
183                        if (divideModels != null) {
184                                for (AtomicModel divideModel : divideModels.values()) {
185                                        out.write(divideModel.getIndex() + "\t" + divideModel.getFrequency() + "\n");
186                        }
187                        }
188                        out.close();
189                } catch (IOException e) {
190                        throw new GuideException("Could not write to the guide model settings file '"+getModelName()+".dsm"+"', when " +
191                                        "saving the guide model settings to file. ", e);
192                }
193        }
194        
195        protected void load() throws MaltChainedException {
196                String dsmString = getGuide().getConfiguration().getConfigFileEntryString(getModelName()+".dsm");
197                String[] lines = dsmString.split("\n");
198                Pattern tabPattern = Pattern.compile("\t");
199//              FeatureVector divideFeatureVector = featureVector.getFeatureModel().getFeatureVector("/" + featureVector.getSpecSubModel().getSubModelName());
200                for (int i = 0; i < lines.length; i++) {
201                        String[] cols = tabPattern.split(lines[i]);
202                        if (cols.length != 2) { 
203                                throw new GuideException("");
204                        }
205                        int code = -1;
206                        int freq = 0;
207                        try {
208                                code = Integer.parseInt(cols[0]);
209                                freq = Integer.parseInt(cols[1]);
210                        } catch (NumberFormatException e) {
211                                throw new GuideException("Could not convert a string value into an integer value when loading the feature divide model settings (.dsm). ", e);
212                        }
213                        if (code == -1) { 
214                                masterModel = new AtomicModel(-1, this);
215                                masterModel.setFrequency(freq);
216                        } else if (divideModels != null) {
217                                divideModels.put(code, new AtomicModel(code, this));
218                                divideModels.get(code).setFrequency(freq);
219                        }
220                        setFrequency(getFrequency()+freq);
221                }
222        }
223        
224        /**
225         * Returns the parent model
226         * 
227         * @return the parent model
228         */
229        public Model getParent() {
230                return parent;
231        }
232
233        public ClassifierGuide getGuide() {
234                return parent.getGuide();
235        }
236        
237        public String getModelName() throws MaltChainedException {
238                try {
239                        return parent.getModelName();
240                } catch (NullPointerException e) {
241                        throw new GuideException("The parent guide model cannot be found. ", e);
242                }
243        }
244        
245        /**
246         * Returns the frequency (number of instances)
247         * 
248         * @return the frequency (number of instances)
249         */
250        public int getFrequency() {
251                return frequency;
252        }
253
254        /**
255         * Increase the frequency by 1
256         */
257        public void increaseFrequency() {
258                if (parent instanceof InstanceModel) {
259                        ((InstanceModel)parent).increaseFrequency();
260                }
261                frequency++;
262        }
263        
264        public void decreaseFrequency() {
265                if (parent instanceof InstanceModel) {
266                        ((InstanceModel)parent).decreaseFrequency();
267                }
268                frequency--;
269        }
270        
271        /**
272         * Sets the frequency (number of instances)
273         * 
274         * @param frequency (number of instances)
275         */
276        protected void setFrequency(int frequency) {
277                this.frequency = frequency;
278        }
279
280
281        /* (non-Javadoc)
282         * @see java.lang.Object#toString()
283         */
284        public String toString() {
285                final StringBuilder sb = new StringBuilder();
286                //TODO
287                return sb.toString();
288        }
289}