001 package org.maltparser.ml.lib; 002 003 import java.io.BufferedReader; 004 import java.io.EOFException; 005 import java.io.File; 006 import java.io.FileInputStream; 007 import java.io.IOException; 008 import java.io.InputStreamReader; 009 import java.io.ObjectInputStream; 010 import java.io.ObjectOutputStream; 011 import java.io.Reader; 012 import java.io.Serializable; 013 import java.nio.charset.Charset; 014 import java.util.Arrays; 015 import java.util.regex.Pattern; 016 017 import org.maltparser.core.helper.Util; 018 019 import de.bwaldvogel.liblinear.SolverType; 020 021 /** 022 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package. 023 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to 024 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 025 * structure directly on the model. </p> 026 * 027 * @author Johan Hall 028 * 029 */ 030 public class MaltLiblinearModel implements Serializable, MaltLibModel { 031 private static final long serialVersionUID = 7526471155622776147L; 032 private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); 033 private double bias; 034 /** label of each class */ 035 private int[] labels; 036 private int nr_class; 037 private int nr_feature; 038 private SolverType solverType; 039 /** feature weight array */ 040 private double[][] w; 041 042 public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) { 043 this.labels = labels; 044 this.nr_class = nr_class; 045 this.nr_feature = nr_feature; 046 this.w = w; 047 this.solverType = solverType; 048 } 049 050 public MaltLiblinearModel(Reader inputReader) throws IOException { 051 loadModel(inputReader); 052 } 053 054 public MaltLiblinearModel(File modelFile) throws IOException { 055 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET)); 056 loadModel(inputReader); 057 } 058 059 /** 060 * @return number of classes 061 */ 062 public int getNrClass() { 063 return nr_class; 064 } 065 066 /** 067 * @return number of features 068 */ 069 public int getNrFeature() { 070 return nr_feature; 071 } 072 073 public int[] getLabels() { 074 return Util.copyOf(labels, nr_class); 075 } 076 077 /** 078 * The nr_feature*nr_class array w gives feature weights. We use one 079 * against the rest for multi-class classification, so each feature 080 * index corresponds to nr_class weight values. Weights are 081 * organized in the following way 082 * 083 * <pre> 084 * +------------------+------------------+------------+ 085 * | nr_class weights | nr_class weights | ... 086 * | for 1st feature | for 2nd feature | 087 * +------------------+------------------+------------+ 088 * </pre> 089 * 090 * If bias >= 0, x becomes [x; bias]. The number of features is 091 * increased by one, so w is a (nr_feature+1)*nr_class array. The 092 * value of bias is stored in the variable bias. 093 * @see #getBias() 094 * @return a <b>copy of</b> the feature weight array as described 095 */ 096 // public double[] getFeatureWeights() { 097 // return Util.copyOf(w, w.length); 098 // } 099 100 /** 101 * @return true for logistic regression solvers 102 */ 103 public boolean isProbabilityModel() { 104 return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR); 105 } 106 107 public double getBias() { 108 return bias; 109 } 110 111 public int[] predict(MaltFeatureNode[] x) { 112 final double[] dec_values = new double[nr_class]; 113 final int[] predictionList = Util.copyOf(labels, nr_class); 114 final int n = (bias >= 0)?nr_feature + 1:nr_feature; 115 // final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class; 116 final int xlen = x.length; 117 // int i; 118 // for (i = 0; i < nr_w; i++) { 119 // dec_values[i] = 0; 120 // } 121 122 for (int i=0; i < xlen; i++) { 123 if (x[i].index <= n) { 124 final int t = (x[i].index - 1); 125 if (w[t] != null) { 126 for (int j = 0; j < w[t].length; j++) { 127 dec_values[j] += w[t][j] * x[i].value; 128 } 129 } 130 } 131 } 132 133 134 double tmpDec; 135 int tmpObj; 136 int lagest; 137 final int nc = nr_class-1; 138 for (int i=0; i < nc; i++) { 139 lagest = i; 140 for (int j=i; j < nr_class; j++) { 141 if (dec_values[j] > dec_values[lagest]) { 142 lagest = j; 143 } 144 } 145 tmpDec = dec_values[lagest]; 146 dec_values[lagest] = dec_values[i]; 147 dec_values[i] = tmpDec; 148 tmpObj = predictionList[lagest]; 149 predictionList[lagest] = predictionList[i]; 150 predictionList[i] = tmpObj; 151 } 152 return predictionList; 153 } 154 155 private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException { 156 is.defaultReadObject(); 157 } 158 159 private void writeObject(ObjectOutputStream os) throws IOException { 160 os.defaultWriteObject(); 161 } 162 163 private void loadModel(Reader inputReader) throws IOException { 164 labels = null; 165 Pattern whitespace = Pattern.compile("\\s+"); 166 BufferedReader reader = null; 167 if (inputReader instanceof BufferedReader) { 168 reader = (BufferedReader)inputReader; 169 } else { 170 reader = new BufferedReader(inputReader); 171 } 172 173 try { 174 String line = null; 175 while ((line = reader.readLine()) != null) { 176 String[] split = whitespace.split(line); 177 if (split[0].equals("solver_type")) { 178 SolverType solver = SolverType.valueOf(split[1]); 179 if (solver == null) { 180 throw new RuntimeException("unknown solver type"); 181 } 182 solverType = solver; 183 } else if (split[0].equals("nr_class")) { 184 nr_class = Util.atoi(split[1]); 185 Integer.parseInt(split[1]); 186 } else if (split[0].equals("nr_feature")) { 187 nr_feature = Util.atoi(split[1]); 188 } else if (split[0].equals("bias")) { 189 bias = Util.atof(split[1]); 190 } else if (split[0].equals("w")) { 191 break; 192 } else if (split[0].equals("label")) { 193 labels = new int[nr_class]; 194 for (int i = 0; i < nr_class; i++) { 195 labels[i] = Util.atoi(split[i + 1]); 196 } 197 } else { 198 throw new RuntimeException("unknown text in model file: [" + line + "]"); 199 } 200 } 201 202 int w_size = nr_feature; 203 if (bias >= 0) w_size++; 204 205 int nr_w = nr_class; 206 if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1; 207 w = new double[w_size][nr_w]; 208 int[] buffer = new int[128]; 209 210 for (int i = 0; i < w_size; i++) { 211 for (int j = 0; j < nr_w; j++) { 212 int b = 0; 213 while (true) { 214 int ch = reader.read(); 215 if (ch == -1) { 216 throw new EOFException("unexpected EOF"); 217 } 218 if (ch == ' ') { 219 w[i][j] = Util.atof(new String(buffer, 0, b)); 220 break; 221 } else { 222 buffer[b++] = ch; 223 } 224 } 225 } 226 } 227 } 228 finally { 229 Util.closeQuietly(reader); 230 } 231 } 232 233 public int hashCode() { 234 final int prime = 31; 235 long temp = Double.doubleToLongBits(bias); 236 int result = prime * 1 + (int)(temp ^ (temp >>> 32)); 237 result = prime * result + Arrays.hashCode(labels); 238 result = prime * result + nr_class; 239 result = prime * result + nr_feature; 240 result = prime * result + ((solverType == null) ? 0 : solverType.hashCode()); 241 for (int i = 0; i < w.length; i++) { 242 result = prime * result + Arrays.hashCode(w[i]); 243 } 244 return result; 245 } 246 247 public boolean equals(Object obj) { 248 if (this == obj) return true; 249 if (obj == null) return false; 250 if (getClass() != obj.getClass()) return false; 251 MaltLiblinearModel other = (MaltLiblinearModel)obj; 252 if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false; 253 if (!Arrays.equals(labels, other.labels)) return false; 254 if (nr_class != other.nr_class) return false; 255 if (nr_feature != other.nr_feature) return false; 256 if (solverType == null) { 257 if (other.solverType != null) return false; 258 } else if (!solverType.equals(other.solverType)) return false; 259 for (int i = 0; i < w.length; i++) { 260 if (other.w.length <= i) return false; 261 if (!Util.equals(w[i], other.w[i])) return false; 262 } 263 return true; 264 } 265 266 public String toString() { 267 final StringBuilder sb = new StringBuilder("Model"); 268 sb.append(" bias=").append(bias); 269 sb.append(" nr_class=").append(nr_class); 270 sb.append(" nr_feature=").append(nr_feature); 271 sb.append(" solverType=").append(solverType); 272 return sb.toString(); 273 } 274 }