Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/interpret_community/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ class LightGBMParams(object):
"""Provide constants for LightGBM."""

CATEGORICAL_FEATURE = 'categorical_feature'
N_JOBS = 'n_jobs'
ALL = [CATEGORICAL_FEATURE, N_JOBS]


class LinearExplainableModelParams(object):
"""Provide constants for LinearExplainableModel."""
SPARSE_DATA = 'sparse_data'
ALL = [SPARSE_DATA]


class ShapValuesOutput(str, Enum):
Expand Down
56 changes: 48 additions & 8 deletions python/interpret_community/mimic/mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..common.blackbox_explainer import BlackBoxExplainer

from .model_distill import _model_distill
from .models import LGBMExplainableModel
from .models import LGBMExplainableModel, LinearExplainableModel
from ..explanation.explanation import _create_local_explanation, _create_global_explanation, \
_aggregate_global_from_local_explanation, _aggregate_streamed_local_explanations, \
_create_raw_feats_global_explanation, _create_raw_feats_local_explanation, \
Expand All @@ -30,7 +30,7 @@
from ..dataset.dataset_wrapper import DatasetWrapper
from ..common.constants import ExplainParams, ExplainType, ModelTask, \
ShapValuesOutput, MimicSerializationConstants, ExplainableModelType, \
LightGBMParams, Defaults, Extension, ResetIndex
LightGBMParams, Defaults, Extension, ResetIndex, LinearExplainableModelParams
import logging
import json

Expand Down Expand Up @@ -236,6 +236,8 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
"""
if transformations is not None and explain_subset is not None:
raise ValueError("explain_subset not supported with transformations")
self._validate_explainable_model_args(explainable_model=explainable_model,
explainable_model_args=explainable_model_args)
self.reset_index = reset_index
self._datamapper = None
if transformations is not None:
Expand All @@ -250,8 +252,7 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
wrapped_model, eval_ml_domain = _wrap_model(model, initialization_examples, model_task, is_function)
super(MimicExplainer, self).__init__(wrapped_model, is_function=is_function,
model_task=eval_ml_domain, **kwargs)
if explainable_model_args is None:
explainable_model_args = {}

if categorical_features is None:
categorical_features = []
self._logger.debug('Initializing MimicExplainer')
Expand Down Expand Up @@ -288,7 +289,6 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
# Index the categorical string columns for training data
self._column_indexer = initialization_examples.string_index(columns=categorical_features)
self._one_hot_encoder = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this removed?

             explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features 

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I guess you moved it to line 347, I guess that's ok, my only slight concern is now we are doing the same checks in multiple places:

is_tree_model = explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE
        if is_tree_model and self._supports_categoricals(explainable_model):

but it's not expensive so I think it's ok

explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features
else:
# One-hot-encode categoricals for models that don't support categoricals natively
self._column_indexer = initialization_examples.string_index(columns=categorical_features)
Expand All @@ -304,15 +304,55 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
if isinstance(training_data, DenseData):
training_data = training_data.data

explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag
if self._supports_shap_values_output(explainable_model):
explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output
explainable_model_args = self._supplement_explainable_model_args(
explainable_model=explainable_model,
explainable_model_args=explainable_model_args,
categorical_features=categorical_features,
shap_values_output=shap_values_output)
self.surrogate_model = _model_distill(self.function, explainable_model, training_data,
original_training_data, explainable_model_args)
self._method = self.surrogate_model._method
self._original_eval_examples = None
self._allow_all_transformations = allow_all_transformations

def _validate_explainable_model_args(self, explainable_model, explainable_model_args):
if explainable_model_args is None:
return

if explainable_model == LGBMExplainableModel:
for linear_param in LinearExplainableModelParams.ALL:
if linear_param in explainable_model_args:
raise Exception(linear_param +
" found in params for LightGBM explainable model")

if explainable_model == LinearExplainableModel:
for lightgbm_param in LightGBMParams.ALL:
if lightgbm_param in explainable_model_args:
raise Exception(lightgbm_param +
" found in params for Linear explainable model")

all_supported_explainable_model_args = [LightGBMParams.ALL, LinearExplainableModelParams.ALL]
for explainable_model_arg in explainable_model_args:
if explainable_model_arg not in all_supported_explainable_model_args:
raise Exception(
"Found unsupported explainable model argument " + explainable_model_arg)

def _supplement_explainable_model_args(self, explainable_model, explainable_model_args,
categorical_features, shap_values_output):
if explainable_model_args is None:
explainable_model_args = {}

if explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE and \
self._supports_categoricals(explainable_model):
explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features

explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag

if self._supports_shap_values_output(explainable_model):
explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output

return explainable_model_args

def _supports_categoricals(self, explainable_model):
return issubclass(explainable_model, LGBMExplainableModel)

Expand Down
26 changes: 25 additions & 1 deletion test/test_mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sys import platform
from interpret_community.common.constants import ShapValuesOutput, ModelTask
from interpret_community.common.constants import ShapValuesOutput, ModelTask, \
LinearExplainableModelParams, LightGBMParams
from interpret_community.mimic.models.lightgbm_model import LGBMExplainableModel
from interpret_community.mimic.models.linear_model import LinearExplainableModel
from common_utils import create_sklearn_svm_classifier, create_sklearn_linear_regressor, \
Expand Down Expand Up @@ -521,6 +522,29 @@ def test_dense_wide_data(self, mimic_explainer):
global_explanation = explainer.explain_global(df_X)
assert global_explanation.method == LIGHTGBM_METHOD

@pytest.mark.parametrize("error_config",
[(LGBMExplainableModel, {LinearExplainableModelParams.SPARSE_DATA: True}),
(LinearExplainableModel, {LightGBMParams.N_JOBS: -1}),
(LinearExplainableModel, {LightGBMParams.CATEGORICAL_FEATURE: []}),
(LGBMExplainableModel, {"unsupported": True}),
(LinearExplainableModel, {"unsupported": True})])
def test_validate_explainable_model_args(self, error_config, mimic_explainer):
num_features = 100
num_rows = 1000
test_size = 0.2
X, y = make_regression(n_samples=num_rows, n_features=num_features)
x_train, x_test, y_train, _ = train_test_split(X, y, test_size=test_size, random_state=42)

model = LinearRegression(normalize=True)
model.fit(x_train, y_train)

explainable_model = error_config[0]
explainable_model_args = error_config[1]
with pytest.raises(Exception):
mimic_explainer(model, x_train, explainable_model,
explainable_model_args=explainable_model_args,
augment_data=False)

@property
def iris_overall_expected_features(self):
return [['petal length', 'petal width', 'sepal width', 'sepal length'],
Expand Down