1 package au.gov.amsa.spark.ais;
2
3 import java.io.File;
4 import java.io.FileOutputStream;
5 import java.io.IOException;
6 import java.util.Arrays;
7 import java.util.HashMap;
8 import java.util.List;
9 import java.util.Map;
10
11 import org.apache.commons.io.FileUtils;
12 import org.apache.spark.SparkConf;
13 import org.apache.spark.api.java.JavaRDD;
14 import org.apache.spark.api.java.JavaSparkContext;
15 import org.apache.spark.api.java.function.Function;
16 import org.apache.spark.mllib.regression.LabeledPoint;
17 import org.apache.spark.mllib.tree.DecisionTree;
18 import org.apache.spark.mllib.tree.model.DecisionTreeModel;
19 import org.apache.spark.mllib.util.MLUtils;
20
21 public class AnchoredTrainerMain {
22
23 public static void main(String[] args) throws IOException {
24
25 SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
26
27 sparkConf.setMaster("local[" + Runtime.getRuntime().availableProcessors() + "]");
28 JavaSparkContext sc = new JavaSparkContext(sparkConf);
29
30
31 String datapath = "/media/an/fixes.libsvm";
32
33
34
35 List<String> names = Arrays.asList("lat", "lon", "speedKnots", "courseHeadingDiff",
36 "preEffectiveSpeedKnots", "preError", "postEffectiveSpeedKnots", "postError");
37 List<String> classifications = Arrays.asList("other", "moored", "anchored");
38
39 JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
40
41 JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] { 0.7, 0.3 });
42 JavaRDD<LabeledPoint> trainingData = splits[0];
43 JavaRDD<LabeledPoint> testData = splits[1];
44
45
46
47 Integer numClassifications = classifications.size();
48 Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
49 String impurity = "gini";
50 Integer maxDepth = 8;
51 Integer maxBins = 32;
52
53
54 final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData,
55 numClassifications, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
56
57
58 Double testErr = (double) testData
59
60 .map(toPredictionAndActual(model))
61
62 .filter(predictionWrong())
63
64 .count()
65
66 / testData.count();
67
68
69 String modelPath = "target/myModelPath";
70 FileUtils.deleteDirectory(new File(modelPath));
71 model.save(sc.sc(), modelPath);
72 DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), modelPath);
73
74 System.out.println("Test Error: " + testErr);
75
76 String s = useNames(model.toDebugString(), names, classifications);
77
78 System.out.println("Learned classification tree model:\n" + s);
79
80 FileOutputStream fos = new FileOutputStream("target/model.txt");
81 fos.write(("Test Error: " + testErr + "\n").getBytes());
82 fos.write(s.getBytes());
83 fos.close();
84
85 }
86
87 private static String useNames(String s, List<String> names, List<String> features) {
88 String result = s;
89 for (int i = names.size() - 1; i >= 0; i--) {
90 result = result.replace("feature " + i, names.get(i));
91 }
92
93 for (int i = features.size() - 1; i >= 0; i--) {
94 result = result.replace("Predict: " + i + ".0", "Predict: " + features.get(i));
95 }
96 return result;
97 }
98
99 private static Function<PredictionAndActual, Boolean> predictionWrong() {
100 return p -> p.prediction != p.actual;
101 }
102
103 private static Function<LabeledPoint, PredictionAndActual> toPredictionAndActual(
104 final DecisionTreeModel model) {
105 return p -> new PredictionAndActual(model.predict(p.features()), p.label());
106 }
107
108 private static class PredictionAndActual {
109 final double prediction;
110 final double actual;
111
112 PredictionAndActual(double prediction, double actual) {
113 this.prediction = prediction;
114 this.actual = actual;
115 }
116
117 }
118 }