Approximate inference in aGrUM (pyAgrum)

Creative Commons License

aGrUM

interactive online version

There are several approximate inference for BN in aGrUM (pyAgrum). They share the same API than exact inference. - Loopy Belief Propagation : LBP is an approximate inference that uses exact calculous methods (when the BN os a tree) even if the BN is not a tree. LBP is a special case of inference : the algorithm may not converge and even if it converges, it may converge to anything (but the exact posterior). LBP however is fast and usually gives not so bad results. - Sampling inference : Sampling inference use sampling to compute the posterior. The sampling may be (very) slow but those algorithms converge to the exac distribution. aGrUM implements : - Montecarlo sampling, - Weighted sampling, - Importance sampling, - Gibbs sampling. - Finally, aGrUM propose the so-called ‘loopy version’ of the sampling algorithms : the idea is to use LBP as a Dirichlet prior for the sampling algorithm. A loopy version of each sampling algorithm is proposed.

In [1]:
import os

%matplotlib inline
from pylab import *
import matplotlib.pyplot as plt

def unsharpen(bn):
  """
  Force the parameters of the BN not to be a bit more far from 0 or 1
  """
  for nod in bn.nodes():
    bn.cpt(nod).translate(bn.maxParam() / 10).normalizeAsCPT()

def compareInference(ie,ie2,ax=None):
    """
    compare 2 inference by plotting all the points from (posterior(ie),posterior(ie2))
    """
    exact=[]
    appro=[]
    errmax=0
    for node in bn.nodes():
        # potentials as list
        exact+=ie.posterior(node).tolist()
        appro+=ie2.posterior(node).tolist()
        errmax=max(errmax,
                   (ie.posterior(node)-ie2.posterior(node)).abs().max())

    if errmax<1e-10: errmax=0
    if ax==None:
        ax=plt.gca() # default axis for plt

    ax.plot(exact,appro,'ro')
    ax.set_title("{} vs {}\n {}\nMax error {:2.4} in {:2.4} seconds".format(
        str(type(ie)).split(".")[2].split("_")[0][0:-2], # name of first inference
        str(type(ie2)).split(".")[2].split("_")[0][0:-2], # name of second inference
        ie2.messageApproximationScheme(),
        errmax,
        ie2.currentTime())
                )
In [2]:
import pyAgrum as gum
import pyAgrum.lib.notebook as gnb
bn=gum.loadBN("res/alarm.dsl")
unsharpen(bn)

ie=gum.LazyPropagation(bn)
ie.makeInference()
In [3]:
gnb.showBN(bn,size='8')

../_images/notebooks_44-Inference_ApproximateInference_5_0.svg

First, an exact inference.

In [4]:
gnb.sideBySide(gnb.getJunctionTreeMap(bn),gnb.getInference(bn,size="8")) # using LazyPropagation by default
print(ie.posterior("KINKEDTUBE"))

