Approximate inference in aGrUM (pyAgrum)

Creative Commons License

aGrUM

interactive online version

There are several approximate inference for BN in aGrUM (pyAgrum). They share the same API than exact inference. - Loopy Belief Propagation : LBP is an approximate inference that uses exact calculous methods (when the BN os a tree) even if the BN is not a tree. LBP is a special case of inference : the algorithm may not converge and even if it converges, it may converge to anything (but the exact posterior). LBP however is fast and usually gives not so bad results. - Sampling inference : Sampling inference use sampling to compute the posterior. The sampling may be (very) slow but those algorithms converge to the exac distribution. aGrUM implements : - Montecarlo sampling, - Weighted sampling, - Importance sampling, - Gibbs sampling. - Finally, aGrUM propose the so-called ‘loopy version’ of the sampling algorithms : the idea is to use LBP as a Dirichlet prior for the sampling algorithm. A loopy version of each sampling algorithm is proposed.

In [1]:
import os

%matplotlib inline
from pylab import *
import matplotlib.pyplot as plt

def unsharpen(bn):
  """
  Force the parameters of the BN not to be a bit more far from 0 or 1
  """
  for nod in bn.nodes():
    bn.cpt(nod).translate(bn.maxParam() / 10).normalizeAsCPT()

def compareInference(ie,ie2,ax=None):
    """
    compare 2 inference by plotting all the points from (posterior(ie),posterior(ie2))
    """
    exact=[]
    appro=[]
    errmax=0
    for node in bn.nodes():
        # potentials as list
        exact+=ie.posterior(node).tolist()
        appro+=ie2.posterior(node).tolist()
        errmax=max(errmax,
                   (ie.posterior(node)-ie2.posterior(node)).abs().max())

    if errmax<1e-10: errmax=0
    if ax==None:
        ax=plt.gca() # default axis for plt

    ax.plot(exact,appro,'ro')
    ax.set_title("{} vs {}\n {}\nMax error {:2.4} in {:2.4} seconds".format(
        str(type(ie)).split(".")[2].split("_")[0][0:-2], # name of first inference
        str(type(ie2)).split(".")[2].split("_")[0][0:-2], # name of second inference
        ie2.messageApproximationScheme(),
        errmax,
        ie2.currentTime())
                )
In [2]:
import pyAgrum as gum
import pyAgrum.lib.notebook as gnb
bn=gum.loadBN("res/alarm.dsl")
unsharpen(bn)

ie=gum.LazyPropagation(bn)
ie.makeInference()
In [3]:
gnb.showBN(bn,size='8')

../_images/notebooks_44-Inference_ApproximateInference_5_0.svg

First, an exact inference.

In [4]:
gnb.sideBySide(gnb.getJunctionTreeMap(bn),gnb.getInference(bn,size="8")) # using LazyPropagation by default
print(ie.posterior("KINKEDTUBE"))

