Explaining a model

Creative Commons License

aGrUM

interactive online version

In [1]:
import time

from pyAgrum.lib.bn2graph import BN2dot
import numpy as np
import pandas as pd

import pyAgrum as gum
import pyAgrum.lib.notebook as gnb
import pyAgrum.lib.explain as expl

import matplotlib.pyplot as plt

Building the model

We build a simple graph for the example

In [2]:
template=gum.fastBN("X1->X2->Y;X3->Z->Y;X0->Z;X1->Z;X2->R[5];Z->R;X1->Y")
data_path = "res/shap/Data_6var_direct_indirect.csv"

#gum.generateSample(template,1000,data_path)

learner = gum.BNLearner(data_path,template)
bn = learner.learnParameters(template.dag())
bn
Out[2]:
G R R Y Y X0 X0 Z Z X0->Z Z->R Z->Y X2 X2 X2->R X2->Y X1 X1 X1->Y X1->Z X1->X2 X3 X3 X3->Z

1-independence list (w.r.t. the class Y)

Given a model, it may be interesting to investigate the conditional independences of the class Y created by this very model.

In [3]:
# this function explores all the CI between 2 variables and computes the p-values w.r.t to a csv file.
expl.independenceListForPairs(bn,data_path)
Out[3]:
{('R', 'X0', ('X1', 'Z')): 0.7083382647903902,
 ('R', 'X1', ('X2', 'Z')): 0.4693848625409949,
 ('R', 'X3', ('X1', 'Z')): 0.4128522974536623,
 ('R', 'Y', ('X2', 'Z')): 0.8684231094674687,
 ('X0', 'X1', ()): 0.723302358657366,
 ('X0', 'X2', ()): 0.9801394906304377,
 ('X0', 'X3', ()): 0.7676868597218647,
 ('X0', 'Y', ('X1', 'Z')): 0.5816487109659612,
 ('X1', 'X3', ()): 0.5216508257424717,
 ('X2', 'X3', ()): 0.9837021981131505,
 ('X2', 'Z', ('X1',)): 0.6638491605436834,
 ('X3', 'Y', ('X1', 'Z')): 0.8774081450472305}
../_images/notebooks_95-Tools_Explain_7_1.svg

… with respect to a specific target.

In [4]:
expl.independenceListForPairs(bn,data_path,target="Y")
Out[4]:
{('Y', 'R', ('X2', 'Z')): 0.8684231094674687,
 ('Y', 'X0', ('X1', 'Z')): 0.5816487109659612,
 ('Y', 'X3', ('X1', 'Z')): 0.8774081450472305}
../_images/notebooks_95-Tools_Explain_9_1.svg

2-ShapValues

In [5]:
print(expl.ShapValues.__doc__)

  The ShapValue class implements the calculation of Shap values in Bayesian networks.

  The main implementation is based on Conditional Shap values [3]_, but the Interventional calculation method proposed in [2]_ is also present. In addition, a new causal method, based on [1]_, is implemented which is well suited for Bayesian networks.

.. [1] Heskes, T., Sijben, E., Bucur, I., & Claassen, T. (2020). Causal Shapley Values: Exploiting Causal Knowledge. 34th Conference on Neural Information Processing Systems. Vancouver, Canada.

.. [2] Janzing, D., Minorics, L., & Blöbaum, P. (2019). Feature relevance quantification in explainable AI: A causality problem. arXiv: Machine Learning. Retrieved 6 24, 2021, from https://arxiv.org/abs/1910.13413

.. [3] Lundberg, S. M., & Su-In, L. (2017). A Unified Approach to Interpreting Model. 31st Conference on Neural Information Processing Systems. Long Beach, CA, USA.

The ShapValue class implements the calculation of Shap values in Bayesian networks. It is necessary to specify a target and to provide a Bayesian network whose parameters are known and will be used later in the different calculation methods.

In [6]:
gumshap = expl.ShapValues(bn, 'Y')

Compute Conditionnal in Bayesian Network

A dataset (as a pandas.dataframe) must be provided so that the Bayesian network can learn its parameters and then predict.