G 0 0~16 0--0~16 1 1~32 1--1~32 2 2~33 2--2~33 3 3~4 3--3~4 4 4~22 4--4~22 5 5~22 5--5~22 6 6~23 6--6~23 7 7~26 7--7~26 8 8~17 8--8~17 10 10~14 10--10~14 11 11~16 11--11~16 12 12~13 12--12~13 13 13~30 13--13~30 14 14~26 14--14~26 16 16~17 16--16~17 17 17~24 17--17~24 19 19~27 19--19~27 20 20~33 20--20~33 22 22~33 22--22~33 23 23~27 23--23~27 23~31 23--23~31 24 24~26 24--24~26 26 26~27 26--26~27 27 30 30~31 30--30~31 31 31~32 31--31~32 32 32~33 32--32~33 33 19~27--27 12~13--13 2~33--33 23~27--27 22~33--33 11~16--16 24~26--26 31~32--32 10~14--14 26~27--27 13~30--30 5~22--22 7~26--26 20~33--33 16~17--17 32~33--33 23~31--31 8~17--17 1~32--32 3~4--4 4~22--22 17~24--24 14~26--26 30~31--31 6~23--23 0~16--16
structs Inference in   5.67ms KINKEDTUBE 2022-12-22T16:16:44.675635 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ VENTLUNG 2022-12-22T16:16:45.725447 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ KINKEDTUBE->VENTLUNG PRESS 2022-12-22T16:16:45.835854 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ KINKEDTUBE->PRESS HYPOVOLEMIA 2022-12-22T16:16:44.717127 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ STROKEVOLUME 2022-12-22T16:16:45.299122 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ HYPOVOLEMIA->STROKEVOLUME LVEDVOLUME 2022-12-22T16:16:45.394122 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ HYPOVOLEMIA->LVEDVOLUME INTUBATION 2022-12-22T16:16:44.761412 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ SHUNT 2022-12-22T16:16:45.536743 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ INTUBATION->SHUNT INTUBATION->VENTLUNG MINVOL 2022-12-22T16:16:45.781919 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ INTUBATION->MINVOL INTUBATION->PRESS VENTALV 2022-12-22T16:16:45.888287 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ INTUBATION->VENTALV MINVOLSET 2022-12-22T16:16:44.809722 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ VENTMACH 2022-12-22T16:16:45.445755 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ MINVOLSET->VENTMACH PULMEMBOLUS 2022-12-22T16:16:44.852395 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ PAP 2022-12-22T16:16:45.249448 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ PULMEMBOLUS->PAP PULMEMBOLUS->SHUNT INSUFFANESTH 2022-12-22T16:16:44.893558 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ CATECHOL 2022-12-22T16:16:46.126570 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ INSUFFANESTH->CATECHOL ERRLOWOUTPUT 2022-12-22T16:16:44.934230 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ HRBP 2022-12-22T16:16:46.217468 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ ERRLOWOUTPUT->HRBP ERRCAUTER 2022-12-22T16:16:44.975311 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ HRSAT 2022-12-22T16:16:46.265253 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ ERRCAUTER->HRSAT HREKG 2022-12-22T16:16:46.361350 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ ERRCAUTER->HREKG FIO2 2022-12-22T16:16:45.017814 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ PVSAT 2022-12-22T16:16:45.983211 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ FIO2->PVSAT LVFAILURE 2022-12-22T16:16:45.057619 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ LVFAILURE->STROKEVOLUME LVFAILURE->LVEDVOLUME HISTORY 2022-12-22T16:16:45.578424 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ LVFAILURE->HISTORY DISCONNECT 2022-12-22T16:16:45.098677 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ VENTTUBE 2022-12-22T16:16:45.625771 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ DISCONNECT->VENTTUBE ANAPHYLAXIS 2022-12-22T16:16:45.204754 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ TPR 2022-12-22T16:16:45.347348 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ ANAPHYLAXIS->TPR CO 2022-12-22T16:16:46.313890 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ STROKEVOLUME->CO TPR->CATECHOL BP 2022-12-22T16:16:46.410075 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ TPR->BP PCWP 2022-12-22T16:16:45.494615 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ LVEDVOLUME->PCWP CVP 2022-12-22T16:16:45.675172 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ LVEDVOLUME->CVP VENTMACH->VENTTUBE SAO2 2022-12-22T16:16:46.030872 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ SHUNT->SAO2 VENTTUBE->VENTLUNG VENTTUBE->PRESS VENTLUNG->MINVOL VENTLUNG->VENTALV EXPCO2 2022-12-22T16:16:46.081812 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ VENTLUNG->EXPCO2 ARTCO2 2022-12-22T16:16:45.937114 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ VENTALV->ARTCO2 VENTALV->PVSAT ARTCO2->EXPCO2 ARTCO2->CATECHOL PVSAT->SAO2 SAO2->CATECHOL HR 2022-12-22T16:16:46.171498 image/svg+xml Matplotlib v3.6.2, https://matplotlib.org/ CATECHOL->HR HR->HRBP HR->HRSAT HR->CO HR->HREKG CO->BP

  KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
 0.1167  | 0.8833  |

Gibbs Inference

Gibbs inference with default parameters

Gibbs inference iterations can be stopped : - by the value of error (epsilon) - by the rate of change of epsilon (MinEpsilonRate) - by the number of iteration (MaxIteration) - by the duration of the algorithm (MaxTime)

In [5]:
ie2=gum.GibbsSampling(bn)
ie2.setEpsilon(1e-2)
gnb.showInference(bn,engine=ie2,size="8")
print(ie2.posterior("KINKEDTUBE"))
print(ie2.messageApproximationScheme())
compareInference(ie,ie2)
../_images/notebooks_44-Inference_ApproximateInference_9_0.svg

  KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
 0.1203  | 0.8797  |

stopped with rate=0.00673795
../_images/notebooks_44-Inference_ApproximateInference_9_2.svg

With default parameters, this inference has been stopped by a low value of rate.

Changing parameters

In [6]:
ie2=gum.GibbsSampling(bn)
ie2.setMaxIter(1000)
ie2.setEpsilon(5e-3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())

  INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
 0.8709  | 0.0745  | 0.0545  |

stopped with max iteration=1000
In [7]:
compareInference(ie,ie2)
../_images/notebooks_44-Inference_ApproximateInference_13_0.svg
In [8]:
ie2=gum.GibbsSampling(bn)
ie2.setMaxTime(3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
compareInference(ie,ie2)

  INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
 0.8200  | 0.0400  | 0.1400  |

stopped with epsilon=0.201897
../_images/notebooks_44-Inference_ApproximateInference_14_1.svg

Looking at the convergence

In [9]:
ie2=gum.GibbsSampling(bn)
ie2.setEpsilon(10**-1.8)
ie2.setBurnIn(300)
ie2.setPeriodSize(300)
ie2.setDrawnAtRandom(True)
gnb.animApproximationScheme(ie2)
ie2.makeInference()
../_images/notebooks_44-Inference_ApproximateInference_16_0.svg
In [10]:
compareInference(ie,ie2)
../_images/notebooks_44-Inference_ApproximateInference_17_0.svg

Importance Sampling

In [11]:
ie4=gum.ImportanceSampling(bn)
ie4.setEpsilon(10**-1.8)
ie4.setMaxTime(10) #10 seconds for inference
ie4.setPeriodSize(300)
ie4.makeInference()
compareInference(ie,ie4)
../_images/notebooks_44-Inference_ApproximateInference_19_0.svg

Loopy Gibbs Sampling

Every sampling inference has a ‘hybrid’ version which consists in using a first loopy belief inference as a prior for the probability estimations by sampling.

In [12]:
ie3=gum.LoopyGibbsSampling(bn)

ie3.setEpsilon(10**-1.8)
ie3.setMaxTime(10) #10 seconds for inference
ie3.setPeriodSize(300)
ie3.makeInference()
compareInference(ie,ie3)

../_images/notebooks_44-Inference_ApproximateInference_21_0.svg

Comparison of approximate inference

These computations may be a bit long

In [13]:
def compareAllInference(bn,evs={},epsilon=10**-1.6,epsilonRate=1e-8,maxTime=20):
    ies=[gum.LazyPropagation(bn),
         gum.LoopyBeliefPropagation(bn),
         gum.GibbsSampling(bn),
         gum.LoopyGibbsSampling(bn),
         gum.WeightedSampling(bn),
         gum.LoopyWeightedSampling(bn),
         gum.ImportanceSampling(bn),
         gum.LoopyImportanceSampling(bn)]

    # burn in for Gibbs samplings
    for i in [2,3]:
        ies[i].setBurnIn(300)
        ies[i].setDrawnAtRandom(True)

    for i in range(2,len(ies)):
        ies[i].setEpsilon(epsilon)
        ies[i].setMinEpsilonRate(epsilonRate)
        ies[i].setPeriodSize(300)
        ies[i].setMaxTime(maxTime)

    for i in range(len(ies)):
        ies[i].setEvidence(evs)
        ies[i].makeInference()

    fig, axes = plt.subplots(1,len(ies)-1,figsize=(35, 3),num='gpplot')
    for i in range(len(ies)-1):
        compareInference(ies[0],ies[i+1],axes[i])

Inference stopped by epsilon

In [14]:
compareAllInference(bn,epsilon=1e-1)
../_images/notebooks_44-Inference_ApproximateInference_25_0.svg
In [ ]:
compareAllInference(bn,epsilon=1e-2)

inference stopped by time

In [ ]:
compareAllInference(bn,maxTime=1,epsilon=1e-8)
In [ ]:
compareAllInference(bn,maxTime=2,epsilon=1e-8)

Inference with Evidence

In [ ]:
funny={'BP':1,'PCWP':2,'EXPCO2':0,'HISTORY':0}
compareAllInference(bn,maxTime=1,evs=funny,epsilon=1e-8)
In [ ]:
compareAllInference(bn,maxTime=4,evs=funny,epsilon=1e-8)
In [ ]:
compareAllInference(bn,maxTime=10,evs=funny,epsilon=1e-8)
In [ ]: