# Approximate inference in aGrUM (pyAgrum)¶

There are several approximate inference for BN in aGrUM (pyAgrum). They share the same API than exact inference. - Loopy Belief Propagation : LBP is an approximate inference that uses exact calculous methods (when the BN os a tree) even if the BN is not a tree. LBP is a special case of inference : the algorithm may not converge and even if it converges, it may converge to anything (but the exact posterior). LBP however is fast and usually gives not so bad results. - Sampling inference : Sampling inference use sampling to compute the posterior. The sampling may be (very) slow but those algorithms converge to the exac distribution. aGrUM implements : - Montecarlo sampling, - Weighted sampling, - Importance sampling, - Gibbs sampling. - Finally, aGrUM propose the so-called ‘loopy version’ of the sampling algorithms : the idea is to use LBP as a Dirichlet prior for the sampling algorithm. A loopy version of each sampling algorithm is proposed.

In [1]:

import os

%matplotlib inline
from pylab import *
import matplotlib.pyplot as plt

def unsharpen(bn):
"""
Force the parameters of the BN not to be a bit more far from 0 or 1
"""
for nod in bn.nodes():
bn.cpt(nod).translate(bn.maxParam() / 10).normalizeAsCPT()

def compareInference(ie,ie2,ax=None):
"""
compare 2 inference by plotting all the points from (posterior(ie),posterior(ie2))
"""
exact=[]
appro=[]
errmax=0
for node in bn.nodes():
# potentials as list
exact+=ie.posterior(node).tolist()
appro+=ie2.posterior(node).tolist()
errmax=max(errmax,
(ie.posterior(node)-ie2.posterior(node)).abs().max())

if errmax<1e-10: errmax=0
if ax==None:
ax=plt.gca() # default axis for plt

ax.plot(exact,appro,'ro')
ax.set_title("{} vs {}\n {}\nMax error {:2.4} in {:2.4} seconds".format(
str(type(ie)).split(".")[2].split("_")[0][0:-2], # name of first inference
str(type(ie2)).split(".")[2].split("_")[0][0:-2], # name of second inference
ie2.messageApproximationScheme(),
errmax,
ie2.currentTime())
)

In [2]:

import pyAgrum as gum
import pyAgrum.lib.notebook as gnb
unsharpen(bn)

ie=gum.LazyPropagation(bn)
ie.makeInference()

In [3]:

gnb.showBN(bn,size='8')


## First, an exact inference.¶

In [4]:

