001    package org.maltparser.parser.guide.decision;
002    
003    import java.lang.reflect.Constructor;
004    import java.lang.reflect.InvocationTargetException;
005    
006    
007    import org.maltparser.core.exception.MaltChainedException;
008    import org.maltparser.core.feature.FeatureModel;
009    import org.maltparser.core.feature.FeatureVector;
010    import org.maltparser.core.helper.HashMap;
011    import org.maltparser.core.syntaxgraph.DependencyStructure;
012    import org.maltparser.parser.DependencyParserConfig;
013    import org.maltparser.parser.guide.ClassifierGuide;
014    import org.maltparser.parser.guide.GuideException;
015    import org.maltparser.parser.guide.instance.AtomicModel;
016    import org.maltparser.parser.guide.instance.FeatureDivideModel;
017    import org.maltparser.parser.guide.instance.InstanceModel;
018    import org.maltparser.parser.history.action.GuideDecision;
019    import org.maltparser.parser.history.action.MultipleDecision;
020    import org.maltparser.parser.history.action.SingleDecision;
021    import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision;
022    /**
023    *
024    * @author Johan Hall
025    * @since 1.1
026    **/
027    public class BranchedDecisionModel implements DecisionModel {
028            private ClassifierGuide guide;
029            private String modelName;
030            private FeatureModel featureModel;
031            private InstanceModel instanceModel;
032            private int decisionIndex;
033            private DecisionModel parentDecisionModel;
034            private HashMap<Integer,DecisionModel> children;
035            private String branchedDecisionSymbols;
036            
037            public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException {
038                    this.branchedDecisionSymbols = "";
039                    setGuide(guide);
040                    setFeatureModel(featureModel);
041                    setDecisionIndex(0);
042                    setModelName("bdm"+decisionIndex);
043                    setParentDecisionModel(null);
044            }
045            
046            public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException {
047                    if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) {
048                            this.branchedDecisionSymbols = branchedDecisionSymbol;
049                    } else {
050                            this.branchedDecisionSymbols = "";
051                    }
052                    setGuide(guide);
053                    setParentDecisionModel(parentDecisionModel);
054                    setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1);
055                    setFeatureModel(parentDecisionModel.getFeatureModel());
056                    if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) {
057                            setModelName("bdm"+decisionIndex+branchedDecisionSymbols);
058                    } else {
059                            setModelName("bdm"+decisionIndex);
060                    }
061                    this.parentDecisionModel = parentDecisionModel;
062            }
063            
064            public void updateFeatureModel() throws MaltChainedException {
065                    featureModel.update();
066            }
067            
068    //      public void updateCardinality() throws MaltChainedException {
069    //              featureModel.updateCardinality();
070    //      }
071            
072    
073            public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
074                    if (instanceModel != null) {
075                            instanceModel.finalizeSentence(dependencyGraph);
076                    }
077                    if (children != null) {
078                            for (DecisionModel child : children.values()) {
079                                    child.finalizeSentence(dependencyGraph);
080                            }
081                    }
082            }
083            
084            public void noMoreInstances() throws MaltChainedException {
085                    if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
086                            throw new GuideException("The decision model could not create it's model. ");
087                    }
088                    if (instanceModel != null) {
089                            instanceModel.noMoreInstances();
090                            instanceModel.train();
091                    }
092                    if (children != null) {
093                            for (DecisionModel child : children.values()) {
094                                    child.noMoreInstances();
095                            }
096                    }
097            }
098    
099            public void terminate() throws MaltChainedException {
100                    if (instanceModel != null) {
101                            instanceModel.terminate();
102                            instanceModel = null;
103                    }
104                    if (children != null) {
105                            for (DecisionModel child : children.values()) {
106                                    child.terminate();
107                            }
108                    }
109            }
110            
111            public void addInstance(GuideDecision decision) throws MaltChainedException {
112                    if (decision instanceof SingleDecision) {
113                            throw new GuideException("A branched decision model expect more than one decisions. ");
114                    }
115                    featureModel.update();
116                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
117                    if (instanceModel == null) {
118                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
119                    }
120                    
121                    instanceModel.addInstance(singleDecision);
122                    if (decisionIndex+1 < decision.numberOfDecisions()) {
123                            if (singleDecision.continueWithNextDecision()) {
124                                    if (children == null) {
125                                            children = new HashMap<Integer,DecisionModel>();
126                                    }
127                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
128                                    if (child == null) {
129                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
130                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
131                                            children.put(singleDecision.getDecisionCode(), child);
132                                    }
133                                    child.addInstance(decision);
134                            }
135                    }
136            }
137            
138            public boolean predict(GuideDecision decision) throws MaltChainedException {
139    //              if (decision instanceof SingleDecision) {
140    //                      throw new GuideException("A branched decision model expect more than one decisions. ");
141    //              }
142                    featureModel.update();
143                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
144                    if (instanceModel == null) {
145                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
146                    }
147                    instanceModel.predict(singleDecision);
148                    if (decisionIndex+1 < decision.numberOfDecisions()) {
149                            if (singleDecision.continueWithNextDecision()) {
150                                    if (children == null) {
151                                            children = new HashMap<Integer,DecisionModel>();
152                                    }
153                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
154                                    if (child == null) {
155                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
156                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
157                                            children.put(singleDecision.getDecisionCode(), child);
158                                    }
159                                    child.predict(decision);
160                            }
161                    }
162    
163                    return true;
164            }
165            
166            public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException {
167                    if (decision instanceof SingleDecision) {
168                            throw new GuideException("A branched decision model expect more than one decisions. ");
169                    }
170                    featureModel.update();
171                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
172                    if (instanceModel == null) {
173                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
174                    }
175                    FeatureVector fv = instanceModel.predictExtract(singleDecision);
176                    if (decisionIndex+1 < decision.numberOfDecisions()) {
177                            if (singleDecision.continueWithNextDecision()) {
178                                    if (children == null) {
179                                            children = new HashMap<Integer,DecisionModel>();
180                                    }
181                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
182                                    if (child == null) {
183                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
184                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
185                                            children.put(singleDecision.getDecisionCode(), child);
186                                    }
187                                    child.predictExtract(decision);
188                            }
189                    }
190    
191                    return fv;
192            }
193            
194            public FeatureVector extract() throws MaltChainedException {
195                    featureModel.update();
196                    return instanceModel.extract(); // TODO handle many feature vectors
197            }
198            
199            public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException {
200                    if (decision instanceof SingleDecision) {
201                            throw new GuideException("A branched decision model expect more than one decisions. ");
202                    }
203                    
204                    boolean success = false;
205                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
206                    if (decisionIndex+1 < decision.numberOfDecisions()) {
207                            if (singleDecision.continueWithNextDecision()) {
208                                    if (children == null) {
209                                            children = new HashMap<Integer,DecisionModel>();
210                                    }
211                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
212                                    if (child != null) {
213                                            success = child.predictFromKBestList(decision);
214                                    }
215                                    
216                            }
217                    }
218                    if (!success) {
219                            success = singleDecision.updateFromKBestList();
220                            if (decisionIndex+1 < decision.numberOfDecisions()) {
221                                    if (singleDecision.continueWithNextDecision()) {
222                                            if (children == null) {
223                                                    children = new HashMap<Integer,DecisionModel>();
224                                            }
225                                            DecisionModel child = children.get(singleDecision.getDecisionCode());
226                                            if (child == null) {
227                                                    child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
228                                                                    branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
229                                                    children.put(singleDecision.getDecisionCode(), child);
230                                            }
231                                            child.predict(decision);
232                                    }
233                            }
234                    }
235                    return success;
236            }
237            
238    
239            public ClassifierGuide getGuide() {
240                    return guide;
241            }
242    
243            public String getModelName() {
244                    return modelName;
245            }
246            
247            public FeatureModel getFeatureModel() {
248                    return featureModel;
249            }
250    
251            public int getDecisionIndex() {
252                    return decisionIndex;
253            }
254    
255            public DecisionModel getParentDecisionModel() {
256                    return parentDecisionModel;
257            }
258    
259            private void setFeatureModel(FeatureModel featureModel) {
260                    this.featureModel = featureModel;
261            }
262            
263            private void setDecisionIndex(int decisionIndex) {
264                    this.decisionIndex = decisionIndex;
265            }
266            
267            private void setParentDecisionModel(DecisionModel parentDecisionModel) {
268                    this.parentDecisionModel = parentDecisionModel;
269            }
270    
271            private void setModelName(String modelName) {
272                    this.modelName = modelName;
273            }
274            
275            private void setGuide(ClassifierGuide guide) {
276                    this.guide = guide;
277            }
278            
279            
280            private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException {
281                    Class<?> decisionModelClass = null;
282                    if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) {
283                            decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class;
284                    } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) {
285                            decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class;
286                    } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) {
287                            decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class;
288                    }
289    
290                    if (decisionModelClass == null) {
291                            throw new GuideException("Could not find an appropriate decision model for the relation to the next decision"); 
292                    }
293                    
294                    try {
295                            Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class, 
296                                                    java.lang.String.class };
297                            Object[] arguments = new Object[3];
298                            arguments[0] = getGuide();
299                            arguments[1] = this;
300                            arguments[2] = branchedDecisionSymbol;
301                            Constructor<?> constructor = decisionModelClass.getConstructor(argTypes);
302                            return (DecisionModel)constructor.newInstance(arguments);
303                    } catch (NoSuchMethodException e) {
304                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
305                    } catch (InstantiationException e) {
306                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
307                    } catch (IllegalAccessException e) {
308                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
309                    } catch (InvocationTargetException e) {
310                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
311                    }
312            }
313            
314            private void initInstanceModel(String subModelName) throws MaltChainedException {
315                    FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName);
316                    if (fv == null) {
317                            fv = featureModel.getFeatureVector(subModelName);
318                    }
319                    if (fv == null) {
320                            fv = featureModel.getMainFeatureVector();
321                    }
322                    
323                    DependencyParserConfig c = guide.getConfiguration();
324                    
325    //              if (c.getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes") ||
326    //                              (c.getOptionValue("guide", "tree_split_columns")!=null &&
327    //                      c.getOptionValue("guide", "tree_split_columns").toString().length() > 0) ||
328    //                      (c.getOptionValue("guide", "tree_split_structures")!=null &&
329    //                      c.getOptionValue("guide", "tree_split_structures").toString().length() > 0)) {
330    //                      instanceModel = new DecisionTreeModel(fv, this); 
331    //              }else 
332                    if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) {
333                            instanceModel = new AtomicModel(-1, fv, this);
334                    } else {
335                            instanceModel = new FeatureDivideModel(fv, this);
336                    }
337            }
338            
339            public String toString() {
340                    final StringBuilder sb = new StringBuilder();
341                    sb.append(modelName + ", ");
342                    for (DecisionModel model : children.values()) {
343                            sb.append(model.toString() + ", ");
344                    }
345                    return sb.toString();
346            }
347    }