Bayesian Parameter Estimation for BNs
Posted on Fri 10 June 2016 in parameter_estimation
Now that ML Parameter Estimation works well, I’ve turned to Bayesian Parameter Estimation (all for discrete variables).
The Bayesian approach is, in practice, very similar to the ML case.
Both involves counting how often each state of the variable obtains in the data, conditional of the parents state.
I thus factored out state_count
-method and put it the estimators.BaseEstimator
class from which all estimators inherit.
MLE is basically done after taking and normalizing the state counts.
For Bayesian Estimation (with dirichlet priors), one additionally specifies so-called pseudo_counts
for each variable state that encode prior beliefs. The state counts are then added to these virtual counts before normalization. My next PR will implement estimators.BayesianEstimator
to compute the CPD parameters for a BayesianModel
, given data.
Update: #696 has been merged and also added basic support for missing data. Bayesian parameter estimation is now working:
>>> import pandas as pd >>> from pgmpy.models import BayesianModel >>> from pgmpy.estimators import BayesianEstimator >>> data = pd.DataFrame(data={'A': [0, 0, 1], 'B': [0, 1, 0], 'C': [1, 1, 0]}) >>> model = BayesianModel([('A', 'C'), ('B', 'C')]) >>> estimator = BayesianEstimator(model, data)
>>> print(estimator.estimate_cpd('A', prior_type="dirichlet", pseudo_counts=[1, 1])) ╒══════╤═════╕ │ A(0) │ 0.6 │ ├──────┼─────┤ │ A(1) │ 0.4 │ ╘══════╧═════╛ >>> print(estimator.estimate_cpd('A', prior_type="dirichlet", pseudo_counts=[2, 2])) ╒══════╤══════════╕ │ A(0) │ 0.571429 │ ├──────┼──────────┤ │ A(1) │ 0.428571 │ ╘══════╧══════════╛ >>> print(estimator.estimate_cpd('A', prior_type="dirichlet", pseudo_counts=[5, 5])) ╒══════╤══════════╕ │ A(0) │ 0.538462 │ ├──────┼──────────┤ │ A(1) │ 0.461538 │ ╘══════╧══════════╛ >>> for cpd in estimator.get_parameters(prior_type="BDeu", equivalent_sample_size=10): ... print(cpd) ... ╒══════╤══════════╕ │ A(0) │ 0.538462 │ ├──────┼──────────┤ │ A(1) │ 0.461538 │ ╘══════╧══════════╛ ╒══════╤═════════════════════╤═════════════════════╤═════════════════════╤══════╕ │ A │ A(0) │ A(0) │ A(1) │ A(1) │ ├──────┼─────────────────────┼─────────────────────┼─────────────────────┼──────┤ │ B │ B(0) │ B(1) │ B(0) │ B(1) │ ├──────┼─────────────────────┼─────────────────────┼─────────────────────┼──────┤ │ C(0) │ 0.35714285714285715 │ 0.35714285714285715 │ 0.6428571428571429 │ 0.5 │ ├──────┼─────────────────────┼─────────────────────┼─────────────────────┼──────┤ │ C(1) │ 0.6428571428571429 │ 0.6428571428571429 │ 0.35714285714285715 │ 0.5 │ ╘══════╧═════════════════════╧═════════════════════╧═════════════════════╧══════╛ ╒══════╤══════════╕ │ B(0) │ 0.538462 │ ├──────┼──────────┤ │ B(1) │ 0.461538 │ ╘══════╧══════════╛