001 package org.maltparser.ml.lib;
002
003 import java.io.BufferedReader;
004 import java.io.EOFException;
005 import java.io.File;
006 import java.io.FileInputStream;
007 import java.io.IOException;
008 import java.io.InputStreamReader;
009 import java.io.ObjectInputStream;
010 import java.io.ObjectOutputStream;
011 import java.io.Reader;
012 import java.io.Serializable;
013 import java.nio.charset.Charset;
014 import java.util.Arrays;
015 import java.util.regex.Pattern;
016
017 import org.maltparser.core.helper.Util;
018
019 import liblinear.Model;
020 import liblinear.SolverType;
021
022 /**
023 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package.
024 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
025 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data
026 * structure directly on the model. </p>
027 *
028 * @author Johan Hall
029 *
030 */
031 public class MaltLiblinearModel implements Serializable, MaltLibModel {
032 private static final long serialVersionUID = 7526471155622776147L;
033 private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
034 private double bias;
035 /** label of each class */
036 private int[] labels;
037 private int nr_class;
038 private int nr_feature;
039 private SolverType solverType;
040 /** feature weight array */
041 private double[] w;
042
043 public MaltLiblinearModel(Model model, SolverType solverType) {
044 labels = model.getLabels();
045 nr_class = model.getNrClass();
046 nr_feature = model.getNrFeature();
047 this.solverType = solverType;
048 w = model.getFeatureWeights();
049 }
050
051 public MaltLiblinearModel(Reader inputReader) throws IOException {
052 loadModel(inputReader);
053 }
054
055 public MaltLiblinearModel(File modelFile) throws IOException {
056 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
057 loadModel(inputReader);
058 }
059
060 /**
061 * @return number of classes
062 */
063 public int getNrClass() {
064 return nr_class;
065 }
066
067 /**
068 * @return number of features
069 */
070 public int getNrFeature() {
071 return nr_feature;
072 }
073
074 public int[] getLabels() {
075 return Util.copyOf(labels, nr_class);
076 }
077
078 /**
079 * The nr_feature*nr_class array w gives feature weights. We use one
080 * against the rest for multi-class classification, so each feature
081 * index corresponds to nr_class weight values. Weights are
082 * organized in the following way
083 *
084 * <pre>
085 * +------------------+------------------+------------+
086 * | nr_class weights | nr_class weights | ...
087 * | for 1st feature | for 2nd feature |
088 * +------------------+------------------+------------+
089 * </pre>
090 *
091 * If bias >= 0, x becomes [x; bias]. The number of features is
092 * increased by one, so w is a (nr_feature+1)*nr_class array. The
093 * value of bias is stored in the variable bias.
094 * @see #getBias()
095 * @return a <b>copy of</b> the feature weight array as described
096 */
097 public double[] getFeatureWeights() {
098 return Util.copyOf(w, w.length);
099 }
100
101 /**
102 * @return true for logistic regression solvers
103 */
104 public boolean isProbabilityModel() {
105 return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
106 }
107
108 public double getBias() {
109 return bias;
110 }
111
112 public int[] predict(MaltFeatureNode[] x) {
113 final double[] dec_values = new double[nr_class];
114 final int n = (bias >= 0)?nr_feature + 1:nr_feature;
115 final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class;
116 final int xlen = x.length;
117 int i;
118 for (i = 0; i < nr_w; i++) {
119 dec_values[i] = 0;
120 }
121
122 for (i=0; i < xlen; i++) {
123 if (x[i].index <= n) {
124 int t = (x[i].index - 1) * nr_w;
125 for (int j = 0; j < nr_w; j++) {
126 dec_values[j] += w[t + j] * x[i].value;
127 }
128 }
129 }
130
131 final int[] predictionList = Util.copyOf(labels, nr_class);
132 double tmpDec;
133 int tmpObj;
134 int lagest;
135 final int nc = nr_class-1;
136 for (i=0; i < nc; i++) {
137 lagest = i;
138 for (int j=i; j < nr_class; j++) {
139 if (dec_values[j] > dec_values[lagest]) {
140 lagest = j;
141 }
142 }
143 tmpDec = dec_values[lagest];
144 dec_values[lagest] = dec_values[i];
145 dec_values[i] = tmpDec;
146 tmpObj = predictionList[lagest];
147 predictionList[lagest] = predictionList[i];
148 predictionList[i] = tmpObj;
149 }
150 return predictionList;
151 }
152
153 private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
154 is.defaultReadObject();
155 }
156
157 private void writeObject(ObjectOutputStream os) throws IOException {
158 os.defaultWriteObject();
159 }
160
161 private void loadModel(Reader inputReader) throws IOException {
162 labels = null;
163 Pattern whitespace = Pattern.compile("\\s+");
164 BufferedReader reader = null;
165 if (inputReader instanceof BufferedReader) {
166 reader = (BufferedReader)inputReader;
167 } else {
168 reader = new BufferedReader(inputReader);
169 }
170
171 try {
172 String line = null;
173 while ((line = reader.readLine()) != null) {
174 String[] split = whitespace.split(line);
175 if (split[0].equals("solver_type")) {
176 SolverType solver = SolverType.valueOf(split[1]);
177 if (solver == null) {
178 throw new RuntimeException("unknown solver type");
179 }
180 solverType = solver;
181 } else if (split[0].equals("nr_class")) {
182 nr_class = Util.atoi(split[1]);
183 Integer.parseInt(split[1]);
184 } else if (split[0].equals("nr_feature")) {
185 nr_feature = Util.atoi(split[1]);
186 } else if (split[0].equals("bias")) {
187 bias = Util.atof(split[1]);
188 } else if (split[0].equals("w")) {
189 break;
190 } else if (split[0].equals("label")) {
191 labels = new int[nr_class];
192 for (int i = 0; i < nr_class; i++) {
193 labels[i] = Util.atoi(split[i + 1]);
194 }
195 } else {
196 throw new RuntimeException("unknown text in model file: [" + line + "]");
197 }
198 }
199
200 int w_size = nr_feature;
201 if (bias >= 0) w_size++;
202
203 int nr_w = nr_class;
204 if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1;
205
206 w = new double[w_size * nr_w];
207 int[] buffer = new int[128];
208
209 for (int i = 0; i < w_size; i++) {
210 for (int j = 0; j < nr_w; j++) {
211 int b = 0;
212 while (true) {
213 int ch = reader.read();
214 if (ch == -1) {
215 throw new EOFException("unexpected EOF");
216 }
217 if (ch == ' ') {
218 w[i * nr_w + j] = Util.atof(new String(buffer, 0, b));
219 break;
220 } else {
221 buffer[b++] = ch;
222 }
223 }
224 }
225 }
226 }
227 finally {
228 Util.closeQuietly(reader);
229 }
230 }
231
232 public int hashCode() {
233 final int prime = 31;
234 long temp = Double.doubleToLongBits(bias);
235 int result = prime * 1 + (int)(temp ^ (temp >>> 32));
236 result = prime * result + Arrays.hashCode(labels);
237 result = prime * result + nr_class;
238 result = prime * result + nr_feature;
239 result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
240 result = prime * result + Arrays.hashCode(w);
241 return result;
242 }
243
244 public boolean equals(Object obj) {
245 if (this == obj) return true;
246 if (obj == null) return false;
247 if (getClass() != obj.getClass()) return false;
248 MaltLiblinearModel other = (MaltLiblinearModel)obj;
249 if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
250 if (!Arrays.equals(labels, other.labels)) return false;
251 if (nr_class != other.nr_class) return false;
252 if (nr_feature != other.nr_feature) return false;
253 if (solverType == null) {
254 if (other.solverType != null) return false;
255 } else if (!solverType.equals(other.solverType)) return false;
256 if (!Util.equals(w, other.w)) return false;
257 return true;
258 }
259
260 public String toString() {
261 StringBuilder sb = new StringBuilder("Model");
262 sb.append(" bias=").append(bias);
263 sb.append(" nr_class=").append(nr_class);
264 sb.append(" nr_feature=").append(nr_feature);
265 sb.append(" solverType=").append(solverType);
266 return sb.toString();
267 }
268 }