gnb.sideBySide(gnb.getJunctionTreeMap(bn),gnb.getInference(bn,size="8")) # using LazyPropagation by default
print(ie.posterior("KINKEDTUBE"))

 G 0 0~16 0--0~16 1 1~32 1--1~32 2 2~33 2--2~33 3 3~4 3--3~4 4 4~22 4--4~22 5 5~22 5--5~22 6 6~23 6--6~23 7 7~26 7--7~26 8 8~17 8--8~17 10 10~14 10--10~14 11 11~16 11--11~16 12 12~13 12--12~13 13 13~30 13--13~30 14 14~26 14--14~26 16 16~17 16--16~17 17 17~24 17--17~24 19 19~27 19--19~27 20 20~33 20--20~33 22 22~33 22--22~33 23 23~27 23--23~27 23~31 23--23~31 24 24~26 24--24~26 26 26~27 26--26~27 27 30 30~31 30--30~31 31 31~32 31--31~32 32 32~33 32--32~33 33 19~27--27 12~13--13 2~33--33 23~27--27 22~33--33 11~16--16 24~26--26 31~32--32 10~14--14 26~27--27 13~30--30 5~22--22 7~26--26 20~33--33 16~17--17 32~33--33 23~31--31 8~17--17 1~32--32 3~4--4 4~22--22 17~24--24 14~26--26 30~31--31 6~23--23 0~16--16 structs Inference in  14.00ms KINKEDTUBE 2022-06-05T18:22:00.995039 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} VENTLUNG 2022-06-05T18:22:02.531449 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} KINKEDTUBE->VENTLUNG PRESS 2022-06-05T18:22:02.743453 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} KINKEDTUBE->PRESS HYPOVOLEMIA 2022-06-05T18:22:01.059040 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} STROKEVOLUME 2022-06-05T18:22:01.862035 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} HYPOVOLEMIA->STROKEVOLUME LVEDVOLUME 2022-06-05T18:22:01.985029 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} HYPOVOLEMIA->LVEDVOLUME INTUBATION 2022-06-05T18:22:01.128040 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} SHUNT 2022-06-05T18:22:02.192101 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} INTUBATION->SHUNT INTUBATION->VENTLUNG MINVOL 2022-06-05T18:22:02.627445 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} INTUBATION->MINVOL INTUBATION->PRESS VENTALV 2022-06-05T18:22:02.837449 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} INTUBATION->VENTALV MINVOLSET 2022-06-05T18:22:01.188044 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} VENTMACH 2022-06-05T18:22:02.053033 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} MINVOLSET->VENTMACH PULMEMBOLUS 2022-06-05T18:22:01.291111 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} PAP 2022-06-05T18:22:01.790028 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} PULMEMBOLUS->PAP PULMEMBOLUS->SHUNT INSUFFANESTH 2022-06-05T18:22:01.370017 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} CATECHOL 2022-06-05T18:22:03.240445 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} INSUFFANESTH->CATECHOL ERRLOWOUTPUT 2022-06-05T18:22:01.412030 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} HRBP 2022-06-05T18:22:03.341440 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} ERRLOWOUTPUT->HRBP ERRCAUTER 2022-06-05T18:22:01.453029 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} HRSAT 2022-06-05T18:22:03.393446 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} ERRCAUTER->HRSAT HREKG 2022-06-05T18:22:03.533444 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} ERRCAUTER->HREKG FIO2 2022-06-05T18:22:01.493033 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} PVSAT 2022-06-05T18:22:02.976446 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} FIO2->PVSAT LVFAILURE 2022-06-05T18:22:01.534033 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} LVFAILURE->STROKEVOLUME LVFAILURE->LVEDVOLUME HISTORY 2022-06-05T18:22:02.255449 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} LVFAILURE->HISTORY DISCONNECT 2022-06-05T18:22:01.574033 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} VENTTUBE 2022-06-05T18:22:02.367444 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} DISCONNECT->VENTTUBE ANAPHYLAXIS 2022-06-05T18:22:01.701030 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} TPR 2022-06-05T18:22:01.917102 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} ANAPHYLAXIS->TPR CO 2022-06-05T18:22:03.450447 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} STROKEVOLUME->CO TPR->CATECHOL BP 2022-06-05T18:22:03.609450 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} TPR->BP PCWP 2022-06-05T18:22:02.123029 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} LVEDVOLUME->PCWP CVP 2022-06-05T18:22:02.477448 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} LVEDVOLUME->CVP VENTMACH->VENTTUBE SAO2 2022-06-05T18:22:03.046445 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} SHUNT->SAO2 VENTTUBE->VENTLUNG VENTTUBE->PRESS VENTLUNG->MINVOL VENTLUNG->VENTALV EXPCO2 2022-06-05T18:22:03.190443 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} VENTLUNG->EXPCO2 ARTCO2 2022-06-05T18:22:02.887447 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} VENTALV->ARTCO2 VENTALV->PVSAT ARTCO2->EXPCO2 ARTCO2->CATECHOL PVSAT->SAO2 SAO2->CATECHOL HR 2022-06-05T18:22:03.284455 image/svg+xml Matplotlib v3.5.2, https://matplotlib.org/ *{stroke-linejoin: round; stroke-linecap: butt} CATECHOL->HR HR->HRBP HR->HRSAT HR->CO HR->HREKG CO->BP

KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
0.1167  | 0.8833  |



## Gibbs inference with default parameters¶

Gibbs inference iterations can be stopped : - by the value of error (epsilon) - by the rate of change of epsilon (MinEpsilonRate) - by the number of iteration (MaxIteration) - by the duration of the algorithm (MaxTime)

In [5]:

