View Javadoc
1   package au.gov.amsa.spark.ais;
2   
3   import org.apache.spark.api.java.JavaSparkContext;
4   import org.apache.spark.mllib.linalg.DenseVector;
5   import org.apache.spark.mllib.linalg.Vector;
6   import org.apache.spark.mllib.tree.model.DecisionTreeModel;
7   
8   public class AnchoredPredictor {
9   
10      private DecisionTreeModel model;
11  
12      public AnchoredPredictor(JavaSparkContext sc) {
13          String dataPath = AnchoredPredictor.class.getResource("/anchoredOrMooredModel").toString();
14          model = DecisionTreeModel.load(sc.sc(), dataPath);
15      }
16  
17      public static enum Status {
18          OTHER, MOORED, ANCHORED;
19      }
20  
21      public Status predict(double lat, double lon, double speedKnots, double courseMinusHeading,
22              double preEffectiveSpeedKnots, double preError, double postEffectiveSpeedKnots,
23              double postError) {
24          Vector features = new DenseVector(new double[] { lat, lon, speedKnots, courseMinusHeading,
25                  preEffectiveSpeedKnots, preError, postEffectiveSpeedKnots, postError });
26          double prediction = model.predict(features);
27  
28          if (is(prediction, 1))
29              return Status.MOORED;
30          else if (is(prediction, 2))
31              return Status.ANCHORED;
32          else
33              return Status.OTHER;
34      }
35  
36      private static boolean is(double a, double b) {
37          return Math.abs(a - b) < 0.0001;
38      }
39  }