Surrogate Model

Explainability approach designed to build a transparent model off of the predictions of an actual model.

View this page for more information on surrogate model feature importance values. For an example of logging surrogate model feature importance values values, check out the Arize Surrogate Model Feature Importance tutorial.

Surrogate explainability is based on the idea of training a surrogate model to mimic a blackbox model, where a surrogate model is an interpretable model trained to approximate the predictions of a black box model as closely as possible. SHAP values can then be generated from the surrogate model when the blackbox model is not available. The goal is to approximate the predictions of the black box model as closely as possible and generate feature importance values from the interpretable surrogate model.

Arize Python SDK (specifically the pandas logger) gives the user the option to pass a flag with their request to send data that would produce SHAP values using the surrogate explainability approach. When the flag is enabled, a tree-based surrogate model is trained using the dataset's features and predictions, and SHAP values are generated from the surrogate model before sending the combined dataset to the Arize platform. The pandas logger can compute surrogate models for regression and binary classification models. For binary classification, the prediction score should have values between 0 and 1.

In this Colab we show how to use the surrogate_explainability flag without calculating library to generate feature importance values from a surrogate model using only the prediction outputs from a black box model. Both classification and regression examples are provided and feature importance values are sent to Arize using the Pandas logger. The library used to create a surrogate model is only supported for regression and classification models currently.

Code Example

!pip install -q 'arize[MimicExplainer]'

# Define a Schema() object for Arize to pick up data from the correct columns for logging
schema = Schema(
    prediction_id_column_name="prediction_id",
    ...
    feature_column_names= feature_cols,
)

# Log the dataframe with the schema mapping
response = arize_client.log(
    dataframe=test_dataframe, 
    schema=schema,
    model_id="surrogate_model_example",
    model_version= "v1",
    model_type=ModelTypes.SCORE_CATEGORICAL,
    environment=Environments.PRODUCTION,
    surrogate_explainability = True  # assign surrogate_explainability to True
)

Make sure to install the dependencies for the Surrogate Explainer! Questions? Email us at support@arize.com or Slack us in the #arize-support channel

Last updated

Copyright © 2023 Arize AI, Inc