1 package au.gov.amsa.spark.ais;
2
3 import org.apache.spark.api.java.JavaSparkContext;
4 import org.apache.spark.mllib.linalg.DenseVector;
5 import org.apache.spark.mllib.linalg.Vector;
6 import org.apache.spark.mllib.tree.model.DecisionTreeModel;
7
8 public class AnchoredPredictor {
9
10 private DecisionTreeModel model;
11
12 public AnchoredPredictor(JavaSparkContext sc) {
13 String dataPath = AnchoredPredictor.class.getResource("/anchoredOrMooredModel").toString();
14 model = DecisionTreeModel.load(sc.sc(), dataPath);
15 }
16
17 public static enum Status {
18 OTHER, MOORED, ANCHORED;
19 }
20
21 public Status predict(double lat, double lon, double speedKnots, double courseMinusHeading,
22 double preEffectiveSpeedKnots, double preError, double postEffectiveSpeedKnots,
23 double postError) {
24 Vector features = new DenseVector(new double[] { lat, lon, speedKnots, courseMinusHeading,
25 preEffectiveSpeedKnots, preError, postEffectiveSpeedKnots, postError });
26 double prediction = model.predict(features);
27
28 if (is(prediction, 1))
29 return Status.MOORED;
30 else if (is(prediction, 2))
31 return Status.ANCHORED;
32 else
33 return Status.OTHER;
34 }
35
36 private static boolean is(double a, double b) {
37 return Math.abs(a - b) < 0.0001;
38 }
39 }