G 0 0~16 0--0~16 1 1~32 1--1~32 2 2~33 2--2~33 3 3~4 3--3~4 4 4~22 4--4~22 5 5~22 5--5~22 6 6~23 6--6~23 7 7~26 7--7~26 8 8~17 8--8~17 10 10~14 10--10~14 11 11~16 11--11~16 12 12~13 12--12~13 13 13~30 13--13~30 14 14~26 14--14~26 16 16~17 16--16~17 17 17~24 17--17~24 19 19~27 19--19~27 20 20~33 20--20~33 22 22~33 22--22~33 23 23~27 23--23~27 23~31 23--23~31 24 24~26 24--24~26 26 26~27 26--26~27 27 30 30~31 30--30~31 31 31~32 31--31~32 32 32~33 32--32~33 33 19~27--27 12~13--13 2~33--33 23~27--27 22~33--33 11~16--16 24~26--26 31~32--32 10~14--14 26~27--27 13~30--30 5~22--22 7~26--26 20~33--33 16~17--17 32~33--33 23~31--31 8~17--17 1~32--32 3~4--4 4~22--22 17~24--24 14~26--26 30~31--31 6~23--23 0~16--16
structs Inference in   4.60ms KINKEDTUBE 2023-05-09T10:16:56.484557 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ VENTLUNG 2023-05-09T10:16:57.715396 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ KINKEDTUBE->VENTLUNG PRESS 2023-05-09T10:16:57.933315 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ KINKEDTUBE->PRESS HYPOVOLEMIA 2023-05-09T10:16:56.537602 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ STROKEVOLUME 2023-05-09T10:16:57.184755 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ HYPOVOLEMIA->STROKEVOLUME LVEDVOLUME 2023-05-09T10:16:57.301383 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ HYPOVOLEMIA->LVEDVOLUME INTUBATION 2023-05-09T10:16:56.594434 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ SHUNT 2023-05-09T10:16:57.478878 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ INTUBATION->SHUNT INTUBATION->VENTLUNG MINVOL 2023-05-09T10:16:57.868595 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ INTUBATION->MINVOL INTUBATION->PRESS VENTALV 2023-05-09T10:16:58.003140 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ INTUBATION->VENTALV MINVOLSET 2023-05-09T10:16:56.652713 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ VENTMACH 2023-05-09T10:16:57.362696 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ MINVOLSET->VENTMACH PULMEMBOLUS 2023-05-09T10:16:56.706990 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ PAP 2023-05-09T10:16:57.127510 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ PULMEMBOLUS->PAP PULMEMBOLUS->SHUNT INSUFFANESTH 2023-05-09T10:16:56.759768 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ CATECHOL 2023-05-09T10:16:58.312384 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ INSUFFANESTH->CATECHOL ERRLOWOUTPUT 2023-05-09T10:16:56.812135 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ HRBP 2023-05-09T10:16:58.430794 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ ERRLOWOUTPUT->HRBP ERRCAUTER 2023-05-09T10:16:56.864870 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ HRSAT 2023-05-09T10:16:58.492554 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ ERRCAUTER->HRSAT HREKG 2023-05-09T10:16:58.607080 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ ERRCAUTER->HREKG FIO2 2023-05-09T10:16:56.917110 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ PVSAT 2023-05-09T10:16:58.125996 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ FIO2->PVSAT LVFAILURE 2023-05-09T10:16:56.969207 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ LVFAILURE->STROKEVOLUME LVFAILURE->LVEDVOLUME HISTORY 2023-05-09T10:16:57.529850 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ LVFAILURE->HISTORY DISCONNECT 2023-05-09T10:16:57.020181 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ VENTTUBE 2023-05-09T10:16:57.590121 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ DISCONNECT->VENTTUBE ANAPHYLAXIS 2023-05-09T10:16:57.071154 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ TPR 2023-05-09T10:16:57.242732 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ ANAPHYLAXIS->TPR CO 2023-05-09T10:16:58.550151 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ STROKEVOLUME->CO TPR->CATECHOL BP 2023-05-09T10:16:58.662154 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ TPR->BP PCWP 2023-05-09T10:16:57.422646 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ LVEDVOLUME->PCWP CVP 2023-05-09T10:16:57.651332 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ LVEDVOLUME->CVP VENTMACH->VENTTUBE SAO2 2023-05-09T10:16:58.188505 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ SHUNT->SAO2 VENTTUBE->VENTLUNG VENTTUBE->PRESS VENTLUNG->MINVOL VENTLUNG->VENTALV EXPCO2 2023-05-09T10:16:58.252741 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ VENTLUNG->EXPCO2 ARTCO2 2023-05-09T10:16:58.066114 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ VENTALV->ARTCO2 VENTALV->PVSAT ARTCO2->EXPCO2 ARTCO2->CATECHOL PVSAT->SAO2 SAO2->CATECHOL HR 2023-05-09T10:16:58.370206 image/svg+xml Matplotlib v3.7.1, https://matplotlib.org/ CATECHOL->HR HR->HRBP HR->HRSAT HR->CO HR->HREKG CO->BP

  KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
 0.1167  | 0.8833  |

Gibbs Inference

Gibbs inference with default parameters

Gibbs inference iterations can be stopped : - by the value of error (epsilon) - by the rate of change of epsilon (MinEpsilonRate) - by the number of iteration (MaxIteration) - by the duration of the algorithm (MaxTime)

In [5]:
ie2=gum.GibbsSampling(bn)
ie2.setEpsilon(1e-2)
gnb.showInference(bn,engine=ie2,size="8")
print(ie2.posterior("KINKEDTUBE"))
print(ie2.messageApproximationScheme())
compareInference(ie,ie2)
../_images/notebooks_44-Inference_ApproximateInference_9_0.svg

  KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
 0.1184  | 0.8816  |

stopped with rate=0.00673795
../_images/notebooks_44-Inference_ApproximateInference_9_2.svg

