View Javadoc
1   package au.gov.amsa.spark.ais;
2   
3   import java.io.File;
4   import java.io.FileOutputStream;
5   import java.io.IOException;
6   import java.util.Arrays;
7   import java.util.HashMap;
8   import java.util.List;
9   import java.util.Map;
10  
11  import org.apache.commons.io.FileUtils;
12  import org.apache.spark.SparkConf;
13  import org.apache.spark.api.java.JavaRDD;
14  import org.apache.spark.api.java.JavaSparkContext;
15  import org.apache.spark.api.java.function.Function;
16  import org.apache.spark.mllib.regression.LabeledPoint;
17  import org.apache.spark.mllib.tree.DecisionTree;
18  import org.apache.spark.mllib.tree.model.DecisionTreeModel;
19  import org.apache.spark.mllib.util.MLUtils;
20  
21  public class AnchoredTrainerMain {
22  
23      public static void main(String[] args) throws IOException {
24  
25          SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
26          // just run this locally
27          sparkConf.setMaster("local[" + Runtime.getRuntime().availableProcessors() + "]");
28          JavaSparkContext sc = new JavaSparkContext(sparkConf);
29  
30          // Load and parse the data file.
31          String datapath = "/media/an/fixes.libsvm";
32  
33          // the feature names are substituted into the model debugString later to
34          // make it readable
35          List<String> names = Arrays.asList("lat", "lon", "speedKnots", "courseHeadingDiff",
36                  "preEffectiveSpeedKnots", "preError", "postEffectiveSpeedKnots", "postError");
37          List<String> classifications = Arrays.asList("other", "moored", "anchored");
38  
39          JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
40          // Split the data into training and test sets (30% held out for testing)
41          JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] { 0.7, 0.3 });
42          JavaRDD<LabeledPoint> trainingData = splits[0];
43          JavaRDD<LabeledPoint> testData = splits[1];
44  
45          // Set parameters.
46          // Empty categoricalFeaturesInfo indicates all features are continuous.
47          Integer numClassifications = classifications.size();
48          Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
49          String impurity = "gini";
50          Integer maxDepth = 8;
51          Integer maxBins = 32;
52  
53          // Train a DecisionTree model for classification.
54          final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData,
55                  numClassifications, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
56  
57          // Evaluate model on test instances and compute test error
58          Double testErr = (double) testData
59          // pair up actual and predicted classification numerical representation
60                  .map(toPredictionAndActual(model))
61                  // get the ones that don't match
62                  .filter(predictionWrong())
63                  // count them
64                  .count()
65          // divide by total count to get ratio failing test
66                  / testData.count();
67  
68          // Save and load model to demo possible usage in prediction mode
69          String modelPath = "target/myModelPath";
70          FileUtils.deleteDirectory(new File(modelPath));
71          model.save(sc.sc(), modelPath);
72          DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), modelPath);
73  
74          System.out.println("Test Error: " + testErr);
75  
76          String s = useNames(model.toDebugString(), names, classifications);
77  
78          System.out.println("Learned classification tree model:\n" + s);
79  
80          FileOutputStream fos = new FileOutputStream("target/model.txt");
81          fos.write(("Test Error: " + testErr + "\n").getBytes());
82          fos.write(s.getBytes());
83          fos.close();
84  
85      }
86  
87      private static String useNames(String s, List<String> names, List<String> features) {
88          String result = s;
89          for (int i = names.size() - 1; i >= 0; i--) {
90              result = result.replace("feature " + i, names.get(i));
91          }
92  
93          for (int i = features.size() - 1; i >= 0; i--) {
94              result = result.replace("Predict: " + i + ".0", "Predict: " + features.get(i));
95          }
96          return result;
97      }
98  
99      private static Function<PredictionAndActual, Boolean> predictionWrong() {
100         return p -> p.prediction != p.actual;
101     }
102 
103     private static Function<LabeledPoint, PredictionAndActual> toPredictionAndActual(
104             final DecisionTreeModel model) {
105         return p -> new PredictionAndActual(model.predict(p.features()), p.label());
106     }
107 
108     private static class PredictionAndActual {
109         final double prediction;
110         final double actual;
111 
112         PredictionAndActual(double prediction, double actual) {
113             this.prediction = prediction;
114             this.actual = actual;
115         }
116 
117     }
118 }