001 package org.maltparser.ml.lib; 002 003 import java.io.BufferedReader; 004 import java.io.EOFException; 005 import java.io.File; 006 import java.io.FileInputStream; 007 import java.io.IOException; 008 import java.io.InputStreamReader; 009 import java.io.ObjectInputStream; 010 import java.io.ObjectOutputStream; 011 import java.io.Reader; 012 import java.io.Serializable; 013 import java.nio.charset.Charset; 014 import java.util.Arrays; 015 import java.util.regex.Pattern; 016 017 import org.maltparser.core.helper.Util; 018 019 import liblinear.Model; 020 import liblinear.SolverType; 021 022 /** 023 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package. 024 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to 025 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data 026 * structure directly on the model. </p> 027 * 028 * @author Johan Hall 029 * 030 */ 031 public class MaltLiblinearModel implements Serializable, MaltLibModel { 032 private static final long serialVersionUID = 7526471155622776147L; 033 private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); 034 private double bias; 035 /** label of each class */ 036 private int[] labels; 037 private int nr_class; 038 private int nr_feature; 039 private SolverType solverType; 040 /** feature weight array */ 041 private double[] w; 042 043 public MaltLiblinearModel(Model model, SolverType solverType) { 044 labels = model.getLabels(); 045 nr_class = model.getNrClass(); 046 nr_feature = model.getNrFeature(); 047 this.solverType = solverType; 048 w = model.getFeatureWeights(); 049 } 050 051 public MaltLiblinearModel(Reader inputReader) throws IOException { 052 loadModel(inputReader); 053 } 054 055 public MaltLiblinearModel(File modelFile) throws IOException { 056 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET)); 057 loadModel(inputReader); 058 } 059 060 /** 061 * @return number of classes 062 */ 063 public int getNrClass() { 064 return nr_class; 065 } 066 067 /** 068 * @return number of features 069 */ 070 public int getNrFeature() { 071 return nr_feature; 072 } 073 074 public int[] getLabels() { 075 return Util.copyOf(labels, nr_class); 076 } 077 078 /** 079 * The nr_feature*nr_class array w gives feature weights. We use one 080 * against the rest for multi-class classification, so each feature 081 * index corresponds to nr_class weight values. Weights are 082 * organized in the following way 083 * 084 * <pre> 085 * +------------------+------------------+------------+ 086 * | nr_class weights | nr_class weights | ... 087 * | for 1st feature | for 2nd feature | 088 * +------------------+------------------+------------+ 089 * </pre> 090 * 091 * If bias >= 0, x becomes [x; bias]. The number of features is 092 * increased by one, so w is a (nr_feature+1)*nr_class array. The 093 * value of bias is stored in the variable bias. 094 * @see #getBias() 095 * @return a <b>copy of</b> the feature weight array as described 096 */ 097 public double[] getFeatureWeights() { 098 return Util.copyOf(w, w.length); 099 } 100 101 /** 102 * @return true for logistic regression solvers 103 */ 104 public boolean isProbabilityModel() { 105 return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR); 106 } 107 108 public double getBias() { 109 return bias; 110 } 111 112 public int[] predict(MaltFeatureNode[] x) { 113 final double[] dec_values = new double[nr_class]; 114 final int n = (bias >= 0)?nr_feature + 1:nr_feature; 115 final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class; 116 final int xlen = x.length; 117 int i; 118 for (i = 0; i < nr_w; i++) { 119 dec_values[i] = 0; 120 } 121 122 for (i=0; i < xlen; i++) { 123 if (x[i].index <= n) { 124 int t = (x[i].index - 1) * nr_w; 125 for (int j = 0; j < nr_w; j++) { 126 dec_values[j] += w[t + j] * x[i].value; 127 } 128 } 129 } 130 131 final int[] predictionList = Util.copyOf(labels, nr_class); 132 double tmpDec; 133 int tmpObj; 134 int lagest; 135 final int nc = nr_class-1; 136 for (i=0; i < nc; i++) { 137 lagest = i; 138 for (int j=i; j < nr_class; j++) { 139 if (dec_values[j] > dec_values[lagest]) { 140 lagest = j; 141 } 142 } 143 tmpDec = dec_values[lagest]; 144 dec_values[lagest] = dec_values[i]; 145 dec_values[i] = tmpDec; 146 tmpObj = predictionList[lagest]; 147 predictionList[lagest] = predictionList[i]; 148 predictionList[i] = tmpObj; 149 } 150 return predictionList; 151 } 152 153 private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException { 154 is.defaultReadObject(); 155 } 156 157 private void writeObject(ObjectOutputStream os) throws IOException { 158 os.defaultWriteObject(); 159 } 160 161 private void loadModel(Reader inputReader) throws IOException { 162 labels = null; 163 Pattern whitespace = Pattern.compile("\\s+"); 164 BufferedReader reader = null; 165 if (inputReader instanceof BufferedReader) { 166 reader = (BufferedReader)inputReader; 167 } else { 168 reader = new BufferedReader(inputReader); 169 } 170 171 try { 172 String line = null; 173 while ((line = reader.readLine()) != null) { 174 String[] split = whitespace.split(line); 175 if (split[0].equals("solver_type")) { 176 SolverType solver = SolverType.valueOf(split[1]); 177 if (solver == null) { 178 throw new RuntimeException("unknown solver type"); 179 } 180 solverType = solver; 181 } else if (split[0].equals("nr_class")) { 182 nr_class = Util.atoi(split[1]); 183 Integer.parseInt(split[1]); 184 } else if (split[0].equals("nr_feature")) { 185 nr_feature = Util.atoi(split[1]); 186 } else if (split[0].equals("bias")) { 187 bias = Util.atof(split[1]); 188 } else if (split[0].equals("w")) { 189 break; 190 } else if (split[0].equals("label")) { 191 labels = new int[nr_class]; 192 for (int i = 0; i < nr_class; i++) { 193 labels[i] = Util.atoi(split[i + 1]); 194 } 195 } else { 196 throw new RuntimeException("unknown text in model file: [" + line + "]"); 197 } 198 } 199 200 int w_size = nr_feature; 201 if (bias >= 0) w_size++; 202 203 int nr_w = nr_class; 204 if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1; 205 206 w = new double[w_size * nr_w]; 207 int[] buffer = new int[128]; 208 209 for (int i = 0; i < w_size; i++) { 210 for (int j = 0; j < nr_w; j++) { 211 int b = 0; 212 while (true) { 213 int ch = reader.read(); 214 if (ch == -1) { 215 throw new EOFException("unexpected EOF"); 216 } 217 if (ch == ' ') { 218 w[i * nr_w + j] = Util.atof(new String(buffer, 0, b)); 219 break; 220 } else { 221 buffer[b++] = ch; 222 } 223 } 224 } 225 } 226 } 227 finally { 228 Util.closeQuietly(reader); 229 } 230 } 231 232 public int hashCode() { 233 final int prime = 31; 234 long temp = Double.doubleToLongBits(bias); 235 int result = prime * 1 + (int)(temp ^ (temp >>> 32)); 236 result = prime * result + Arrays.hashCode(labels); 237 result = prime * result + nr_class; 238 result = prime * result + nr_feature; 239 result = prime * result + ((solverType == null) ? 0 : solverType.hashCode()); 240 result = prime * result + Arrays.hashCode(w); 241 return result; 242 } 243 244 public boolean equals(Object obj) { 245 if (this == obj) return true; 246 if (obj == null) return false; 247 if (getClass() != obj.getClass()) return false; 248 MaltLiblinearModel other = (MaltLiblinearModel)obj; 249 if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false; 250 if (!Arrays.equals(labels, other.labels)) return false; 251 if (nr_class != other.nr_class) return false; 252 if (nr_feature != other.nr_feature) return false; 253 if (solverType == null) { 254 if (other.solverType != null) return false; 255 } else if (!solverType.equals(other.solverType)) return false; 256 if (!Util.equals(w, other.w)) return false; 257 return true; 258 } 259 260 public String toString() { 261 StringBuilder sb = new StringBuilder("Model"); 262 sb.append(" bias=").append(bias); 263 sb.append(" nr_class=").append(nr_class); 264 sb.append(" nr_feature=").append(nr_feature); 265 sb.append(" solverType=").append(solverType); 266 return sb.toString(); 267 } 268 }