001package org.maltparser.ml.lib; 002 003import java.io.BufferedReader; 004import java.io.EOFException; 005import java.io.File; 006import java.io.FileInputStream; 007import java.io.IOException; 008import java.io.InputStreamReader; 009import java.io.ObjectInputStream; 010import java.io.ObjectOutputStream; 011import java.io.Reader; 012import java.io.Serializable; 013import java.nio.charset.Charset; 014import java.util.Arrays; 015import java.util.regex.Pattern; 016 017import org.maltparser.core.helper.Util; 018 019import de.bwaldvogel.liblinear.SolverType; 020 021/** 022 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package. 023 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to 024 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 025 * structure directly on the model. </p> 026 * 027 * @author Johan Hall 028 * 029 */ 030public class MaltLiblinearModel implements Serializable, MaltLibModel { 031 private static final long serialVersionUID = 7526471155622776147L; 032 private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); 033 private double bias; 034 /** label of each class */ 035 private int[] labels; 036 private int nr_class; 037 private int nr_feature; 038 private SolverType solverType; 039 /** feature weight array */ 040 private double[][] w; 041 042 public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) { 043 this.labels = labels; 044 this.nr_class = nr_class; 045 this.nr_feature = nr_feature; 046 this.w = w; 047 this.solverType = solverType; 048 } 049 050 public MaltLiblinearModel(Reader inputReader) throws IOException { 051 loadModel(inputReader); 052 } 053 054 public MaltLiblinearModel(File modelFile) throws IOException { 055 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET)); 056 loadModel(inputReader); 057 } 058 059 /** 060 * @return number of classes 061 */ 062 public int getNrClass() { 063 return nr_class; 064 } 065 066 /** 067 * @return number of features 068 */ 069 public int getNrFeature() { 070 return nr_feature; 071 } 072 073 public int[] getLabels() { 074 return Util.copyOf(labels, nr_class); 075 } 076 077 /** 078 * The nr_feature*nr_class array w gives feature weights. We use one 079 * against the rest for multi-class classification, so each feature 080 * index corresponds to nr_class weight values. Weights are 081 * organized in the following way 082 * 083 * <pre> 084 * +------------------+------------------+------------+ 085 * | nr_class weights | nr_class weights | ... 086 * | for 1st feature | for 2nd feature | 087 * +------------------+------------------+------------+ 088 * </pre> 089 * 090 * If bias >= 0, x becomes [x; bias]. The number of features is 091 * increased by one, so w is a (nr_feature+1)*nr_class array. The 092 * value of bias is stored in the variable bias. 093 * @see #getBias() 094 * @return a <b>copy of</b> the feature weight array as described 095 */ 096// public double[] getFeatureWeights() { 097// return Util.copyOf(w, w.length); 098// } 099 100 /** 101 * @return true for logistic regression solvers 102 */ 103 public boolean isProbabilityModel() { 104 return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR); 105 } 106 107 public double getBias() { 108 return bias; 109 } 110 111 public int[] predict(MaltFeatureNode[] x) { 112 final double[] dec_values = new double[nr_class]; 113 final int n = (bias >= 0)?nr_feature + 1:nr_feature; 114 final int xlen = x.length; 115 116 for (int i=0; i < xlen; i++) { 117 if (x[i].index <= n) { 118 final int t = (x[i].index - 1); 119 if (w[t] != null) { 120 for (int j = 0; j < w[t].length; j++) { 121 dec_values[j] += w[t][j] * x[i].value; 122 } 123 } 124 } 125 } 126 127 128 double tmpDec; 129 int tmpObj; 130 int iMax; 131 final int[] predictionList = new int[nr_class]; 132 System.arraycopy(labels, 0, predictionList, 0, nr_class); 133 final int nc = nr_class-1; 134 for (int i=0; i < nc; i++) { 135 iMax = i; 136 for (int j=i+1; j < nr_class; j++) { 137 if (dec_values[j] > dec_values[iMax]) { 138 iMax = j; 139 } 140 } 141 if (iMax != i) { 142 tmpDec = dec_values[iMax]; 143 dec_values[iMax] = dec_values[i]; 144 dec_values[i] = tmpDec; 145 tmpObj = predictionList[iMax]; 146 predictionList[iMax] = predictionList[i]; 147 predictionList[i] = tmpObj; 148 } 149 } 150 return predictionList; 151 } 152 153 public int predict_one(MaltFeatureNode[] x) { 154 final double[] dec_values = new double[nr_class]; 155 final int n = (bias >= 0)?nr_feature + 1:nr_feature; 156 final int xlen = x.length; 157 158 for (int i=0; i < xlen; i++) { 159 if (x[i].index <= n) { 160 final int t = (x[i].index - 1); 161 if (w[t] != null) { 162 for (int j = 0; j < w[t].length; j++) { 163 dec_values[j] += w[t][j] * x[i].value; 164 } 165 } 166 } 167 } 168 169 double max = dec_values[0]; 170 int max_index = 0; 171 for (int i = 1; i < dec_values.length; i++) { 172 if (dec_values[i] > max) { 173 max = dec_values[i]; 174 max_index = i; 175 } 176 } 177 178 return labels[max_index]; 179 } 180 181 private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException { 182 is.defaultReadObject(); 183 } 184 185 private void writeObject(ObjectOutputStream os) throws IOException { 186 os.defaultWriteObject(); 187 } 188 189 private void loadModel(Reader inputReader) throws IOException { 190 labels = null; 191 Pattern whitespace = Pattern.compile("\\s+"); 192 BufferedReader reader = null; 193 if (inputReader instanceof BufferedReader) { 194 reader = (BufferedReader)inputReader; 195 } else { 196 reader = new BufferedReader(inputReader); 197 } 198 199 try { 200 String line = null; 201 while ((line = reader.readLine()) != null) { 202 String[] split = whitespace.split(line); 203 if (split[0].equals("solver_type")) { 204 SolverType solver = SolverType.valueOf(split[1]); 205 if (solver == null) { 206 throw new RuntimeException("unknown solver type"); 207 } 208 solverType = solver; 209 } else if (split[0].equals("nr_class")) { 210 nr_class = Util.atoi(split[1]); 211 Integer.parseInt(split[1]); 212 } else if (split[0].equals("nr_feature")) { 213 nr_feature = Util.atoi(split[1]); 214 } else if (split[0].equals("bias")) { 215 bias = Util.atof(split[1]); 216 } else if (split[0].equals("w")) { 217 break; 218 } else if (split[0].equals("label")) { 219 labels = new int[nr_class]; 220 for (int i = 0; i < nr_class; i++) { 221 labels[i] = Util.atoi(split[i + 1]); 222 } 223 } else { 224 throw new RuntimeException("unknown text in model file: [" + line + "]"); 225 } 226 } 227 228 int w_size = nr_feature; 229 if (bias >= 0) w_size++; 230 231 int nr_w = nr_class; 232 if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1; 233 w = new double[w_size][nr_w]; 234 int[] buffer = new int[128]; 235 236 for (int i = 0; i < w_size; i++) { 237 for (int j = 0; j < nr_w; j++) { 238 int b = 0; 239 while (true) { 240 int ch = reader.read(); 241 if (ch == -1) { 242 throw new EOFException("unexpected EOF"); 243 } 244 if (ch == ' ') { 245 w[i][j] = Util.atof(new String(buffer, 0, b)); 246 break; 247 } else { 248 buffer[b++] = ch; 249 } 250 } 251 } 252 } 253 } 254 finally { 255 Util.closeQuietly(reader); 256 } 257 } 258 259 public int hashCode() { 260 final int prime = 31; 261 long temp = Double.doubleToLongBits(bias); 262 int result = prime * 1 + (int)(temp ^ (temp >>> 32)); 263 result = prime * result + Arrays.hashCode(labels); 264 result = prime * result + nr_class; 265 result = prime * result + nr_feature; 266 result = prime * result + ((solverType == null) ? 0 : solverType.hashCode()); 267 for (int i = 0; i < w.length; i++) { 268 result = prime * result + Arrays.hashCode(w[i]); 269 } 270 return result; 271 } 272 273 public boolean equals(Object obj) { 274 if (this == obj) return true; 275 if (obj == null) return false; 276 if (getClass() != obj.getClass()) return false; 277 MaltLiblinearModel other = (MaltLiblinearModel)obj; 278 if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false; 279 if (!Arrays.equals(labels, other.labels)) return false; 280 if (nr_class != other.nr_class) return false; 281 if (nr_feature != other.nr_feature) return false; 282 if (solverType == null) { 283 if (other.solverType != null) return false; 284 } else if (!solverType.equals(other.solverType)) return false; 285 for (int i = 0; i < w.length; i++) { 286 if (other.w.length <= i) return false; 287 if (!Util.equals(w[i], other.w[i])) return false; 288 } 289 return true; 290 } 291 292 public String toString() { 293 final StringBuilder sb = new StringBuilder("Model"); 294 sb.append(" bias=").append(bias); 295 sb.append(" nr_class=").append(nr_class); 296 sb.append(" nr_feature=").append(nr_feature); 297 sb.append(" solverType=").append(solverType); 298 return sb.toString(); 299 } 300}