001 package org.maltparser.ml.lib; 002 003 import java.io.BufferedReader; 004 import java.io.EOFException; 005 import java.io.File; 006 import java.io.FileInputStream; 007 import java.io.IOException; 008 import java.io.InputStreamReader; 009 import java.io.ObjectInputStream; 010 import java.io.ObjectOutputStream; 011 import java.io.Reader; 012 import java.io.Serializable; 013 import java.nio.charset.Charset; 014 import java.util.Arrays; 015 import java.util.regex.Pattern; 016 017 import org.maltparser.core.helper.Util; 018 019 import liblinear.SolverType; 020 021 /** 022 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package. 023 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to 024 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 025 * structure directly on the model. </p> 026 * 027 * @author Johan Hall 028 * 029 */ 030 public class MaltLiblinearModel implements Serializable, MaltLibModel { 031 private static final long serialVersionUID = 7526471155622776147L; 032 private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); 033 private double bias; 034 /** label of each class */ 035 private int[] labels; 036 private int nr_class; 037 private int nr_feature; 038 private SolverType solverType; 039 /** feature weight array */ 040 private double[][] w; 041 042 public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) { 043 this.labels = labels; 044 this.nr_class = nr_class; 045 this.nr_feature = nr_feature; 046 this.w = w; 047 this.solverType = solverType; 048 } 049 050 public MaltLiblinearModel(Reader inputReader) throws IOException { 051 loadModel(inputReader); 052 } 053 054 public MaltLiblinearModel(File modelFile) throws IOException { 055 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET)); 056 loadModel(inputReader); 057 } 058 059 /** 060 * @return number of classes 061 */ 062 public int getNrClass() { 063 return nr_class; 064 } 065 066 /** 067 * @return number of features 068 */ 069 public int getNrFeature() { 070 return nr_feature; 071 } 072 073 public int[] getLabels() { 074 return Util.copyOf(labels, nr_class); 075 } 076 077 /** 078 * The nr_feature*nr_class array w gives feature weights. We use one 079 * against the rest for multi-class classification, so each feature 080 * index corresponds to nr_class weight values. Weights are 081 * organized in the following way 082 * 083 * <pre> 084 * +------------------+------------------+------------+ 085 * | nr_class weights | nr_class weights | ... 086 * | for 1st feature | for 2nd feature | 087 * +------------------+------------------+------------+ 088 * </pre> 089 * 090 * If bias >= 0, x becomes [x; bias]. The number of features is 091 * increased by one, so w is a (nr_feature+1)*nr_class array. The 092 * value of bias is stored in the variable bias. 093 * @see #getBias() 094 * @return a <b>copy of</b> the feature weight array as described 095 */ 096 // public double[] getFeatureWeights() { 097 // return Util.copyOf(w, w.length); 098 // } 099 100 /** 101 * @return true for logistic regression solvers 102 */ 103 public boolean isProbabilityModel() { 104 return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR); 105 } 106 107 public double getBias() { 108 return bias; 109 } 110 111 public int[] predict(MaltFeatureNode[] x) { 112 final double[] dec_values = new double[nr_class]; 113 final int n = (bias >= 0)?nr_feature + 1:nr_feature; 114 final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class; 115 final int xlen = x.length; 116 int i; 117 for (i = 0; i < nr_w; i++) { 118 dec_values[i] = 0; 119 } 120 121 for (i=0; i < xlen; i++) { 122 if (x[i].index <= n) { 123 int t = (x[i].index - 1); 124 if (w[t] != null) { 125 for (int j = 0; j < w[t].length; j++) { 126 dec_values[j] += w[t][j] * x[i].value; 127 } 128 } 129 } 130 } 131 132 final int[] predictionList = Util.copyOf(labels, nr_class); 133 double tmpDec; 134 int tmpObj; 135 int lagest; 136 final int nc = nr_class-1; 137 for (i=0; i < nc; i++) { 138 lagest = i; 139 for (int j=i; j < nr_class; j++) { 140 if (dec_values[j] > dec_values[lagest]) { 141 lagest = j; 142 } 143 } 144 tmpDec = dec_values[lagest]; 145 dec_values[lagest] = dec_values[i]; 146 dec_values[i] = tmpDec; 147 tmpObj = predictionList[lagest]; 148 predictionList[lagest] = predictionList[i]; 149 predictionList[i] = tmpObj; 150 } 151 return predictionList; 152 } 153 154 private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException { 155 is.defaultReadObject(); 156 } 157 158 private void writeObject(ObjectOutputStream os) throws IOException { 159 os.defaultWriteObject(); 160 } 161 162 private void loadModel(Reader inputReader) throws IOException { 163 labels = null; 164 Pattern whitespace = Pattern.compile("\\s+"); 165 BufferedReader reader = null; 166 if (inputReader instanceof BufferedReader) { 167 reader = (BufferedReader)inputReader; 168 } else { 169 reader = new BufferedReader(inputReader); 170 } 171 172 try { 173 String line = null; 174 while ((line = reader.readLine()) != null) { 175 String[] split = whitespace.split(line); 176 if (split[0].equals("solver_type")) { 177 SolverType solver = SolverType.valueOf(split[1]); 178 if (solver == null) { 179 throw new RuntimeException("unknown solver type"); 180 } 181 solverType = solver; 182 } else if (split[0].equals("nr_class")) { 183 nr_class = Util.atoi(split[1]); 184 Integer.parseInt(split[1]); 185 } else if (split[0].equals("nr_feature")) { 186 nr_feature = Util.atoi(split[1]); 187 } else if (split[0].equals("bias")) { 188 bias = Util.atof(split[1]); 189 } else if (split[0].equals("w")) { 190 break; 191 } else if (split[0].equals("label")) { 192 labels = new int[nr_class]; 193 for (int i = 0; i < nr_class; i++) { 194 labels[i] = Util.atoi(split[i + 1]); 195 } 196 } else { 197 throw new RuntimeException("unknown text in model file: [" + line + "]"); 198 } 199 } 200 201 int w_size = nr_feature; 202 if (bias >= 0) w_size++; 203 204 int nr_w = nr_class; 205 if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1; 206 w = new double[w_size][nr_w]; 207 int[] buffer = new int[128]; 208 209 for (int i = 0; i < w_size; i++) { 210 for (int j = 0; j < nr_w; j++) { 211 int b = 0; 212 while (true) { 213 int ch = reader.read(); 214 if (ch == -1) { 215 throw new EOFException("unexpected EOF"); 216 } 217 if (ch == ' ') { 218 w[i][j] = Util.atof(new String(buffer, 0, b)); 219 break; 220 } else { 221 buffer[b++] = ch; 222 } 223 } 224 } 225 } 226 } 227 finally { 228 Util.closeQuietly(reader); 229 } 230 } 231 232 public int hashCode() { 233 final int prime = 31; 234 long temp = Double.doubleToLongBits(bias); 235 int result = prime * 1 + (int)(temp ^ (temp >>> 32)); 236 result = prime * result + Arrays.hashCode(labels); 237 result = prime * result + nr_class; 238 result = prime * result + nr_feature; 239 result = prime * result + ((solverType == null) ? 0 : solverType.hashCode()); 240 for (int i = 0; i < w.length; i++) { 241 result = prime * result + Arrays.hashCode(w[i]); 242 } 243 return result; 244 } 245 246 public boolean equals(Object obj) { 247 if (this == obj) return true; 248 if (obj == null) return false; 249 if (getClass() != obj.getClass()) return false; 250 MaltLiblinearModel other = (MaltLiblinearModel)obj; 251 if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false; 252 if (!Arrays.equals(labels, other.labels)) return false; 253 if (nr_class != other.nr_class) return false; 254 if (nr_feature != other.nr_feature) return false; 255 if (solverType == null) { 256 if (other.solverType != null) return false; 257 } else if (!solverType.equals(other.solverType)) return false; 258 for (int i = 0; i < w.length; i++) { 259 if (other.w.length <= i) return false; 260 if (!Util.equals(w[i], other.w[i])) return false; 261 } 262 return true; 263 } 264 265 public String toString() { 266 final StringBuilder sb = new StringBuilder("Model"); 267 sb.append(" bias=").append(bias); 268 sb.append(" nr_class=").append(nr_class); 269 sb.append(" nr_feature=").append(nr_feature); 270 sb.append(" solverType=").append(solverType); 271 return sb.toString(); 272 } 273 }