001    package org.maltparser.parser.guide.decision;
002    
003    import java.lang.reflect.Constructor;
004    import java.lang.reflect.InvocationTargetException;
005    import java.util.HashMap;
006    
007    import org.maltparser.core.exception.MaltChainedException;
008    import org.maltparser.core.feature.FeatureModel;
009    import org.maltparser.core.feature.FeatureVector;
010    import org.maltparser.core.syntaxgraph.DependencyStructure;
011    import org.maltparser.parser.guide.ClassifierGuide;
012    import org.maltparser.parser.guide.GuideException;
013    import org.maltparser.parser.guide.instance.AtomicModel;
014    import org.maltparser.parser.guide.instance.FeatureDivideModel;
015    import org.maltparser.parser.guide.instance.InstanceModel;
016    import org.maltparser.parser.history.action.GuideDecision;
017    import org.maltparser.parser.history.action.MultipleDecision;
018    import org.maltparser.parser.history.action.SingleDecision;
019    import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision;
020    /**
021    *
022    * @author Johan Hall
023    * @since 1.1
024    **/
025    public class BranchedDecisionModel implements DecisionModel {
026            private ClassifierGuide guide;
027            private String modelName;
028            private FeatureModel featureModel;
029            private InstanceModel instanceModel;
030            private int decisionIndex;
031            private DecisionModel parentDecisionModel;
032            private HashMap<Integer,DecisionModel> children;
033            private String branchedDecisionSymbols;
034            
035            public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException {
036                    this.branchedDecisionSymbols = "";
037                    setGuide(guide);
038                    setFeatureModel(featureModel);
039                    setDecisionIndex(0);
040                    setModelName("bdm"+decisionIndex);
041                    setParentDecisionModel(null);
042            }
043            
044            public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException {
045                    if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) {
046                            this.branchedDecisionSymbols = branchedDecisionSymbol;
047                    } else {
048                            this.branchedDecisionSymbols = "";
049                    }
050                    setGuide(guide);
051                    setParentDecisionModel(parentDecisionModel);
052                    setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1);
053                    setFeatureModel(parentDecisionModel.getFeatureModel());
054                    if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) {
055                            setModelName("bdm"+decisionIndex+branchedDecisionSymbols);
056                    } else {
057                            setModelName("bdm"+decisionIndex);
058                    }
059                    this.parentDecisionModel = parentDecisionModel;
060            }
061            
062            public void updateFeatureModel() throws MaltChainedException {
063                    featureModel.update();
064            }
065            
066            public void updateCardinality() throws MaltChainedException {
067                    featureModel.updateCardinality();
068            }
069            
070    
071            public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
072                    if (instanceModel != null) {
073                            instanceModel.finalizeSentence(dependencyGraph);
074                    }
075                    if (children != null) {
076                            for (DecisionModel child : children.values()) {
077                                    child.finalizeSentence(dependencyGraph);
078                            }
079                    }
080            }
081            
082            public void noMoreInstances() throws MaltChainedException {
083                    if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
084                            throw new GuideException("The decision model could not create it's model. ");
085                    }
086                    featureModel.updateCardinality();
087                    if (instanceModel != null) {
088                            instanceModel.noMoreInstances();
089                            instanceModel.train();
090                    }
091                    if (children != null) {
092                            for (DecisionModel child : children.values()) {
093                                    child.noMoreInstances();
094                            }
095                    }
096            }
097    
098            public void terminate() throws MaltChainedException {
099                    if (instanceModel != null) {
100                            instanceModel.terminate();
101                            instanceModel = null;
102                    }
103                    if (children != null) {
104                            for (DecisionModel child : children.values()) {
105                                    child.terminate();
106                            }
107                    }
108            }
109            
110            public void addInstance(GuideDecision decision) throws MaltChainedException {
111                    if (decision instanceof SingleDecision) {
112                            throw new GuideException("A branched decision model expect more than one decisions. ");
113                    }
114                    updateFeatureModel();
115                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
116                    if (instanceModel == null) {
117                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
118                    }
119                    
120                    instanceModel.addInstance(singleDecision);
121                    if (decisionIndex+1 < decision.numberOfDecisions()) {
122                            if (singleDecision.continueWithNextDecision()) {
123                                    if (children == null) {
124                                            children = new HashMap<Integer,DecisionModel>();
125                                    }
126                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
127                                    if (child == null) {
128                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
129                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
130                                            children.put(singleDecision.getDecisionCode(), child);
131                                    }
132                                    child.addInstance(decision);
133                            }
134                    }
135            }
136            
137            public boolean predict(GuideDecision decision) throws MaltChainedException {
138                    if (decision instanceof SingleDecision) {
139                            throw new GuideException("A branched decision model expect more than one decisions. ");
140                    }
141                    updateFeatureModel();
142                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
143                    if (instanceModel == null) {
144                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
145                    }
146                    instanceModel.predict(singleDecision);
147                    if (decisionIndex+1 < decision.numberOfDecisions()) {
148                            if (singleDecision.continueWithNextDecision()) {
149                                    if (children == null) {
150                                            children = new HashMap<Integer,DecisionModel>();
151                                    }
152                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
153                                    if (child == null) {
154                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
155                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
156                                            children.put(singleDecision.getDecisionCode(), child);
157                                    }
158                                    child.predict(decision);
159                            }
160                    }
161    
162                    return true;
163            }
164            
165            public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException {
166                    if (decision instanceof SingleDecision) {
167                            throw new GuideException("A branched decision model expect more than one decisions. ");
168                    }
169                    updateFeatureModel();
170                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
171                    if (instanceModel == null) {
172                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
173                    }
174                    FeatureVector fv = instanceModel.predictExtract(singleDecision);
175                    if (decisionIndex+1 < decision.numberOfDecisions()) {
176                            if (singleDecision.continueWithNextDecision()) {
177                                    if (children == null) {
178                                            children = new HashMap<Integer,DecisionModel>();
179                                    }
180                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
181                                    if (child == null) {
182                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
183                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
184                                            children.put(singleDecision.getDecisionCode(), child);
185                                    }
186                                    child.predictExtract(decision);
187                            }
188                    }
189    
190                    return fv;
191            }
192            
193            public FeatureVector extract() throws MaltChainedException {
194                    updateFeatureModel();
195                    return instanceModel.extract(); // TODO handle many feature vectors
196            }
197            
198            public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException {
199                    if (decision instanceof SingleDecision) {
200                            throw new GuideException("A branched decision model expect more than one decisions. ");
201                    }
202                    
203                    boolean success = false;
204                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
205                    if (decisionIndex+1 < decision.numberOfDecisions()) {
206                            if (singleDecision.continueWithNextDecision()) {
207                                    if (children == null) {
208                                            children = new HashMap<Integer,DecisionModel>();
209                                    }
210                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
211                                    if (child != null) {
212                                            success = child.predictFromKBestList(decision);
213                                    }
214                                    
215                            }
216                    }
217                    if (!success) {
218                            success = singleDecision.updateFromKBestList();
219                            if (decisionIndex+1 < decision.numberOfDecisions()) {
220                                    if (singleDecision.continueWithNextDecision()) {
221                                            if (children == null) {
222                                                    children = new HashMap<Integer,DecisionModel>();
223                                            }
224                                            DecisionModel child = children.get(singleDecision.getDecisionCode());
225                                            if (child == null) {
226                                                    child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
227                                                                    branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
228                                                    children.put(singleDecision.getDecisionCode(), child);
229                                            }
230                                            child.predict(decision);
231                                    }
232                            }
233                    }
234                    return success;
235            }
236            
237    
238            public ClassifierGuide getGuide() {
239                    return guide;
240            }
241    
242            public String getModelName() {
243                    return modelName;
244            }
245            
246            public FeatureModel getFeatureModel() {
247                    return featureModel;
248            }
249    
250            public int getDecisionIndex() {
251                    return decisionIndex;
252            }
253    
254            public DecisionModel getParentDecisionModel() {
255                    return parentDecisionModel;
256            }
257    
258            private void setFeatureModel(FeatureModel featureModel) {
259                    this.featureModel = featureModel;
260            }
261            
262            private void setDecisionIndex(int decisionIndex) {
263                    this.decisionIndex = decisionIndex;
264            }
265            
266            private void setParentDecisionModel(DecisionModel parentDecisionModel) {
267                    this.parentDecisionModel = parentDecisionModel;
268            }
269    
270            private void setModelName(String modelName) {
271                    this.modelName = modelName;
272            }
273            
274            private void setGuide(ClassifierGuide guide) {
275                    this.guide = guide;
276            }
277            
278            
279            private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException {
280                    Class<?> decisionModelClass = null;
281                    if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) {
282                            decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class;
283                    } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) {
284                            decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class;
285                    } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) {
286                            decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class;
287                    }
288    
289                    if (decisionModelClass == null) {
290                            throw new GuideException("Could not find an appropriate decision model for the relation to the next decision"); 
291                    }
292                    
293                    try {
294                            Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class, 
295                                                    java.lang.String.class };
296                            Object[] arguments = new Object[3];
297                            arguments[0] = getGuide();
298                            arguments[1] = this;
299                            arguments[2] = branchedDecisionSymbol;
300                            Constructor<?> constructor = decisionModelClass.getConstructor(argTypes);
301                            return (DecisionModel)constructor.newInstance(arguments);
302                    } catch (NoSuchMethodException e) {
303                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
304                    } catch (InstantiationException e) {
305                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
306                    } catch (IllegalAccessException e) {
307                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
308                    } catch (InvocationTargetException e) {
309                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
310                    }
311            }
312            
313            private void initInstanceModel(String subModelName) throws MaltChainedException {
314                    FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName);
315                    if (fv == null) {
316                            fv = featureModel.getFeatureVector(subModelName);
317                    }
318                    if (fv == null) {
319                            fv = featureModel.getMainFeatureVector();
320                    }
321                    if (guide.getConfiguration().getOptionValue("guide", "data_split_column").toString().length() == 0) {
322                            instanceModel = new AtomicModel(-1, fv, this);
323                    } else {
324                            instanceModel = new FeatureDivideModel(fv, this);
325                    }
326            }
327            
328            public String toString() {
329                    final StringBuilder sb = new StringBuilder();
330                    sb.append(modelName + ", ");
331                    for (DecisionModel model : children.values()) {
332                            sb.append(model.toString() + ", ");
333                    }
334                    return sb.toString();
335            }
336    }