001 package org.maltparser.parser.guide.decision; 002 003 import java.lang.reflect.Constructor; 004 import java.lang.reflect.InvocationTargetException; 005 import java.util.HashMap; 006 007 import org.maltparser.core.exception.MaltChainedException; 008 import org.maltparser.core.feature.FeatureModel; 009 import org.maltparser.core.feature.FeatureVector; 010 import org.maltparser.core.syntaxgraph.DependencyStructure; 011 import org.maltparser.parser.guide.ClassifierGuide; 012 import org.maltparser.parser.guide.GuideException; 013 import org.maltparser.parser.guide.instance.AtomicModel; 014 import org.maltparser.parser.guide.instance.FeatureDivideModel; 015 import org.maltparser.parser.guide.instance.InstanceModel; 016 import org.maltparser.parser.history.action.GuideDecision; 017 import org.maltparser.parser.history.action.MultipleDecision; 018 import org.maltparser.parser.history.action.SingleDecision; 019 import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision; 020 /** 021 * 022 * @author Johan Hall 023 * @since 1.1 024 **/ 025 public class BranchedDecisionModel implements DecisionModel { 026 private ClassifierGuide guide; 027 private String modelName; 028 private FeatureModel featureModel; 029 private InstanceModel instanceModel; 030 private int decisionIndex; 031 private DecisionModel parentDecisionModel; 032 private HashMap<Integer,DecisionModel> children; 033 private String branchedDecisionSymbols; 034 035 public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException { 036 this.branchedDecisionSymbols = ""; 037 setGuide(guide); 038 setFeatureModel(featureModel); 039 setDecisionIndex(0); 040 setModelName("bdm"+decisionIndex); 041 setParentDecisionModel(null); 042 } 043 044 public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException { 045 if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) { 046 this.branchedDecisionSymbols = branchedDecisionSymbol; 047 } else { 048 this.branchedDecisionSymbols = ""; 049 } 050 setGuide(guide); 051 setParentDecisionModel(parentDecisionModel); 052 setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1); 053 setFeatureModel(parentDecisionModel.getFeatureModel()); 054 if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) { 055 setModelName("bdm"+decisionIndex+branchedDecisionSymbols); 056 } else { 057 setModelName("bdm"+decisionIndex); 058 } 059 this.parentDecisionModel = parentDecisionModel; 060 } 061 062 public void updateFeatureModel() throws MaltChainedException { 063 featureModel.update(); 064 } 065 066 public void updateCardinality() throws MaltChainedException { 067 featureModel.updateCardinality(); 068 } 069 070 071 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { 072 if (instanceModel != null) { 073 instanceModel.finalizeSentence(dependencyGraph); 074 } 075 if (children != null) { 076 for (DecisionModel child : children.values()) { 077 child.finalizeSentence(dependencyGraph); 078 } 079 } 080 } 081 082 public void noMoreInstances() throws MaltChainedException { 083 if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 084 throw new GuideException("The decision model could not create it's model. "); 085 } 086 featureModel.updateCardinality(); 087 if (instanceModel != null) { 088 instanceModel.noMoreInstances(); 089 instanceModel.train(); 090 } 091 if (children != null) { 092 for (DecisionModel child : children.values()) { 093 child.noMoreInstances(); 094 } 095 } 096 } 097 098 public void terminate() throws MaltChainedException { 099 if (instanceModel != null) { 100 instanceModel.terminate(); 101 instanceModel = null; 102 } 103 if (children != null) { 104 for (DecisionModel child : children.values()) { 105 child.terminate(); 106 } 107 } 108 } 109 110 public void addInstance(GuideDecision decision) throws MaltChainedException { 111 if (decision instanceof SingleDecision) { 112 throw new GuideException("A branched decision model expect more than one decisions. "); 113 } 114 updateFeatureModel(); 115 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 116 if (instanceModel == null) { 117 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 118 } 119 120 instanceModel.addInstance(singleDecision); 121 if (decisionIndex+1 < decision.numberOfDecisions()) { 122 if (singleDecision.continueWithNextDecision()) { 123 if (children == null) { 124 children = new HashMap<Integer,DecisionModel>(); 125 } 126 DecisionModel child = children.get(singleDecision.getDecisionCode()); 127 if (child == null) { 128 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 129 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 130 children.put(singleDecision.getDecisionCode(), child); 131 } 132 child.addInstance(decision); 133 } 134 } 135 } 136 137 public boolean predict(GuideDecision decision) throws MaltChainedException { 138 if (decision instanceof SingleDecision) { 139 throw new GuideException("A branched decision model expect more than one decisions. "); 140 } 141 updateFeatureModel(); 142 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 143 if (instanceModel == null) { 144 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 145 } 146 instanceModel.predict(singleDecision); 147 if (decisionIndex+1 < decision.numberOfDecisions()) { 148 if (singleDecision.continueWithNextDecision()) { 149 if (children == null) { 150 children = new HashMap<Integer,DecisionModel>(); 151 } 152 DecisionModel child = children.get(singleDecision.getDecisionCode()); 153 if (child == null) { 154 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 155 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 156 children.put(singleDecision.getDecisionCode(), child); 157 } 158 child.predict(decision); 159 } 160 } 161 162 return true; 163 } 164 165 public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException { 166 if (decision instanceof SingleDecision) { 167 throw new GuideException("A branched decision model expect more than one decisions. "); 168 } 169 updateFeatureModel(); 170 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 171 if (instanceModel == null) { 172 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 173 } 174 FeatureVector fv = instanceModel.predictExtract(singleDecision); 175 if (decisionIndex+1 < decision.numberOfDecisions()) { 176 if (singleDecision.continueWithNextDecision()) { 177 if (children == null) { 178 children = new HashMap<Integer,DecisionModel>(); 179 } 180 DecisionModel child = children.get(singleDecision.getDecisionCode()); 181 if (child == null) { 182 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 183 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 184 children.put(singleDecision.getDecisionCode(), child); 185 } 186 child.predictExtract(decision); 187 } 188 } 189 190 return fv; 191 } 192 193 public FeatureVector extract() throws MaltChainedException { 194 updateFeatureModel(); 195 return instanceModel.extract(); // TODO handle many feature vectors 196 } 197 198 public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException { 199 if (decision instanceof SingleDecision) { 200 throw new GuideException("A branched decision model expect more than one decisions. "); 201 } 202 203 boolean success = false; 204 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 205 if (decisionIndex+1 < decision.numberOfDecisions()) { 206 if (singleDecision.continueWithNextDecision()) { 207 if (children == null) { 208 children = new HashMap<Integer,DecisionModel>(); 209 } 210 DecisionModel child = children.get(singleDecision.getDecisionCode()); 211 if (child != null) { 212 success = child.predictFromKBestList(decision); 213 } 214 215 } 216 } 217 if (!success) { 218 success = singleDecision.updateFromKBestList(); 219 if (decisionIndex+1 < decision.numberOfDecisions()) { 220 if (singleDecision.continueWithNextDecision()) { 221 if (children == null) { 222 children = new HashMap<Integer,DecisionModel>(); 223 } 224 DecisionModel child = children.get(singleDecision.getDecisionCode()); 225 if (child == null) { 226 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 227 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 228 children.put(singleDecision.getDecisionCode(), child); 229 } 230 child.predict(decision); 231 } 232 } 233 } 234 return success; 235 } 236 237 238 public ClassifierGuide getGuide() { 239 return guide; 240 } 241 242 public String getModelName() { 243 return modelName; 244 } 245 246 public FeatureModel getFeatureModel() { 247 return featureModel; 248 } 249 250 public int getDecisionIndex() { 251 return decisionIndex; 252 } 253 254 public DecisionModel getParentDecisionModel() { 255 return parentDecisionModel; 256 } 257 258 private void setFeatureModel(FeatureModel featureModel) { 259 this.featureModel = featureModel; 260 } 261 262 private void setDecisionIndex(int decisionIndex) { 263 this.decisionIndex = decisionIndex; 264 } 265 266 private void setParentDecisionModel(DecisionModel parentDecisionModel) { 267 this.parentDecisionModel = parentDecisionModel; 268 } 269 270 private void setModelName(String modelName) { 271 this.modelName = modelName; 272 } 273 274 private void setGuide(ClassifierGuide guide) { 275 this.guide = guide; 276 } 277 278 279 private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException { 280 Class<?> decisionModelClass = null; 281 if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) { 282 decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class; 283 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) { 284 decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class; 285 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) { 286 decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class; 287 } 288 289 if (decisionModelClass == null) { 290 throw new GuideException("Could not find an appropriate decision model for the relation to the next decision"); 291 } 292 293 try { 294 Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class, 295 java.lang.String.class }; 296 Object[] arguments = new Object[3]; 297 arguments[0] = getGuide(); 298 arguments[1] = this; 299 arguments[2] = branchedDecisionSymbol; 300 Constructor<?> constructor = decisionModelClass.getConstructor(argTypes); 301 return (DecisionModel)constructor.newInstance(arguments); 302 } catch (NoSuchMethodException e) { 303 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 304 } catch (InstantiationException e) { 305 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 306 } catch (IllegalAccessException e) { 307 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 308 } catch (InvocationTargetException e) { 309 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 310 } 311 } 312 313 private void initInstanceModel(String subModelName) throws MaltChainedException { 314 FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName); 315 if (fv == null) { 316 fv = featureModel.getFeatureVector(subModelName); 317 } 318 if (fv == null) { 319 fv = featureModel.getMainFeatureVector(); 320 } 321 if (guide.getConfiguration().getOptionValue("guide", "data_split_column").toString().length() == 0) { 322 instanceModel = new AtomicModel(-1, fv, this); 323 } else { 324 instanceModel = new FeatureDivideModel(fv, this); 325 } 326 } 327 328 public String toString() { 329 final StringBuilder sb = new StringBuilder(); 330 sb.append(modelName + ", "); 331 for (DecisionModel model : children.values()) { 332 sb.append(model.toString() + ", "); 333 } 334 return sb.toString(); 335 } 336 }