ie2=gum.GibbsSampling(bn)
ie2.setEpsilon(1e-2)
gnb.showInference(bn,engine=ie2,size="8")
print(ie2.posterior("KINKEDTUBE"))
print(ie2.messageApproximationScheme())
compareInference(ie,ie2)


KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
0.1156  | 0.8844  |

stopped with rate=0.00673795


With default parameters, this inference has been stopped by a low value of rate.

## Changing parameters¶

In [6]:

ie2=gum.GibbsSampling(bn)
ie2.setMaxIter(1000)
ie2.setEpsilon(5e-3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())


INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
0.8264  | 0.1136  | 0.0600  |

stopped with max iteration=1000

In [7]:

compareInference(ie,ie2)

In [8]:

ie2=gum.GibbsSampling(bn)
ie2.setMaxTime(3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
compareInference(ie,ie2)


INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
0.9700  | 0.0033  | 0.0267  |

stopped with epsilon=0.201897


## Looking at the convergence¶

In [9]:

ie2=gum.GibbsSampling(bn)
ie2.setEpsilon(10**-1.8)
ie2.setBurnIn(300)
ie2.setPeriodSize(300)
ie2.setDrawnAtRandom(True)
gnb.animApproximationScheme(ie2)
ie2.makeInference()

In [10]:

compareInference(ie,ie2)


### Importance Sampling¶

In [11]:

ie4=gum.ImportanceSampling(bn)
ie4.setEpsilon(10**-1.8)
ie4.setMaxTime(10) #10 seconds for inference
ie4.setPeriodSize(300)
ie4.makeInference()
compareInference(ie,ie4)


### Loopy Gibbs Sampling¶

Every sampling inference has a ‘hybrid’ version which consists in using a first loopy belief inference as a prior for the probability estimations by sampling.

In [12]:

ie3=gum.LoopyGibbsSampling(bn)

ie3.setEpsilon(10**-1.8)
ie3.setMaxTime(10) #10 seconds for inference
ie3.setPeriodSize(300)
ie3.makeInference()
compareInference(ie,ie3)


### Comparison of approximate inference¶

These computations may be a bit long

In [13]:

def compareAllInference(bn,evs={},epsilon=10**-1.6,epsilonRate=1e-8,maxTime=20):
ies=[gum.LazyPropagation(bn),
gum.LoopyBeliefPropagation(bn),
gum.GibbsSampling(bn),
gum.LoopyGibbsSampling(bn),
gum.WeightedSampling(bn),
gum.LoopyWeightedSampling(bn),
gum.ImportanceSampling(bn),
gum.LoopyImportanceSampling(bn)]

# burn in for Gibbs samplings
for i in [2,3]:
ies[i].setBurnIn(300)
ies[i].setDrawnAtRandom(True)

for i in range(2,len(ies)):
ies[i].setEpsilon(epsilon)
ies[i].setMinEpsilonRate(epsilonRate)
ies[i].setPeriodSize(300)
ies[i].setMaxTime(maxTime)

for i in range(len(ies)):
ies[i].setEvidence(evs)
ies[i].makeInference()

fig, axes = plt.subplots(1,len(ies)-1,figsize=(35, 3),num='gpplot')
for i in range(len(ies)-1):
compareInference(ies[0],ies[i+1],axes[i])


## Inference stopped by epsilon¶

In [14]:

compareAllInference(bn,epsilon=1e-1)

In [15]:

compareAllInference(bn,epsilon=1e-2)


## inference stopped by time¶

In [16]:

compareAllInference(bn,maxTime=1,epsilon=1e-8)

In [17]:

compareAllInference(bn,maxTime=2,epsilon=1e-8)


## Inference with Evidence¶

In [18]:

funny={'BP':1,'PCWP':2,'EXPCO2':0,'HISTORY':0}
compareAllInference(bn,maxTime=1,evs=funny,epsilon=1e-8)

In [19]:

compareAllInference(bn,maxTime=4,evs=funny,epsilon=1e-8)

In [20]:

compareAllInference(bn,maxTime=10,evs=funny,epsilon=1e-8)

In [ ]: