001 package org.maltparser.parser.guide.decision; 002 003 import java.lang.reflect.Constructor; 004 import java.lang.reflect.InvocationTargetException; 005 006 007 import org.maltparser.core.exception.MaltChainedException; 008 import org.maltparser.core.feature.FeatureModel; 009 import org.maltparser.core.feature.FeatureVector; 010 import org.maltparser.core.helper.HashMap; 011 import org.maltparser.core.syntaxgraph.DependencyStructure; 012 import org.maltparser.parser.DependencyParserConfig; 013 import org.maltparser.parser.guide.ClassifierGuide; 014 import org.maltparser.parser.guide.GuideException; 015 import org.maltparser.parser.guide.instance.AtomicModel; 016 import org.maltparser.parser.guide.instance.FeatureDivideModel; 017 import org.maltparser.parser.guide.instance.InstanceModel; 018 import org.maltparser.parser.history.action.GuideDecision; 019 import org.maltparser.parser.history.action.MultipleDecision; 020 import org.maltparser.parser.history.action.SingleDecision; 021 import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision; 022 /** 023 * 024 * @author Johan Hall 025 * @since 1.1 026 **/ 027 public class BranchedDecisionModel implements DecisionModel { 028 private ClassifierGuide guide; 029 private String modelName; 030 private FeatureModel featureModel; 031 private InstanceModel instanceModel; 032 private int decisionIndex; 033 private DecisionModel parentDecisionModel; 034 private HashMap<Integer,DecisionModel> children; 035 private String branchedDecisionSymbols; 036 037 public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException { 038 this.branchedDecisionSymbols = ""; 039 setGuide(guide); 040 setFeatureModel(featureModel); 041 setDecisionIndex(0); 042 setModelName("bdm"+decisionIndex); 043 setParentDecisionModel(null); 044 } 045 046 public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException { 047 if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) { 048 this.branchedDecisionSymbols = branchedDecisionSymbol; 049 } else { 050 this.branchedDecisionSymbols = ""; 051 } 052 setGuide(guide); 053 setParentDecisionModel(parentDecisionModel); 054 setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1); 055 setFeatureModel(parentDecisionModel.getFeatureModel()); 056 if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) { 057 setModelName("bdm"+decisionIndex+branchedDecisionSymbols); 058 } else { 059 setModelName("bdm"+decisionIndex); 060 } 061 this.parentDecisionModel = parentDecisionModel; 062 } 063 064 public void updateFeatureModel() throws MaltChainedException { 065 featureModel.update(); 066 } 067 068 // public void updateCardinality() throws MaltChainedException { 069 // featureModel.updateCardinality(); 070 // } 071 072 073 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { 074 if (instanceModel != null) { 075 instanceModel.finalizeSentence(dependencyGraph); 076 } 077 if (children != null) { 078 for (DecisionModel child : children.values()) { 079 child.finalizeSentence(dependencyGraph); 080 } 081 } 082 } 083 084 public void noMoreInstances() throws MaltChainedException { 085 if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 086 throw new GuideException("The decision model could not create it's model. "); 087 } 088 if (instanceModel != null) { 089 instanceModel.noMoreInstances(); 090 instanceModel.train(); 091 } 092 if (children != null) { 093 for (DecisionModel child : children.values()) { 094 child.noMoreInstances(); 095 } 096 } 097 } 098 099 public void terminate() throws MaltChainedException { 100 if (instanceModel != null) { 101 instanceModel.terminate(); 102 instanceModel = null; 103 } 104 if (children != null) { 105 for (DecisionModel child : children.values()) { 106 child.terminate(); 107 } 108 } 109 } 110 111 public void addInstance(GuideDecision decision) throws MaltChainedException { 112 if (decision instanceof SingleDecision) { 113 throw new GuideException("A branched decision model expect more than one decisions. "); 114 } 115 featureModel.update(); 116 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 117 if (instanceModel == null) { 118 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 119 } 120 121 instanceModel.addInstance(singleDecision); 122 if (decisionIndex+1 < decision.numberOfDecisions()) { 123 if (singleDecision.continueWithNextDecision()) { 124 if (children == null) { 125 children = new HashMap<Integer,DecisionModel>(); 126 } 127 DecisionModel child = children.get(singleDecision.getDecisionCode()); 128 if (child == null) { 129 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 130 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 131 children.put(singleDecision.getDecisionCode(), child); 132 } 133 child.addInstance(decision); 134 } 135 } 136 } 137 138 public boolean predict(GuideDecision decision) throws MaltChainedException { 139 // if (decision instanceof SingleDecision) { 140 // throw new GuideException("A branched decision model expect more than one decisions. "); 141 // } 142 featureModel.update(); 143 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 144 if (instanceModel == null) { 145 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 146 } 147 instanceModel.predict(singleDecision); 148 if (decisionIndex+1 < decision.numberOfDecisions()) { 149 if (singleDecision.continueWithNextDecision()) { 150 if (children == null) { 151 children = new HashMap<Integer,DecisionModel>(); 152 } 153 DecisionModel child = children.get(singleDecision.getDecisionCode()); 154 if (child == null) { 155 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 156 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 157 children.put(singleDecision.getDecisionCode(), child); 158 } 159 child.predict(decision); 160 } 161 } 162 163 return true; 164 } 165 166 public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException { 167 if (decision instanceof SingleDecision) { 168 throw new GuideException("A branched decision model expect more than one decisions. "); 169 } 170 featureModel.update(); 171 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 172 if (instanceModel == null) { 173 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 174 } 175 FeatureVector fv = instanceModel.predictExtract(singleDecision); 176 if (decisionIndex+1 < decision.numberOfDecisions()) { 177 if (singleDecision.continueWithNextDecision()) { 178 if (children == null) { 179 children = new HashMap<Integer,DecisionModel>(); 180 } 181 DecisionModel child = children.get(singleDecision.getDecisionCode()); 182 if (child == null) { 183 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 184 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 185 children.put(singleDecision.getDecisionCode(), child); 186 } 187 child.predictExtract(decision); 188 } 189 } 190 191 return fv; 192 } 193 194 public FeatureVector extract() throws MaltChainedException { 195 featureModel.update(); 196 return instanceModel.extract(); // TODO handle many feature vectors 197 } 198 199 public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException { 200 if (decision instanceof SingleDecision) { 201 throw new GuideException("A branched decision model expect more than one decisions. "); 202 } 203 204 boolean success = false; 205 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 206 if (decisionIndex+1 < decision.numberOfDecisions()) { 207 if (singleDecision.continueWithNextDecision()) { 208 if (children == null) { 209 children = new HashMap<Integer,DecisionModel>(); 210 } 211 DecisionModel child = children.get(singleDecision.getDecisionCode()); 212 if (child != null) { 213 success = child.predictFromKBestList(decision); 214 } 215 216 } 217 } 218 if (!success) { 219 success = singleDecision.updateFromKBestList(); 220 if (decisionIndex+1 < decision.numberOfDecisions()) { 221 if (singleDecision.continueWithNextDecision()) { 222 if (children == null) { 223 children = new HashMap<Integer,DecisionModel>(); 224 } 225 DecisionModel child = children.get(singleDecision.getDecisionCode()); 226 if (child == null) { 227 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 228 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 229 children.put(singleDecision.getDecisionCode(), child); 230 } 231 child.predict(decision); 232 } 233 } 234 } 235 return success; 236 } 237 238 239 public ClassifierGuide getGuide() { 240 return guide; 241 } 242 243 public String getModelName() { 244 return modelName; 245 } 246 247 public FeatureModel getFeatureModel() { 248 return featureModel; 249 } 250 251 public int getDecisionIndex() { 252 return decisionIndex; 253 } 254 255 public DecisionModel getParentDecisionModel() { 256 return parentDecisionModel; 257 } 258 259 private void setFeatureModel(FeatureModel featureModel) { 260 this.featureModel = featureModel; 261 } 262 263 private void setDecisionIndex(int decisionIndex) { 264 this.decisionIndex = decisionIndex; 265 } 266 267 private void setParentDecisionModel(DecisionModel parentDecisionModel) { 268 this.parentDecisionModel = parentDecisionModel; 269 } 270 271 private void setModelName(String modelName) { 272 this.modelName = modelName; 273 } 274 275 private void setGuide(ClassifierGuide guide) { 276 this.guide = guide; 277 } 278 279 280 private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException { 281 Class<?> decisionModelClass = null; 282 if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) { 283 decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class; 284 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) { 285 decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class; 286 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) { 287 decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class; 288 } 289 290 if (decisionModelClass == null) { 291 throw new GuideException("Could not find an appropriate decision model for the relation to the next decision"); 292 } 293 294 try { 295 Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class, 296 java.lang.String.class }; 297 Object[] arguments = new Object[3]; 298 arguments[0] = getGuide(); 299 arguments[1] = this; 300 arguments[2] = branchedDecisionSymbol; 301 Constructor<?> constructor = decisionModelClass.getConstructor(argTypes); 302 return (DecisionModel)constructor.newInstance(arguments); 303 } catch (NoSuchMethodException e) { 304 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 305 } catch (InstantiationException e) { 306 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 307 } catch (IllegalAccessException e) { 308 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 309 } catch (InvocationTargetException e) { 310 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 311 } 312 } 313 314 private void initInstanceModel(String subModelName) throws MaltChainedException { 315 FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName); 316 if (fv == null) { 317 fv = featureModel.getFeatureVector(subModelName); 318 } 319 if (fv == null) { 320 fv = featureModel.getMainFeatureVector(); 321 } 322 323 DependencyParserConfig c = guide.getConfiguration(); 324 325 // if (c.getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes") || 326 // (c.getOptionValue("guide", "tree_split_columns")!=null && 327 // c.getOptionValue("guide", "tree_split_columns").toString().length() > 0) || 328 // (c.getOptionValue("guide", "tree_split_structures")!=null && 329 // c.getOptionValue("guide", "tree_split_structures").toString().length() > 0)) { 330 // instanceModel = new DecisionTreeModel(fv, this); 331 // }else 332 if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) { 333 instanceModel = new AtomicModel(-1, fv, this); 334 } else { 335 instanceModel = new FeatureDivideModel(fv, this); 336 } 337 } 338 339 public String toString() { 340 final StringBuilder sb = new StringBuilder(); 341 sb.append(modelName + ", "); 342 for (DecisionModel model : children.values()) { 343 sb.append(model.toString() + ", "); 344 } 345 return sb.toString(); 346 } 347 }