Skip to content

[SPARK-34591][ML] Add decision tree pruning as a parameter#55763

Open
WeichenXu123 wants to merge 26 commits into
apache:masterfrom
WeichenXu123:SPARK-34591
Open

[SPARK-34591][ML] Add decision tree pruning as a parameter#55763
WeichenXu123 wants to merge 26 commits into
apache:masterfrom
WeichenXu123:SPARK-34591

Conversation

@WeichenXu123
Copy link
Copy Markdown
Contributor

What changes were proposed in this pull request?

This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.

This PR takes over #32813

Why are the changes needed?

2 Reasons:

  1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
    Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

  2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

Does this PR introduce any user-facing change?

New params:
adds a parameter pruneTree that is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly.

How was this patch tested?

Unit tests.

bribiescas-carlos and others added 21 commits June 8, 2021 12:10
    ### What changes were proposed in this pull request?

    This PR disables a feature created in SPARK-3159 where LearningNodes are
    merged after a RF model is trained.

    ### Why are the changes needed?

    2 Reasons:

    1. In addition to basic classification, another use case for decision trees are the
    probabilities associated with predictions.  Once pruned, these predictions are lost
    and it makes the trees/predictions challenging to work with if not unusable.

    2. It is not in line with the default behavior in sklearn.  In sklearn, the trees
    are left unpruned by default.

    ### Does this PR introduce _any_ user-facing change?

    No, it's dev-only.

    ### How was this patch tested?
    Locally ran `./build/mvn -pl mllib package` and verified tests passed
    Additionally, running through git workflow as described here:
    	https://spark.apache.org/developer-tools.html#github-workflow-tests
This PR disables a feature created in SPARK-3159 where LearningNodes are merged after a RF model is trained.

2 Reasons:

1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
 Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

2. It is not in line with the default behavior in sklearn.  In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

No, it's dev-only.

I modified the two tests introduced with this change to verify postive/negative use of feature.  I also added assertions for default behavior

Locally ran `./build/mvn -pl mllib package` and verified tests passed
Additionally, running through git workflow as described here:
    	https://spark.apache.org/developer-tools.html#github-workflow-tests
…are merged after a RF model is trained.

2 Reasons:

1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
 Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.

2. It is not in line with the default behavior in sklearn.  In sklearn, the trees are left unpruned by default.

Please see Jira ticket for more explanation.

No, it's dev-only.

I modified the two tests introduced with this change to verify postive/negative use of feature.  I also added assertions for default behavior

Locally ran `./build/mvn -pl mllib package` and verified tests passed
Locally ran `./dev/scalafmt` which resulted in some minor cosmetic changes

Additionally, running through git workflow as described here:
    	https://spark.apache.org/developer-tools.html#github-workflow-tests
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
zhengruifeng
zhengruifeng previously approved these changes May 8, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new pruneTree parameter to Spark ML tree-based classification estimators to control post-training pruning (merging redundant subtrees), wiring the flag through the Scala training pipeline and exposing it in PySpark.

Changes:

  • Introduce pruneTree as a ML param (default true) and propagate it into the underlying old-API Strategy used by training.
  • Expose pruneTree in PySpark DecisionTreeClassifier / RandomForestClassifier (constructor, setParams, setter/getter).
  • Update Scala unit tests to toggle pruning via Strategy.pruneTree (instead of a testing-only prune argument).

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
python/pyspark/ml/tree.py Adds pruneTree Param + getter in shared tree classifier param mixin.
python/pyspark/ml/classification.py Wires pruneTree through PySpark DT/RF classifier defaults, signatures, and setters.
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala Updates pruning-related tests to use Strategy.pruneTree.
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala Adds pruneTree to the old-API Strategy configuration.
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala Adds the ML pruneTree param to TreeClassifierParams.
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala Removes test-only prune arg; uses strategy.pruneTree for model materialization.
mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala Exposes setPruneTree and propagates param to old Strategy.
mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala Exposes setPruneTree and propagates param to old Strategy.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala Outdated
Comment thread mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala Outdated
Comment thread python/pyspark/ml/tree.py Outdated
Comment thread python/pyspark/ml/tree.py
Comment thread python/pyspark/ml/classification.py
@zhengruifeng zhengruifeng dismissed their stale review May 8, 2026 11:32

dismiss

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants