001 package org.maltparser.ml.lib;
002
003 import java.io.BufferedReader;
004 import java.io.EOFException;
005 import java.io.File;
006 import java.io.FileInputStream;
007 import java.io.IOException;
008 import java.io.InputStreamReader;
009 import java.io.ObjectInputStream;
010 import java.io.ObjectOutputStream;
011 import java.io.Reader;
012 import java.io.Serializable;
013 import java.nio.charset.Charset;
014 import java.util.Arrays;
015 import java.util.regex.Pattern;
016
017 import org.maltparser.core.helper.Util;
018
019 import liblinear.SolverType;
020
021 /**
022 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package.
023 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
024 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data
025 * structure directly on the model. </p>
026 *
027 * @author Johan Hall
028 *
029 */
030 public class MaltLiblinearModel implements Serializable, MaltLibModel {
031 private static final long serialVersionUID = 7526471155622776147L;
032 private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
033 private double bias;
034 /** label of each class */
035 private int[] labels;
036 private int nr_class;
037 private int nr_feature;
038 private SolverType solverType;
039 /** feature weight array */
040 private double[][] w;
041
042 public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) {
043 this.labels = labels;
044 this.nr_class = nr_class;
045 this.nr_feature = nr_feature;
046 this.w = w;
047 this.solverType = solverType;
048 }
049
050 public MaltLiblinearModel(Reader inputReader) throws IOException {
051 loadModel(inputReader);
052 }
053
054 public MaltLiblinearModel(File modelFile) throws IOException {
055 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
056 loadModel(inputReader);
057 }
058
059 /**
060 * @return number of classes
061 */
062 public int getNrClass() {
063 return nr_class;
064 }
065
066 /**
067 * @return number of features
068 */
069 public int getNrFeature() {
070 return nr_feature;
071 }
072
073 public int[] getLabels() {
074 return Util.copyOf(labels, nr_class);
075 }
076
077 /**
078 * The nr_feature*nr_class array w gives feature weights. We use one
079 * against the rest for multi-class classification, so each feature
080 * index corresponds to nr_class weight values. Weights are
081 * organized in the following way
082 *
083 * <pre>
084 * +------------------+------------------+------------+
085 * | nr_class weights | nr_class weights | ...
086 * | for 1st feature | for 2nd feature |
087 * +------------------+------------------+------------+
088 * </pre>
089 *
090 * If bias >= 0, x becomes [x; bias]. The number of features is
091 * increased by one, so w is a (nr_feature+1)*nr_class array. The
092 * value of bias is stored in the variable bias.
093 * @see #getBias()
094 * @return a <b>copy of</b> the feature weight array as described
095 */
096 // public double[] getFeatureWeights() {
097 // return Util.copyOf(w, w.length);
098 // }
099
100 /**
101 * @return true for logistic regression solvers
102 */
103 public boolean isProbabilityModel() {
104 return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
105 }
106
107 public double getBias() {
108 return bias;
109 }
110
111 public int[] predict(MaltFeatureNode[] x) {
112 final double[] dec_values = new double[nr_class];
113 final int n = (bias >= 0)?nr_feature + 1:nr_feature;
114 final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class;
115 final int xlen = x.length;
116 int i;
117 for (i = 0; i < nr_w; i++) {
118 dec_values[i] = 0;
119 }
120
121 for (i=0; i < xlen; i++) {
122 if (x[i].index <= n) {
123 int t = (x[i].index - 1);
124 if (w[t] != null) {
125 for (int j = 0; j < w[t].length; j++) {
126 dec_values[j] += w[t][j] * x[i].value;
127 }
128 }
129 }
130 }
131
132 final int[] predictionList = Util.copyOf(labels, nr_class);
133 double tmpDec;
134 int tmpObj;
135 int lagest;
136 final int nc = nr_class-1;
137 for (i=0; i < nc; i++) {
138 lagest = i;
139 for (int j=i; j < nr_class; j++) {
140 if (dec_values[j] > dec_values[lagest]) {
141 lagest = j;
142 }
143 }
144 tmpDec = dec_values[lagest];
145 dec_values[lagest] = dec_values[i];
146 dec_values[i] = tmpDec;
147 tmpObj = predictionList[lagest];
148 predictionList[lagest] = predictionList[i];
149 predictionList[i] = tmpObj;
150 }
151 return predictionList;
152 }
153
154 private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
155 is.defaultReadObject();
156 }
157
158 private void writeObject(ObjectOutputStream os) throws IOException {
159 os.defaultWriteObject();
160 }
161
162 private void loadModel(Reader inputReader) throws IOException {
163 labels = null;
164 Pattern whitespace = Pattern.compile("\\s+");
165 BufferedReader reader = null;
166 if (inputReader instanceof BufferedReader) {
167 reader = (BufferedReader)inputReader;
168 } else {
169 reader = new BufferedReader(inputReader);
170 }
171
172 try {
173 String line = null;
174 while ((line = reader.readLine()) != null) {
175 String[] split = whitespace.split(line);
176 if (split[0].equals("solver_type")) {
177 SolverType solver = SolverType.valueOf(split[1]);
178 if (solver == null) {
179 throw new RuntimeException("unknown solver type");
180 }
181 solverType = solver;
182 } else if (split[0].equals("nr_class")) {
183 nr_class = Util.atoi(split[1]);
184 Integer.parseInt(split[1]);
185 } else if (split[0].equals("nr_feature")) {
186 nr_feature = Util.atoi(split[1]);
187 } else if (split[0].equals("bias")) {
188 bias = Util.atof(split[1]);
189 } else if (split[0].equals("w")) {
190 break;
191 } else if (split[0].equals("label")) {
192 labels = new int[nr_class];
193 for (int i = 0; i < nr_class; i++) {
194 labels[i] = Util.atoi(split[i + 1]);
195 }
196 } else {
197 throw new RuntimeException("unknown text in model file: [" + line + "]");
198 }
199 }
200
201 int w_size = nr_feature;
202 if (bias >= 0) w_size++;
203
204 int nr_w = nr_class;
205 if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1;
206 w = new double[w_size][nr_w];
207 int[] buffer = new int[128];
208
209 for (int i = 0; i < w_size; i++) {
210 for (int j = 0; j < nr_w; j++) {
211 int b = 0;
212 while (true) {
213 int ch = reader.read();
214 if (ch == -1) {
215 throw new EOFException("unexpected EOF");
216 }
217 if (ch == ' ') {
218 w[i][j] = Util.atof(new String(buffer, 0, b));
219 break;
220 } else {
221 buffer[b++] = ch;
222 }
223 }
224 }
225 }
226 }
227 finally {
228 Util.closeQuietly(reader);
229 }
230 }
231
232 public int hashCode() {
233 final int prime = 31;
234 long temp = Double.doubleToLongBits(bias);
235 int result = prime * 1 + (int)(temp ^ (temp >>> 32));
236 result = prime * result + Arrays.hashCode(labels);
237 result = prime * result + nr_class;
238 result = prime * result + nr_feature;
239 result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
240 for (int i = 0; i < w.length; i++) {
241 result = prime * result + Arrays.hashCode(w[i]);
242 }
243 return result;
244 }
245
246 public boolean equals(Object obj) {
247 if (this == obj) return true;
248 if (obj == null) return false;
249 if (getClass() != obj.getClass()) return false;
250 MaltLiblinearModel other = (MaltLiblinearModel)obj;
251 if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
252 if (!Arrays.equals(labels, other.labels)) return false;
253 if (nr_class != other.nr_class) return false;
254 if (nr_feature != other.nr_feature) return false;
255 if (solverType == null) {
256 if (other.solverType != null) return false;
257 } else if (!solverType.equals(other.solverType)) return false;
258 for (int i = 0; i < w.length; i++) {
259 if (other.w.length <= i) return false;
260 if (!Util.equals(w[i], other.w[i])) return false;
261 }
262 return true;
263 }
264
265 public String toString() {
266 final StringBuilder sb = new StringBuilder("Model");
267 sb.append(" bias=").append(bias);
268 sb.append(" nr_class=").append(nr_class);
269 sb.append(" nr_feature=").append(nr_feature);
270 sb.append(" solverType=").append(solverType);
271 return sb.toString();
272 }
273 }