diff --git a/python/bindings.cpp b/python/bindings.cpp index 673d22bf49719dbe970bd8bc9502072cc98d62b8..d868cc70cd8448dbd8afeeebb17e0c889cd4055c 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -39,6 +39,7 @@ PYBIND11_MODULE(pygin, m) { .def("mass", [](const Dist& dist) { return toString(dist.mass()); }) .def("normalize", &Dist::normalize) .def("marginal", [](const Dist& dist, const string& x) {return dist.marginal(x);}) + .def("marginal", [](const Dist& dist, const std::vector<string>& vars) {return dist.marginal(vars);}) .def("is_zero", &Dist::isZero) .def("is_trivial", &Dist::isTrivial) .def("update", &Dist::update) diff --git a/src/Dist.cpp b/src/Dist.cpp index e3dc358a7727ed8726c39a9ad6b5845500181902..92a5bd61ca163ff226b2e2f0c3a8892e0a1de41d 100644 --- a/src/Dist.cpp +++ b/src/Dist.cpp @@ -158,12 +158,14 @@ namespace prodigy { return Dist{res}; } - Dist Dist::marginal(std::initializer_list<std::string> vars) const { + Dist Dist::marginal(const std::vector<std::string>& vars) const { for (const auto &v : vars) { EXPECTS(!isKnownAsParam(v), "given symbol is considered a parameter"); } ex res{_gf}; + // go over all variables v we currently know for (const auto &v : _vars) { + // if v is not in the given list vars then project it away if (std::find(vars.begin(), vars.end(), v.first) == vars.end()) { res = res.subs(v.second == 1); } diff --git a/src/Dist.h b/src/Dist.h index 5460d0301b729ad9025b7459cfc0a07ab1864f2a..dce2a6fdb7c1cb30c7db395d8a9c0ebf0b3ce55b 100644 --- a/src/Dist.h +++ b/src/Dist.h @@ -144,7 +144,8 @@ namespace prodigy { * returns marginal distribution in given variable(s) */ Dist marginal(const std::string& x) const; - Dist marginal(std::initializer_list<std::string> vars) const; + Dist marginal(const std::vector<std::string>& vars) const; + Dist marginal(std::initializer_list<std::string> vars) const { return this->marginal(std::vector<std::string>{vars}); } /* * attempts to prove that *this is the zero distribution (mass 0)