With default parameters, this inference has been stopped by a low value of rate.

Changing parameters

In [6]:
ie2=gum.GibbsSampling(bn)
ie2.setMaxIter(1000)
ie2.setEpsilon(5e-3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())

  INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
 0.8091  | 0.1027  | 0.0882  |

stopped with max iteration=1000
In [7]:
compareInference(ie,ie2)
../_images/notebooks_44-Inference_ApproximateInference_13_0.svg
In [8]:
ie2=gum.GibbsSampling(bn)
ie2.setMaxTime(3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
compareInference(ie,ie2)

  INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
 0.5333  | 0.2500  | 0.2167  |

stopped with epsilon=0.201897
../_images/notebooks_44-Inference_ApproximateInference_14_1.svg

Looking at the convergence

In [9]:
ie2=gum.GibbsSampling(bn)
ie2.setEpsilon(10**-1.8)
ie2.setBurnIn(300)
ie2.setPeriodSize(300)
ie2.setDrawnAtRandom(True)
gnb.animApproximationScheme(ie2)
ie2.makeInference()
../_images/notebooks_44-Inference_ApproximateInference_16_0.svg
In [10]:
compareInference(ie,ie2)
../_images/notebooks_44-Inference_ApproximateInference_17_0.svg

Importance Sampling

In [11]:
ie4=gum.ImportanceSampling(bn)
ie4.setEpsilon(10**-1.8)
ie4.setMaxTime(10) #10 seconds for inference
ie4.setPeriodSize(300)
ie4.makeInference()
compareInference(ie,ie4)
../_images/notebooks_44-Inference_ApproximateInference_19_0.svg

Loopy Gibbs Sampling

Every sampling inference has a ‘hybrid’ version which consists in using a first loopy belief inference as a prior for the probability estimations by sampling.

In [12]:
ie3=gum.LoopyGibbsSampling(bn)

ie3.setEpsilon(10**-1.8)
ie3.setMaxTime(10) #10 seconds for inference
ie3.setPeriodSize(300)
ie3.makeInference()
compareInference(ie,ie3)

../_images/notebooks_44-Inference_ApproximateInference_21_0.svg

Comparison of approximate inference

These computations may be a bit long

In [13]:
def compareAllInference(bn,evs={},epsilon=10**-1.6,epsilonRate=1e-8,maxTime=20):
    ies=[gum.LazyPropagation(bn),
         gum.LoopyBeliefPropagation(bn),
         gum.GibbsSampling(bn),
         gum.LoopyGibbsSampling(bn),
         gum.WeightedSampling(bn),
         gum.LoopyWeightedSampling(bn),
         gum.ImportanceSampling(bn),
         gum.LoopyImportanceSampling(bn)]

    # burn in for Gibbs samplings
    for i in [2,3]:
        ies[i].setBurnIn(300)
        ies[i].setDrawnAtRandom(True)

    for i in range(2,len(ies)):
        ies[i].setEpsilon(epsilon)
        ies[i].setMinEpsilonRate(epsilonRate)
        ies[i].setPeriodSize(300)
        ies[i].setMaxTime(maxTime)

    for i in range(len(ies)):
        ies[i].setEvidence(evs)
        ies[i].makeInference()

    fig, axes = plt.subplots(1,len(ies)-1,figsize=(35, 3),num='gpplot')
    for i in range(len(ies)-1):
        compareInference(ies[0],ies[i+1],axes[i])

Inference stopped by epsilon

In [14]:
compareAllInference(bn,epsilon=1e-1)
../_images/notebooks_44-Inference_ApproximateInference_25_0.svg
In [15]:
compareAllInference(bn,epsilon=1e-2)
../_images/notebooks_44-Inference_ApproximateInference_26_0.svg

inference stopped by time

In [16]:
compareAllInference(bn,maxTime=1,epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_28_0.svg
In [17]:
compareAllInference(bn,maxTime=2,epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_29_0.svg

Inference with Evidence

In [18]:
funny={'BP':1,'PCWP':2,'EXPCO2':0,'HISTORY':0}
compareAllInference(bn,maxTime=1,evs=funny,epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_31_0.svg
In [19]:
compareAllInference(bn,maxTime=4,evs=funny,epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_32_0.svg
In [20]:
compareAllInference(bn,maxTime=10,evs=funny,epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_33_0.svg
In [ ]: