001package org.maltparser.ml.lib;
002
003import java.io.BufferedOutputStream;
004import java.io.BufferedReader;
005import java.io.BufferedWriter;
006import java.io.File;
007import java.io.FileNotFoundException;
008import java.io.FileOutputStream;
009import java.io.IOException;
010import java.io.InputStream;
011import java.io.InputStreamReader;
012import java.io.ObjectInputStream;
013import java.io.ObjectOutputStream;
014import java.io.OutputStream;
015
016import java.io.OutputStreamWriter;
017import java.util.ArrayList;
018
019import java.util.LinkedHashMap;
020import java.util.Set;
021import java.util.regex.Pattern;
022import java.util.regex.PatternSyntaxException;
023
024
025import org.maltparser.core.exception.MaltChainedException;
026import org.maltparser.core.feature.FeatureVector;
027import org.maltparser.core.feature.function.FeatureFunction;
028import org.maltparser.core.feature.value.FeatureValue;
029import org.maltparser.core.feature.value.MultipleFeatureValue;
030import org.maltparser.core.feature.value.SingleFeatureValue;
031import org.maltparser.core.syntaxgraph.DependencyStructure;
032import org.maltparser.ml.LearningMethod;
033import org.maltparser.ml.lib.FeatureMap;
034import org.maltparser.ml.lib.FeatureList;
035import org.maltparser.ml.lib.MaltLibModel;
036import org.maltparser.ml.lib.MaltFeatureNode;
037import org.maltparser.ml.lib.LibException;
038import org.maltparser.parser.DependencyParserConfig;
039import org.maltparser.parser.guide.instance.InstanceModel;
040import org.maltparser.parser.history.action.SingleDecision;
041
042public abstract class Lib implements LearningMethod {
043        public enum Verbostity {
044                SILENT, ERROR, ALL
045        }
046        protected final Verbostity verbosity;
047        private final InstanceModel owner;
048        private final int learnerMode;
049        private final String name;
050        protected final FeatureMap featureMap;
051        private final boolean excludeNullValues;
052        private BufferedWriter instanceOutput = null; 
053        protected MaltLibModel model = null;
054        
055        private int numberOfInstances;
056        
057        /**
058         * Constructs a Lib learner.
059         * 
060         * @param owner the guide model owner
061         * @param learnerMode the mode of the learner BATCH or CLASSIFY
062         */
063        public Lib(InstanceModel owner, Integer learnerMode, String learningMethodName) throws MaltChainedException {
064                this.owner = owner;
065                this.learnerMode = learnerMode.intValue();
066                this.name = learningMethodName;
067                if (getConfiguration().getOptionValue("lib", "verbosity") != null) {
068                        this.verbosity = Verbostity.valueOf(getConfiguration().getOptionValue("lib", "verbosity").toString().toUpperCase());
069                } else {
070                        this.verbosity = Verbostity.SILENT;
071                }
072                setNumberOfInstances(0);
073                if (getConfiguration().getOptionValue("singlemalt", "null_value") != null && getConfiguration().getOptionValue("singlemalt", "null_value").toString().equalsIgnoreCase("none")) {
074                        excludeNullValues = true;
075                } else {
076                        excludeNullValues = false;
077                }
078
079                if (learnerMode == BATCH) {
080                        featureMap = new FeatureMap();
081                        instanceOutput = new BufferedWriter(getInstanceOutputStreamWriter(".ins"));
082                } else if (learnerMode == CLASSIFY) {
083                        featureMap = (FeatureMap)getConfigFileEntryObject(".map");
084                } else {
085                        featureMap = null;
086                }
087        }
088        
089        public void addInstance(SingleDecision decision, FeatureVector featureVector) throws MaltChainedException {
090                if (featureVector == null) {
091                        throw new LibException("The feature vector cannot be found");
092                } else if (decision == null) {
093                        throw new LibException("The decision cannot be found");
094                }       
095                
096                try {
097                        final StringBuilder sb = new StringBuilder();
098                        sb.append(decision.getDecisionCode()+"\t");
099                        final int n = featureVector.size();
100                        for (int i = 0; i < n; i++) {
101                                FeatureValue featureValue = featureVector.getFeatureValue(i);
102                                if (featureValue == null || (excludeNullValues == true && featureValue.isNullValue())) {
103                                        sb.append("-1");
104                                } else {
105                                        if (!featureValue.isMultiple()) {
106                                                SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
107                                                if (singleFeatureValue.getValue() == 1) {
108                                                        sb.append(singleFeatureValue.getIndexCode());
109                                                } else if (singleFeatureValue.getValue() == 0) {
110                                                        sb.append("-1");
111                                                } else {
112                                                        sb.append(singleFeatureValue.getIndexCode());
113                                                        sb.append(":");
114                                                        sb.append(singleFeatureValue.getValue());
115                                                }
116                                        } else { //if (featureValue instanceof MultipleFeatureValue) {
117                                                Set<Integer> values = ((MultipleFeatureValue)featureValue).getCodes();
118                                                int j=0;
119                                                for (Integer value : values) {
120                                                        sb.append(value.toString());
121                                                        if (j != values.size()-1) {
122                                                                sb.append("|");
123                                                        }
124                                                        j++;
125                                                }
126                                        }
127//                                      else {
128//                                              throw new LibException("Don't recognize the type of feature value: "+featureValue.getClass());
129//                                      }
130                                }
131                                sb.append('\t');
132                        }
133                        sb.append('\n');
134                        instanceOutput.write(sb.toString());
135                        instanceOutput.flush();
136                        increaseNumberOfInstances();
137//                      sb.setLength(0);
138                } catch (IOException e) {
139                        throw new LibException("The learner cannot write to the instance file. ", e);
140                }
141        }
142
143        public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { }
144
145        public void moveAllInstances(LearningMethod method,
146                        FeatureFunction divideFeature,
147                        ArrayList<Integer> divideFeatureIndexVector)
148                        throws MaltChainedException { 
149                if (method == null) {
150                        throw new LibException("The learning method cannot be found. ");
151                } else if (divideFeature == null) {
152                        throw new LibException("The divide feature cannot be found. ");
153                } 
154                
155                try {
156                        final BufferedReader in = new BufferedReader(getInstanceInputStreamReader(".ins"));
157                        final BufferedWriter out = method.getInstanceWriter();
158                        final StringBuilder sb = new StringBuilder(6);
159                        int l = in.read();
160                        char c;
161                        int j = 0;
162        
163                        while(true) {
164                                if (l == -1) {
165                                        sb.setLength(0);
166                                        break;
167                                }
168                                c = (char)l; 
169                                l = in.read();
170                                if (c == '\t') {
171                                        if (divideFeatureIndexVector.contains(j-1)) {
172                                                out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
173                                                out.write('\t');
174                                        }
175                                        out.write(sb.toString());
176                                        j++;
177                                        out.write('\t');
178                                        sb.setLength(0);
179                                } else if (c == '\n') {
180                                        out.write(sb.toString());
181                                        if (divideFeatureIndexVector.contains(j-1)) {
182                                                out.write('\t');
183                                                out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
184                                        }
185                                        out.write('\n');
186                                        sb.setLength(0);
187                                        method.increaseNumberOfInstances();
188                                        this.decreaseNumberOfInstances();
189                                        j = 0;
190                                } else {
191                                        sb.append(c);
192                                }
193                        }       
194                        in.close();
195                        getFile(".ins").delete();
196                        out.flush();
197                } catch (SecurityException e) {
198                        throw new LibException("The learner cannot remove the instance file. ", e);
199                } catch (NullPointerException  e) {
200                        throw new LibException("The instance file cannot be found. ", e);
201                } catch (FileNotFoundException e) {
202                        throw new LibException("The instance file cannot be found. ", e);
203                } catch (IOException e) {
204                        throw new LibException("The learner read from the instance file. ", e);
205                }
206        }
207
208        public void noMoreInstances() throws MaltChainedException { 
209                closeInstanceWriter();
210        }
211        
212        public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
213                final FeatureList featureList = new FeatureList();
214                final int size = featureVector.size();
215                for (int i = 1; i <= size; i++) {
216                        final FeatureValue featureValue = featureVector.getFeatureValue(i-1);   
217                        if (featureValue != null && !(excludeNullValues == true && featureValue.isNullValue())) {
218                                if (!featureValue.isMultiple()) {
219                                        SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
220                                        final int index = featureMap.getIndex(i, singleFeatureValue.getIndexCode());
221                                        if (index != -1 && singleFeatureValue.getValue() != 0) {
222                                                featureList.add(index,singleFeatureValue.getValue());
223                                        }
224                                } 
225                                else {
226                                        for (Integer value : ((MultipleFeatureValue)featureValue).getCodes()) {
227                                                final int v = featureMap.getIndex(i, value);
228                                                if (v != -1) {
229                                                        featureList.add(v,1);
230                                                }
231                                        }
232                                } 
233                        }
234                }
235                try {
236                        decision.getKBestList().addList(model.predict(featureList.toArray()));
237                } catch (OutOfMemoryError e) {
238                        throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
239                }
240                return true;
241        }
242                
243//      protected abstract int[] prediction(FeatureList featureList) throws MaltChainedException;
244        
245        public void train() throws MaltChainedException { 
246                if (owner == null) {
247                        throw new LibException("The parent guide model cannot be found. ");
248                }
249                String pathExternalTrain = null;
250                if (!getConfiguration().getOptionValue("lib", "external").toString().equals("")) {
251                        String path = getConfiguration().getOptionValue("lib", "external").toString(); 
252                        try {
253                                if (!new File(path).exists()) {
254                                        throw new LibException("The path to the external  trainer 'svm-train' is wrong.");
255                                }
256                                if (new File(path).isDirectory()) {
257                                        throw new LibException("The option --lib-external points to a directory, the path should point at the 'train' file or the 'train.exe' file in the libsvm or the liblinear package");
258                                }
259                                if (!(path.endsWith("train") ||path.endsWith("train.exe"))) {
260                                        throw new LibException("The option --lib-external does not specify the path to 'train' file or the 'train.exe' file in the libsvm or the liblinear package. ");
261                                }
262                                pathExternalTrain = path;
263                        } catch (SecurityException e) {
264                                throw new LibException("Access denied to the file specified by the option --lib-external. ", e);
265                        }
266                }
267                LinkedHashMap<String, String> libOptions = getDefaultLibOptions();
268                parseParameters(getConfiguration().getOptionValue("lib", "options").toString(), libOptions, getAllowedLibOptionFlags());
269                
270//              long startTime = System.currentTimeMillis();
271                
272//              if (configLogger.isInfoEnabled()) {
273//                      configLogger.info("\nStart training\n");
274//              }
275                if (pathExternalTrain != null) {
276                        trainExternal(pathExternalTrain, libOptions);
277                } else {
278                        trainInternal(libOptions);
279                }
280//              long elapsed = System.currentTimeMillis() - startTime;
281//              if (configLogger.isInfoEnabled()) {
282//                      configLogger.info("Time 1: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
283//              }
284                try {
285//                      if (configLogger.isInfoEnabled()) {
286//                              configLogger.info("\nSaving feature map "+getFile(".map").getName()+"\n");
287//                      }
288                        saveFeatureMap(new BufferedOutputStream(new FileOutputStream(getFile(".map").getAbsolutePath())), featureMap);
289                } catch (FileNotFoundException e) {
290                        throw new LibException("The learner cannot save the feature map file '"+getFile(".map").getAbsolutePath()+"'. ", e);
291                }
292//              elapsed = System.currentTimeMillis() - startTime;
293//              if (configLogger.isInfoEnabled()) {
294//                      configLogger.info("Time 2: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
295//              }
296        }
297        protected abstract void trainExternal(String pathExternalTrain, LinkedHashMap<String, String> libOptions) throws MaltChainedException;
298        protected abstract void trainInternal(LinkedHashMap<String, String> libOptions) throws MaltChainedException;
299        
300        public void terminate() throws MaltChainedException { 
301                closeInstanceWriter();
302//              owner = null;
303//              model = null;
304        }
305
306        public BufferedWriter getInstanceWriter() {
307                return instanceOutput;
308        }
309        
310        protected void closeInstanceWriter() throws MaltChainedException {
311                try {
312                        if (instanceOutput != null) {
313                                instanceOutput.flush();
314                                instanceOutput.close();
315                                instanceOutput = null;
316                        }
317                } catch (IOException e) {
318                        throw new LibException("The learner cannot close the instance file. ", e);
319                }
320        }
321        
322        public InstanceModel getOwner() {
323                return owner;
324        }
325        
326        public int getLearnerMode() {
327                return learnerMode;
328        }
329        
330        public String getLearningMethodName() {
331                return name;
332        }
333        
334        /**
335         * Returns the current configuration
336         * 
337         * @return the current configuration
338         * @throws MaltChainedException
339         */
340        public DependencyParserConfig getConfiguration() throws MaltChainedException {
341                return owner.getGuide().getConfiguration();
342        }
343        
344        public int getNumberOfInstances() throws MaltChainedException {
345                if(numberOfInstances!=0)
346                        return numberOfInstances;
347                else{
348                        BufferedReader reader = new BufferedReader( getInstanceInputStreamReader(".ins"));
349                        try {
350                                while(reader.readLine()!=null){
351                                        numberOfInstances++;
352                                        owner.increaseFrequency();
353                                }
354                                reader.close();
355                        } catch (IOException e) {
356                                throw new MaltChainedException("No instances found in file",e);
357                        }
358                        return numberOfInstances;
359                }
360        }
361
362        public void increaseNumberOfInstances() {
363                numberOfInstances++;
364                owner.increaseFrequency();
365        }
366        
367        public void decreaseNumberOfInstances() {
368                numberOfInstances--;
369                owner.decreaseFrequency();
370        }
371        
372        protected void setNumberOfInstances(int numberOfInstances) {
373                this.numberOfInstances = 0;
374        }
375        
376        protected OutputStreamWriter getInstanceOutputStreamWriter(String suffix) throws MaltChainedException {
377                return getConfiguration().getAppendOutputStreamWriter(owner.getModelName()+getLearningMethodName()+suffix);
378        }
379        
380        protected InputStreamReader getInstanceInputStreamReader(String suffix) throws MaltChainedException {
381                return getConfiguration().getInputStreamReader(owner.getModelName()+getLearningMethodName()+suffix);
382        }
383        
384        protected InputStream getInputStreamFromConfigFileEntry(String suffix) throws MaltChainedException {
385                return getConfiguration().getInputStreamFromConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix);
386        }
387        
388        protected File getFile(String suffix) throws MaltChainedException {
389                return getConfiguration().getFile(owner.getModelName()+getLearningMethodName()+suffix);
390        }
391        
392        protected Object getConfigFileEntryObject(String suffix) throws MaltChainedException {
393                return getConfiguration().getConfigFileEntryObject(owner.getModelName()+getLearningMethodName()+suffix);
394        }
395        
396        public String[] getLibParamStringArray(LinkedHashMap<String, String> libOptions) {
397                final ArrayList<String> params = new ArrayList<String>();
398
399                for (String key : libOptions.keySet()) {
400                        params.add("-"+key); params.add(libOptions.get(key));
401                }
402                return params.toArray(new String[params.size()]);
403        }
404        
405        public abstract LinkedHashMap<String, String> getDefaultLibOptions();
406        public abstract String getAllowedLibOptionFlags();
407        
408        public void parseParameters(String paramstring, LinkedHashMap<String, String> libOptions, String allowedLibOptionFlags) throws MaltChainedException {
409                if (paramstring == null) {
410                        return;
411                }
412                final String[] argv;
413                try {
414                        argv = paramstring.split("[_\\p{Blank}]");
415                } catch (PatternSyntaxException e) {
416                        throw new LibException("Could not split the parameter string '"+paramstring+"'. ", e);
417                }
418                for (int i=0; i < argv.length-1; i++) {
419                        if(argv[i].charAt(0) != '-') {
420                                throw new LibException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0));
421                        }
422                        if(++i>=argv.length) {
423                                throw new LibException("The last argument does not have any value. ");
424                        }
425                        try {
426                                int index = allowedLibOptionFlags.indexOf(argv[i-1].charAt(1));
427                                if (index != -1) {
428                                        libOptions.put(Character.toString(argv[i-1].charAt(1)), argv[i]);
429                                } else {
430                                        throw new LibException("Unknown learner parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. ");                
431                                }
432                        } catch (ArrayIndexOutOfBoundsException e) {
433                                throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);
434                        } catch (NumberFormatException e) {
435                                throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);     
436                        } catch (NullPointerException e) {
437                                throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);     
438                        }
439                }
440        }
441        
442        protected void finalize() throws Throwable {
443                try {
444                        closeInstanceWriter();
445                } finally {
446                        super.finalize();
447                }
448        }
449        
450        public String toString() {
451                final StringBuffer sb = new StringBuffer();
452                sb.append('\n');
453                sb.append(getLearningMethodName());
454                sb.append(" INTERFACE\n");
455                try {
456                        sb.append(getConfiguration().getOptionValue("lib", "options").toString());
457                } catch (MaltChainedException e) {}
458                return sb.toString();
459        }
460
461        protected int binariesInstance(String line, FeatureList featureList) throws MaltChainedException {
462                final Pattern tabPattern = Pattern.compile("\t");
463                final Pattern pipePattern = Pattern.compile("\\|");
464                int y = -1; 
465                featureList.clear();
466                try {   
467                        String[] columns = tabPattern.split(line);
468
469                        if (columns.length == 0) {
470                                return -1;
471                        }
472                        try {
473                                y = Integer.parseInt(columns[0]);
474                        } catch (NumberFormatException e) {
475                                throw new LibException("The instance file contain a non-integer value '"+columns[0]+"'", e);
476                        }
477                        for(int j = 1; j < columns.length; j++) {
478                                final String[] items = pipePattern.split(columns[j]);
479                                for (int k = 0; k < items.length; k++) {
480                                        try {
481                                                int colon = items[k].indexOf(':');
482                                                if (colon == -1) {
483                                                        if (Integer.parseInt(items[k]) != -1) {
484                                                                int v = featureMap.addIndex(j, Integer.parseInt(items[k]));
485                                                                if (v != -1) {
486                                                                        featureList.add(v,1);
487                                                                }
488                                                        }
489                                                } else {
490                                                        int index = featureMap.addIndex(j, Integer.parseInt(items[k].substring(0,colon)));
491                                                        double value;
492                                                        if (items[k].substring(colon+1).indexOf('.') != -1) {
493                                                                value = Double.parseDouble(items[k].substring(colon+1));
494                                                        } else {
495                                                                value = Integer.parseInt(items[k].substring(colon+1));
496                                                        }
497                                                        featureList.add(index,value);
498                                                }
499                                        } catch (NumberFormatException e) {
500                                                throw new LibException("The instance file contain a non-numeric value '"+items[k]+"'", e);
501                                        }
502                                }
503                        }
504                } catch (ArrayIndexOutOfBoundsException e) {
505                        throw new LibException("Couln't read from the instance file. ", e);
506                }
507                return y;
508        }
509
510        protected void binariesInstances2SVMFileFormat(InputStreamReader isr, OutputStreamWriter osw) throws MaltChainedException {
511                try {
512                        final BufferedReader in = new BufferedReader(isr);
513                        final BufferedWriter out = new BufferedWriter(osw);
514                        final FeatureList featureSet = new FeatureList();
515                        while(true) {
516                                String line = in.readLine();
517                                if(line == null) break;
518                                int y = binariesInstance(line, featureSet);
519                                if (y == -1) {
520                                        continue;
521                                }
522                                out.write(Integer.toString(y));
523                                
524                        for (int k=0; k < featureSet.size(); k++) {
525                                MaltFeatureNode x = featureSet.get(k);
526                                        out.write(' ');
527                                        out.write(Integer.toString(x.getIndex()));
528                                        out.write(':');
529                                        out.write(Double.toString(x.getValue()));         
530                                }
531                                out.write('\n');
532                        }                       
533                        in.close();     
534                        out.close();
535                } catch (NumberFormatException e) {
536                        throw new LibException("The instance file contain a non-numeric value", e);
537                } catch (IOException e) {
538                        throw new LibException("Couldn't read from the instance file, when converting the Malt instances into LIBSVM/LIBLINEAR format. ", e);
539                }
540        }
541        
542        protected void saveFeatureMap(OutputStream os, FeatureMap map) throws MaltChainedException {
543                try {
544                    ObjectOutputStream output = new ObjectOutputStream(os);
545                try{
546                  output.writeObject(map);
547                }
548                finally{
549                  output.close();
550                }
551                } catch (IOException e) {
552                        throw new LibException("Save feature map error", e);
553                }
554        }
555
556        protected FeatureMap loadFeatureMap(InputStream is) throws MaltChainedException {
557                FeatureMap map = new FeatureMap();
558                try {
559                    ObjectInputStream input = new ObjectInputStream(is);
560                    try {
561                        map = (FeatureMap)input.readObject();
562                    } finally {
563                        input.close();
564                    }
565                } catch (ClassNotFoundException e) {
566                        throw new LibException("Load feature map error", e);
567                } catch (IOException e) {
568                        throw new LibException("Load feature map error", e);
569                }
570                return map;
571        }
572}