[SPARK-34591][ML] Add decision tree pruning as a parameter#55763
Open
WeichenXu123 wants to merge 26 commits into
Open
[SPARK-34591][ML] Add decision tree pruning as a parameter#55763WeichenXu123 wants to merge 26 commits into
WeichenXu123 wants to merge 26 commits into
Conversation
### What changes were proposed in this pull request?
This PR disables a feature created in SPARK-3159 where LearningNodes are
merged after a RF model is trained.
### Why are the changes needed?
2 Reasons:
1. In addition to basic classification, another use case for decision trees are the
probabilities associated with predictions. Once pruned, these predictions are lost
and it makes the trees/predictions challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the trees
are left unpruned by default.
### Does this PR introduce _any_ user-facing change?
No, it's dev-only.
### How was this patch tested?
Locally ran `./build/mvn -pl mllib package` and verified tests passed
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests
This PR disables a feature created in SPARK-3159 where LearningNodes are merged after a RF model is trained.
2 Reasons:
1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
No, it's dev-only.
I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior
Locally ran `./build/mvn -pl mllib package` and verified tests passed
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests
…are merged after a RF model is trained.
2 Reasons:
1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
No, it's dev-only.
I modified the two tests introduced with this change to verify postive/negative use of feature. I also added assertions for default behavior
Locally ran `./build/mvn -pl mllib package` and verified tests passed
Locally ran `./dev/scalafmt` which resulted in some minor cosmetic changes
Additionally, running through git workflow as described here:
https://spark.apache.org/developer-tools.html#github-workflow-tests
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
zhengruifeng
previously approved these changes
May 8, 2026
There was a problem hiding this comment.
Pull request overview
Adds a new pruneTree parameter to Spark ML tree-based classification estimators to control post-training pruning (merging redundant subtrees), wiring the flag through the Scala training pipeline and exposing it in PySpark.
Changes:
- Introduce
pruneTreeas a ML param (defaulttrue) and propagate it into the underlying old-APIStrategyused by training. - Expose
pruneTreein PySparkDecisionTreeClassifier/RandomForestClassifier(constructor,setParams, setter/getter). - Update Scala unit tests to toggle pruning via
Strategy.pruneTree(instead of a testing-onlypruneargument).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| python/pyspark/ml/tree.py | Adds pruneTree Param + getter in shared tree classifier param mixin. |
| python/pyspark/ml/classification.py | Wires pruneTree through PySpark DT/RF classifier defaults, signatures, and setters. |
| mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | Updates pruning-related tests to use Strategy.pruneTree. |
| mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala | Adds pruneTree to the old-API Strategy configuration. |
| mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | Adds the ML pruneTree param to TreeClassifierParams. |
| mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | Removes test-only prune arg; uses strategy.pruneTree for model materialization. |
| mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala | Exposes setPruneTree and propagates param to old Strategy. |
| mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | Exposes setPruneTree and propagates param to old Strategy. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What changes were proposed in this pull request?
This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained.
This PR takes over #32813
Why are the changes needed?
2 Reasons:
In addition to basic classification, another use case for decision trees are the probabilities associated with predictions.
Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable.
It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default.
Please see Jira ticket for more explanation.
Does this PR introduce any user-facing change?
New params:
adds a parameter
pruneTreethat is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly.How was this patch tested?
Unit tests.