public class GradientBoostedTreesModel extends java.lang.Object implements Saveable
param: algo algorithm for the ensemble model, either Classification or Regression param: trees tree ensembles param: treeWeights tree ensemble weights
Constructor and Description |
---|
GradientBoostedTreesModel(scala.Enumeration.Value algo,
DecisionTreeModel[] trees,
double[] treeWeights) |
Modifier and Type | Method and Description |
---|---|
scala.Enumeration.Value |
algo() |
protected scala.Enumeration.Value |
combiningStrategy() |
static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> |
computeInitialPredictionAndError(RDD<LabeledPoint> data,
double initTreeWeight,
DecisionTreeModel initTree,
Loss loss)
:: DeveloperApi ::
Compute the initial predictions and errors for a dataset for the first
iteration of gradient boosting.
|
double[] |
evaluateEachIteration(RDD<LabeledPoint> data,
Loss loss)
Method to compute error or loss for every iteration of gradient boosting.
|
protected java.lang.String |
formatVersion()
Current version of model save/load format.
|
static GradientBoostedTreesModel |
load(SparkContext sc,
java.lang.String path) |
int |
numTrees()
Get number of trees in ensemble.
|
JavaRDD<java.lang.Double> |
predict(JavaRDD<Vector> features)
Java-friendly version of
TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector) . |
RDD<java.lang.Object> |
predict(RDD<Vector> features)
Predict values for the given data set.
|
double |
predict(Vector features)
Predict values for a single data point using the model trained.
|
void |
save(SparkContext sc,
java.lang.String path)
Save this model to the given path.
|
java.lang.String |
toDebugString()
Print the full model to a string.
|
java.lang.String |
toString()
Print a summary of the model.
|
int |
totalNumNodes()
Get total number of nodes, summed over all trees in the ensemble.
|
DecisionTreeModel[] |
trees() |
double[] |
treeWeights() |
static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> |
updatePredictionError(RDD<LabeledPoint> data,
RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> predictionAndError,
double treeWeight,
DecisionTreeModel tree,
Loss loss)
:: DeveloperApi ::
Update a zipped predictionError RDD
(as obtained with computeInitialPredictionAndError)
|
public GradientBoostedTreesModel(scala.Enumeration.Value algo, DecisionTreeModel[] trees, double[] treeWeights)
public static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeModel initTree, Loss loss)
data:
- training data.initTreeWeight:
- learning rate assigned to the first tree.initTree:
- first DecisionTreeModel.loss:
- evaluation metric.public static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> updatePredictionError(RDD<LabeledPoint> data, RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> predictionAndError, double treeWeight, DecisionTreeModel tree, Loss loss)
data:
- training data.predictionAndError:
- predictionError RDDtreeWeight:
- Learning rate.tree:
- Tree using which the prediction and error should be updated.loss:
- evaluation metric.public static GradientBoostedTreesModel load(SparkContext sc, java.lang.String path)
sc
- Spark context used for loading model files.path
- Path specifying the directory to which the model was saved.public scala.Enumeration.Value algo()
public DecisionTreeModel[] trees()
public double[] treeWeights()
public void save(SparkContext sc, java.lang.String path)
Saveable
This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using Loader.load
.
public double[] evaluateEachIteration(RDD<LabeledPoint> data, Loss loss)
data
- RDD of LabeledPoint
loss
- evaluation metric.protected java.lang.String formatVersion()
Saveable
formatVersion
in interface Saveable
protected scala.Enumeration.Value combiningStrategy()
public double predict(Vector features)
features
- array representing a single data pointpublic RDD<java.lang.Object> predict(RDD<Vector> features)
features
- RDD representing data points to be predictedpublic JavaRDD<java.lang.Double> predict(JavaRDD<Vector> features)
TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector)
.features
- (undocumented)public java.lang.String toString()
toString
in class java.lang.Object
public java.lang.String toDebugString()
public int numTrees()
public int totalNumNodes()