public final class GBTClassificationModel extends PredictionModel<Vector,GBTClassificationModel> implements scala.Serializable
Gradient-Boosted Trees (GBTs)
model for classification.
It supports binary labels, as well as both continuous and categorical features.
Note: Multiclass labels are not currently supported.
param: _trees Decision trees in the ensemble.
param: _treeWeights Weights for the decision trees in the ensemble.| Constructor and Description |
|---|
GBTClassificationModel(java.lang.String uid,
DecisionTreeRegressionModel[] _trees,
double[] _treeWeights)
Construct a GBTClassificationModel
|
| Modifier and Type | Method and Description |
|---|---|
GBTClassificationModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
static GBTClassificationModel |
fromOld(GradientBoostedTreesModel oldModel,
GBTClassifier parent,
scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeatures,
int numFeatures)
(private[ml]) Convert a model from the old API
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
java.lang.String |
getPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
protected double |
predict(Vector features)
Predict label for the given features.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
java.lang.String |
toString() |
protected DataFrame |
transformImpl(DataFrame dataset) |
org.apache.spark.ml.tree.DecisionTreeModel[] |
trees() |
double[] |
treeWeights() |
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
featuresDataType, setFeaturesCol, setPredictionCol, transform, transformSchematransform, transform, transformtransformSchemaclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParamsinitializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarningpublic GBTClassificationModel(java.lang.String uid,
DecisionTreeRegressionModel[] _trees,
double[] _treeWeights)
_trees - Decision trees in the ensemble._treeWeights - Weights for the decision trees in the ensemble.uid - (undocumented)public static GBTClassificationModel fromOld(GradientBoostedTreesModel oldModel, GBTClassifier parent, scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeatures, int numFeatures)
public java.lang.String uid()
Identifiableuid in interface Identifiablepublic int numFeatures()
PredictionModelnumFeatures in class PredictionModel<Vector,GBTClassificationModel>public org.apache.spark.ml.tree.DecisionTreeModel[] trees()
public double[] treeWeights()
protected DataFrame transformImpl(DataFrame dataset)
transformImpl in class PredictionModel<Vector,GBTClassificationModel>protected double predict(Vector features)
PredictionModeltransform() and output predictionCol.predict in class PredictionModel<Vector,GBTClassificationModel>features - (undocumented)public GBTClassificationModel copy(ParamMap extra)
Paramscopy in interface Paramscopy in class Model<GBTClassificationModel>extra - (undocumented)defaultCopy()public java.lang.String toString()
toString in interface IdentifiabletoString in class java.lang.Objectpublic StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()