The method conditional computes the conditonal shap values using the Bayesian Networks. It returns 2 graphs and a dictionary. The first one shows the distribution of the shap values for each of the variables, the second one classifies the variables by their importance.

In [7]:
train = pd.read_csv(data_path).sample(frac=1.)
In [8]:
t_start = time.time()
resultat = gumshap.conditional(train,
                               plot=True,plot_importance=True,percentage=False)
print(f'Run Time : {time.time()-t_start} sec')
../_images/notebooks_95-Tools_Explain_17_0.svg
Run Time : 16.09646701812744 sec
In [9]:
t_start = time.time()
resultat = gumshap.conditional(train,
                               plot=False,plot_importance=True,percentage=False)
print(f'Run Time : {time.time()-t_start} sec')
../_images/notebooks_95-Tools_Explain_18_0.svg
Run Time : 15.890217065811157 sec
In [10]:
result = gumshap.conditional(train,
                             plot=True,plot_importance=False,percentage=False)
#result is a Dict[str,float] of the different Shapley values for all nodes.
../_images/notebooks_95-Tools_Explain_19_0.svg

The result is returned as a dictionary, the keys are the names of the features and the associated value is the absolute value of the average of the calculated shap.

In [11]:
t_start = time.time()
resultat = gumshap.conditional(train,
                               plot=False,plot_importance=False,percentage=False)
print(f'Run Time : {time.time()-t_start} sec')
resultat
Run Time : 15.818578004837036 sec
Out[11]:
{'R': np.float64(0.054456334441524),
 'X0': np.float64(0.06176712200000176),
 'Z': np.float64(0.5464180054433385),
 'X2': np.float64(0.32716064437520065),
 'X1': np.float64(0.2533375405370652),
 'X3': np.float64(0.10465402104047898)}

Causal Shap Values

This method is similar to the previous one, except the formula of computation. It computes the causal shap value as described in the paper of Heskes Causal Shapley Values: Exploiting Causal Knowledge to Explain Individual Predictions of Complex Models .

In [12]:
t_start = time.time()
causal = gumshap.causal(train,
                        plot=True, plot_importance=True, percentage=False)
print(f'Run Time : {time.time()-t_start} sec')
../_images/notebooks_95-Tools_Explain_24_0.svg
Run Time : 12.43918490409851 sec

As you can see, since \(R\) is not among the ‘causes’ of Y, its causal importance is null.

Marginal Shap Values

Similarly, one can also compute marginal Shap Value.

In [13]:
t_start = time.time()
marginal = gumshap.marginal(train, sample_size=10,
                            plot=True,plot_importance=True,percentage=False)
print(f'Run Time : {time.time()-t_start} sec')
print(marginal)
../_images/notebooks_95-Tools_Explain_26_0.svg
Run Time : 96.8886923789978 sec
{'R': np.float64(0.0), 'X0': np.float64(0.0), 'Z': np.float64(0.7163284011241826), 'X2': np.float64(0.38495483799667063), 'X1': np.float64(0.3822200556800647), 'X3': np.float64(0.0)}

As you can see, since \(R\), \(X0\) and \(X3\) are not in the Markov Blanket of \(Y\), their marginal importances are null.

Saving the graph

You can specify a filename if you prefer to save this figure instead of showing it:

In [17]:
t_start = time.time()
causal2 = gumshap.causal(train,
                         plot=True,plot_importance=True,percentage=False,
                         filename="out/marginal.pdf")
print(f'Run Time : {time.time()-t_start} sec')
print(causal2)
Run Time : 12.455586910247803 sec
{'R': np.float64(1.5312548682736943e-17), 'X0': np.float64(0.05709885754818158), 'Z': np.float64(0.5925926625367108), 'X2': np.float64(0.3411206770245262), 'X1': np.float64(0.25248101589161415), 'X3': np.float64(0.07711539074829295)}

Visualizing shapvalues directly on a BN

This function returns a coloured graph that makes it easier to understand which variable is important and where it is located in the graph.

In [ ]:
import pyAgrum.lib.notebook as gnb

g = explain.showShapValues(causal)
gnb.showGraph(g)

Visualizing information

Finally another view consists in showing the entropy on each node and the mutual informations on each arcs.

In [ ]:
expl.showInformation(bn)