diff --git a/sql/provsql.sql b/sql/provsql.sql index acaaaf0..8c0b2e8 100644 --- a/sql/provsql.sql +++ b/sql/provsql.sql @@ -628,6 +628,12 @@ CREATE OR REPLACE FUNCTION probability_evaluate( RETURNS DOUBLE PRECISION AS 'provsql','probability_evaluate' LANGUAGE C; +CREATE OR REPLACE FUNCTION shapley( + token UUID, + variable UUID) + RETURNS DOUBLE PRECISION AS + 'provsql','shapley' LANGUAGE C; + CREATE OR REPLACE FUNCTION view_circuit( token UUID, token2desc regclass, diff --git a/src/BooleanCircuit.cpp b/src/BooleanCircuit.cpp index fd9d2cc..d9c770b 100644 --- a/src/BooleanCircuit.cpp +++ b/src/BooleanCircuit.cpp @@ -4,6 +4,7 @@ extern "C" { #include #include +#include "provsql_shmem.h" } #include @@ -42,7 +43,7 @@ gate_t BooleanCircuit::setGate(BooleanGate type) gate_t BooleanCircuit::setGate(const uuid &u, BooleanGate type) { - auto id = Circuit::setGate(u, type); + auto id = Circuit::setGate(u, type); if(type == BooleanGate::IN) { setProb(id,1.); inputs.insert(id); @@ -79,30 +80,30 @@ std::string BooleanCircuit::toString(gate_t g) const std::string result; switch(getGateType(g)) { - case BooleanGate::IN: - if(getProb(g)==0.) { - return "⊥"; - } else if(getProb(g)==1.) { - return "⊤"; - } else { - return to_string(g)+"["+std::to_string(getProb(g))+"]"; - } - case BooleanGate::MULIN: - return "{" + to_string(*getWires(g).begin()) + "=" + std::to_string(getInfo(g)) + "}[" + std::to_string(getProb(g)) + "]"; - case BooleanGate::NOT: - op="¬"; - break; - case BooleanGate::UNDETERMINED: - op="?"; - break; - case BooleanGate::AND: - op="∧"; - break; - case BooleanGate::OR: - op="∨"; - break; - case BooleanGate::MULVAR: - ; // already dealt with in MULIN + case BooleanGate::IN: + if(getProb(g)==0.) { + return "⊥"; + } else if(getProb(g)==1.) { + return "⊤"; + } else { + return to_string(g)+"["+std::to_string(getProb(g))+"]"; + } + case BooleanGate::MULIN: + return "{" + to_string(*getWires(g).begin()) + "=" + std::to_string(getInfo(g)) + "}[" + std::to_string(getProb(g)) + "]"; + case BooleanGate::NOT: + op="¬"; + break; + case BooleanGate::UNDETERMINED: + op="?"; + break; + case BooleanGate::AND: + op="∧"; + break; + case BooleanGate::OR: + op="∨"; + break; + case BooleanGate::MULVAR: + ; // already dealt with in MULIN } if(getWires(g).empty()) { @@ -129,21 +130,21 @@ bool BooleanCircuit::evaluate(gate_t g, const std::unordered_set &sample bool disjunction=false; switch(getGateType(g)) { - case BooleanGate::IN: - return sampled.find(g)!=sampled.end(); - case BooleanGate::MULIN: - case BooleanGate::MULVAR: - throw CircuitException("Monte-Carlo sampling not implemented on multivalued inputs"); - case BooleanGate::NOT: - return !evaluate(*(getWires(g).begin()), sampled); - case BooleanGate::AND: - disjunction = false; - break; - case BooleanGate::OR: - disjunction = true; - break; - case BooleanGate::UNDETERMINED: - throw CircuitException("Incorrect gate type"); + case BooleanGate::IN: + return sampled.find(g)!=sampled.end(); + case BooleanGate::MULIN: + case BooleanGate::MULVAR: + throw CircuitException("Monte-Carlo sampling not implemented on multivalued inputs"); + case BooleanGate::NOT: + return !evaluate(*(getWires(g).begin()), sampled); + case BooleanGate::AND: + disjunction = false; + break; + case BooleanGate::OR: + disjunction = true; + break; + case BooleanGate::UNDETERMINED: + throw CircuitException("Incorrect gate type"); } for(auto s: getWires(g)) { @@ -174,7 +175,7 @@ double BooleanCircuit::monteCarlo(gate_t g, unsigned samples) const if(evaluate(g, sampled)) ++success; - + if(provsql_interrupted) throw CircuitException("Interrupted after "+std::to_string(i+1)+" samples"); } @@ -183,7 +184,7 @@ double BooleanCircuit::monteCarlo(gate_t g, unsigned samples) const } double BooleanCircuit::possibleWorlds(gate_t g) const -{ +{ if(inputs.size()>=8*sizeof(unsigned long long)) throw CircuitException("Too many possible worlds to iterate over"); @@ -207,7 +208,7 @@ double BooleanCircuit::possibleWorlds(gate_t g) const if(evaluate(g, s)) totalp+=p; - + if(provsql_interrupted) throw CircuitException("Interrupted"); } @@ -216,50 +217,50 @@ double BooleanCircuit::possibleWorlds(gate_t g) const } std::string BooleanCircuit::Tseytin(gate_t g, bool display_prob=false) const { - std::vector> clauses; - + std::vector > clauses; + // Tseytin transformation for(gate_t i{0}; i(i)+1}; - std::vector c = {id}; - for(auto s: getWires(i)) { - clauses.push_back({-id, static_cast(s)+1}); - c.push_back(-static_cast(s)-1); - } - clauses.push_back(c); - break; - } + case BooleanGate::AND: + { + int id{static_cast(i)+1}; + std::vector c = {id}; + for(auto s: getWires(i)) { + clauses.push_back({-id, static_cast(s)+1}); + c.push_back(-static_cast(s)-1); + } + clauses.push_back(c); + break; + } - case BooleanGate::OR: - { - int id{static_cast(i)+1}; - std::vector c = {-id}; - for(auto s: getWires(i)) { - clauses.push_back({id, -static_cast(s)-1}); - c.push_back(static_cast(s)+1); - } - clauses.push_back(c); - } - break; + case BooleanGate::OR: + { + int id{static_cast(i)+1}; + std::vector c = {-id}; + for(auto s: getWires(i)) { + clauses.push_back({id, -static_cast(s)-1}); + c.push_back(static_cast(s)+1); + } + clauses.push_back(c); + } + break; - case BooleanGate::NOT: - { - int id=static_cast(i)+1; - auto s=*getWires(i).begin(); - clauses.push_back({-id,-static_cast(s)-1}); - clauses.push_back({id,static_cast(s)+1}); - break; - } + case BooleanGate::NOT: + { + int id=static_cast(i)+1; + auto s=*getWires(i).begin(); + clauses.push_back({-id,-static_cast(s)-1}); + clauses.push_back({id,static_cast(s)+1}); + break; + } - case BooleanGate::MULIN: - throw CircuitException("Multivalued inputs should have been removed by then."); - case BooleanGate::MULVAR: - case BooleanGate::IN: - case BooleanGate::UNDETERMINED: - ; + case BooleanGate::MULIN: + throw CircuitException("Multivalued inputs should have been removed by then."); + case BooleanGate::MULVAR: + case BooleanGate::IN: + case BooleanGate::UNDETERMINED: + ; } } clauses.push_back({(int)g+1}); @@ -274,7 +275,7 @@ std::string BooleanCircuit::Tseytin(gate_t g, bool display_prob=false) const { ofs << "p cnf " << gates.size() << " " << clauses.size() << "\n"; - for(unsigned i=0;i> nnf >> nb_nodes >> nb_edges >> nb_variables; - + if(nb_variables!=gates.size()) throw CircuitException("Unreadable d-DNNF (wrong number of variables: " + std::to_string(nb_variables) +" vs " + std::to_string(gates.size()) + ")"); - + getline(ifs,line); } @@ -363,7 +364,7 @@ double BooleanCircuit::compilation(gate_t g, std::string compiler) const { unsigned i=0; do { std::stringstream ss(line); - + std::string c; ss >> c; @@ -432,7 +433,7 @@ double BooleanCircuit::compilation(gate_t g, std::string compiler) const { dnnf.addWire(dnnf.getGate(c), and_gate); dnnf.addWire(and_gate, id2); for(auto leaf : decisions) { - gate_t leaf_gate; + gate_t leaf_gate; if(leaf<0) { leaf_gate = dnnf.setGate("i"+std::to_string(leaf), BooleanGate::IN, 1-prob[-leaf-1]); } else { @@ -463,14 +464,14 @@ double BooleanCircuit::WeightMC(gate_t g, std::string opt) const { std::string filename=BooleanCircuit::Tseytin(g, true); //opt of the form 'delta;epsilon' - std::stringstream ssopt(opt); + std::stringstream ssopt(opt); std::string delta_s, epsilon_s; getline(ssopt, delta_s, ';'); getline(ssopt, epsilon_s, ';'); double delta = 0; - try { - delta=stod(delta_s); + try { + delta=stod(delta_s); } catch (std::invalid_argument &e) { delta=0; } @@ -504,7 +505,7 @@ double BooleanCircuit::WeightMC(gate_t g, std::string opt) const { std::stringstream ss(prev_line); std::string result; ss >> result >> result >> result >> result >> result; - + std::istringstream iss(result); std::string val, exp; getline(iss, val, 'x'); @@ -526,74 +527,74 @@ double BooleanCircuit::WeightMC(gate_t g, std::string opt) const { } double BooleanCircuit::independentEvaluationInternal( - gate_t g, std::set &seen) const + gate_t g, std::set &seen) const { double result=1.; switch(getGateType(g)) { - case BooleanGate::AND: - for(const auto &c: getWires(g)) { - result*=independentEvaluationInternal(c, seen); - } - break; + case BooleanGate::AND: + for(const auto &c: getWires(g)) { + result*=independentEvaluationInternal(c, seen); + } + break; - case BooleanGate::OR: - { - // We collect probability among each group of children, where we - // group MULIN gates with the same key var together - std::map groups; - std::set local_mulins; - std::set> mulin_seen; - - for(const auto &c: getWires(g)) { - auto group = c; - if(getGateType(c) == BooleanGate::MULIN) { - group = *getWires(c).begin(); - if(local_mulins.find(g)==local_mulins.end()) { - if(seen.find(g)!=seen.end()) - throw CircuitException("Not an independent circuit"); - else - seen.insert(g); - } - auto p = std::make_pair(group, getInfo(c)); - if(mulin_seen.find(p)==mulin_seen.end()) { - groups[group] += getProb(c); - mulin_seen.insert(p); - } - } else - groups[group] = independentEvaluationInternal(c, seen); + case BooleanGate::OR: + { + // We collect probability among each group of children, where we + // group MULIN gates with the same key var together + std::map groups; + std::set local_mulins; + std::set > mulin_seen; + + for(const auto &c: getWires(g)) { + auto group = c; + if(getGateType(c) == BooleanGate::MULIN) { + group = *getWires(c).begin(); + if(local_mulins.find(g)==local_mulins.end()) { + if(seen.find(g)!=seen.end()) + throw CircuitException("Not an independent circuit"); + else + seen.insert(g); + } + auto p = std::make_pair(group, getInfo(c)); + if(mulin_seen.find(p)==mulin_seen.end()) { + groups[group] += getProb(c); + mulin_seen.insert(p); } + } else + groups[group] = independentEvaluationInternal(c, seen); + } - for(const auto [k, v]: groups) - result *= 1-v; - result = 1-result; - } - break; + for(const auto [k, v]: groups) + result *= 1-v; + result = 1-result; + } + break; - case BooleanGate::NOT: - result=1-independentEvaluationInternal(*getWires(g).begin(), seen); - break; + case BooleanGate::NOT: + result=1-independentEvaluationInternal(*getWires(g).begin(), seen); + break; - case BooleanGate::IN: - if(seen.find(g)!=seen.end()) - throw CircuitException("Not an independent circuit"); - seen.insert(g); - result=getProb(g); - break; - - case BooleanGate::MULIN: - { - auto child = *getWires(g).begin(); - if(seen.find(child)!=seen.end()) - throw CircuitException("Not an independent circuit"); - seen.insert(child); - result=getProb(g); - } - break; + case BooleanGate::IN: + if(seen.find(g)!=seen.end()) + throw CircuitException("Not an independent circuit"); + seen.insert(g); + result=getProb(g); + break; - case BooleanGate::UNDETERMINED: - case BooleanGate::MULVAR: - throw CircuitException("Bad gate"); + case BooleanGate::MULIN: + { + auto child = *getWires(g).begin(); + if(seen.find(child)!=seen.end()) + throw CircuitException("Not an independent circuit"); + seen.insert(child); + result=getProb(g); + } + break; + + case BooleanGate::UNDETERMINED: + case BooleanGate::MULVAR: + throw CircuitException("Bad gate"); } return result; @@ -621,11 +622,11 @@ unsigned BooleanCircuit::getInfo(gate_t g) const } void BooleanCircuit::rewriteMultivaluedGatesRec( - const std::vector &muls, - const std::vector &cumulated_probs, - unsigned start, - unsigned end, - std::vector &prefix) + const std::vector &muls, + const std::vector &cumulated_probs, + unsigned start, + unsigned end, + std::vector &prefix) { if(start==end) { getWires(muls[start]) = prefix; @@ -634,9 +635,9 @@ void BooleanCircuit::rewriteMultivaluedGatesRec( unsigned mid = (start+end)/2; auto g = setGate( - BooleanGate::IN, - (cumulated_probs[mid+1] - cumulated_probs[start]) / - (cumulated_probs[end] - cumulated_probs[start])); + BooleanGate::IN, + (cumulated_probs[mid+1] - cumulated_probs[start]) / + (cumulated_probs[end] - cumulated_probs[start])); auto not_g = setGate(BooleanGate::NOT); getWires(not_g).push_back(g); @@ -658,7 +659,7 @@ static constexpr bool almost_equals(double a, double b) void BooleanCircuit::rewriteMultivaluedGates() { - std::map> var2mulinput; + std::map > var2mulinput; for(auto mul: mulinputs) { var2mulinput[*getWires(mul).begin()].push_back(mul); } @@ -669,14 +670,14 @@ void BooleanCircuit::rewriteMultivaluedGates() const unsigned n = muls.size(); std::vector cumulated_probs(n); double cumulated_prob=0.; - + for(unsigned i=0; i::type>(muls[i])] = BooleanGate::AND; getWires(muls[i]).clear(); } - + std::vector prefix; prefix.reserve(static_cast(log(n)/log(2)+2)); if(!almost_equals(cumulated_probs[n-1],1.)) { @@ -685,3 +686,96 @@ void BooleanCircuit::rewriteMultivaluedGates() rewriteMultivaluedGatesRec(muls, cumulated_probs, 0, n-1, prefix); } } + +BooleanCircuit::BooleanCircuit(pg_uuid_t token) +{ + std::set to_process, processed; + to_process.insert(token); + + BooleanCircuit c; + + LWLockAcquire(provsql_shared_state->lock, LW_SHARED); + while(!to_process.empty()) { + pg_uuid_t uuid = *to_process.begin(); + to_process.erase(to_process.begin()); + processed.insert(uuid); + std::string f{uuid2string(uuid)}; + + bool found; + provsqlHashEntry *entry = reinterpret_cast(hash_search(provsql_hash, &uuid, HASH_FIND, &found)); + + if(!found) + c.setGate(f, BooleanGate::MULVAR); + else { + gate_t id; + + switch(entry->type) { + case gate_input: + if(isnan(entry->prob)) { + LWLockRelease(provsql_shared_state->lock); + elog(ERROR, "Missing probability for input token"); + } + id = c.setGate(f, BooleanGate::IN, entry->prob); + break; + + case gate_mulinput: + if(isnan(entry->prob)) { + LWLockRelease(provsql_shared_state->lock); + elog(ERROR, "Missing probability for input token"); + } + id = c.setGate(f, BooleanGate::MULIN, entry->prob); + c.addWire( + id, + c.getGate(uuid2string(provsql_shared_state->wires[entry->children_idx]))); + c.setInfo(id, entry->info1); + break; + + case gate_times: + case gate_project: + case gate_eq: + case gate_monus: + case gate_one: + id = c.setGate(f, BooleanGate::AND); + break; + + case gate_plus: + case gate_zero: + id = c.setGate(f, BooleanGate::OR); + break; + + default: + elog(ERROR, "Wrong type of gate in circuit"); + } + + if(entry->nb_children > 0) { + if(entry->type == gate_monus) { + auto id_not = c.setGate(BooleanGate::NOT); + auto child1 = provsql_shared_state->wires[entry->children_idx]; + auto child2 = provsql_shared_state->wires[entry->children_idx+1]; + c.addWire( + id, + c.getGate(uuid2string(child1))); + c.addWire(id, id_not); + c.addWire( + id_not, + c.getGate(uuid2string(child2))); + if(processed.find(child1)==processed.end()) + to_process.insert(child1); + if(processed.find(child2)==processed.end()) + to_process.insert(child2); + } else { + for(unsigned i=0; inb_children; ++i) { + auto child = provsql_shared_state->wires[entry->children_idx+i]; + + c.addWire( + id, + c.getGate(uuid2string(child))); + if(processed.find(child)==processed.end()) + to_process.insert(child); + } + } + } + } + } + LWLockRelease(provsql_shared_state->lock); +} diff --git a/src/BooleanCircuit.h b/src/BooleanCircuit.h index f7a9fc5..26da2ea 100644 --- a/src/BooleanCircuit.h +++ b/src/BooleanCircuit.h @@ -8,48 +8,56 @@ #include #include "Circuit.hpp" +#include "provsql_utils_cpp.h" enum class BooleanGate { UNDETERMINED, AND, OR, NOT, IN, MULIN, MULVAR }; class BooleanCircuit : public Circuit { - private: - bool evaluate(gate_t g, const std::unordered_set &sampled) const; - std::string Tseytin(gate_t g, bool display_prob) const; - double independentEvaluationInternal(gate_t g, std::set &seen) const; - void rewriteMultivaluedGatesRec( - const std::vector &muls, - const std::vector &cumulated_probs, - unsigned start, - unsigned end, - std::vector &prefix); - - protected: - std::set inputs; - std::set mulinputs; - std::vector prob; - std::map info; - - public: - gate_t addGate() override; - gate_t setGate(BooleanGate t) override; - gate_t setGate(const uuid &u, BooleanGate t) override; - gate_t setGate(BooleanGate t, double p); - gate_t setGate(const uuid &u, BooleanGate t, double p); - void setProb(gate_t g, double p) { prob[static_cast::type>(g)]=p; } - double getProb(gate_t g) const { return prob[static_cast::type>(g)]; } - void setInfo(gate_t g, unsigned info); - unsigned getInfo(gate_t g) const; - - double possibleWorlds(gate_t g) const; - double compilation(gate_t g, std::string compiler) const; - double monteCarlo(gate_t g, unsigned samples) const; - double WeightMC(gate_t g, std::string opt) const; - double independentEvaluation(gate_t g) const; - void rewriteMultivaluedGates(); - - virtual std::string toString(gate_t g) const override; - - friend class dDNNFTreeDecompositionBuilder; +private: +bool evaluate(gate_t g, const std::unordered_set &sampled) const; +std::string Tseytin(gate_t g, bool display_prob) const; +double independentEvaluationInternal(gate_t g, std::set &seen) const; +void rewriteMultivaluedGatesRec( + const std::vector &muls, + const std::vector &cumulated_probs, + unsigned start, + unsigned end, + std::vector &prefix); + +protected: +std::set inputs; +std::set mulinputs; +std::vector prob; +std::map info; + +public: +BooleanCircuit() { +} +explicit BooleanCircuit(pg_uuid_t token); +gate_t addGate() override; +gate_t setGate(BooleanGate t) override; +gate_t setGate(const uuid &u, BooleanGate t) override; +gate_t setGate(BooleanGate t, double p); +gate_t setGate(const uuid &u, BooleanGate t, double p); +void setProb(gate_t g, double p) { + prob[static_cast::type>(g)]=p; +} +double getProb(gate_t g) const { + return prob[static_cast::type>(g)]; +} +void setInfo(gate_t g, unsigned info); +unsigned getInfo(gate_t g) const; + +double possibleWorlds(gate_t g) const; +double compilation(gate_t g, std::string compiler) const; +double monteCarlo(gate_t g, unsigned samples) const; +double WeightMC(gate_t g, std::string opt) const; +double independentEvaluation(gate_t g) const; +void rewriteMultivaluedGates(); + +virtual std::string toString(gate_t g) const override; + +friend class dDNNFTreeDecompositionBuilder; }; #endif /* BOOLEAN_CIRCUIT_H */ diff --git a/src/probability_evaluate.cpp b/src/probability_evaluate.cpp index 4e06ba9..1fb3a11 100644 --- a/src/probability_evaluate.cpp +++ b/src/probability_evaluate.cpp @@ -6,8 +6,8 @@ extern "C" { #include "executor/spi.h" #include "provsql_shmem.h" #include "provsql_utils.h" - - PG_FUNCTION_INFO_V1(probability_evaluate); + +PG_FUNCTION_INFO_V1(probability_evaluate); } #include @@ -33,95 +33,7 @@ bool operator<(const pg_uuid_t a, const pg_uuid_t b) static Datum probability_evaluate_internal (pg_uuid_t token, const string &method, const string &args) { - std::set to_process, processed; - to_process.insert(token); - - BooleanCircuit c; - - LWLockAcquire(provsql_shared_state->lock, LW_SHARED); - while(!to_process.empty()) { - pg_uuid_t uuid = *to_process.begin(); - to_process.erase(to_process.begin()); - processed.insert(uuid); - std::string f{uuid2string(uuid)}; - - bool found; - provsqlHashEntry *entry = (provsqlHashEntry *) hash_search(provsql_hash, &uuid, HASH_FIND, &found); - - gate_t id; - - if(!found) - id = c.setGate(f, BooleanGate::MULVAR); - else { - switch(entry->type) { - case gate_input: - if(isnan(entry->prob)) { - LWLockRelease(provsql_shared_state->lock); - elog(ERROR, "Missing probability for input token"); - } - id = c.setGate(f, BooleanGate::IN, entry->prob); - break; - - case gate_mulinput: - if(isnan(entry->prob)) { - LWLockRelease(provsql_shared_state->lock); - elog(ERROR, "Missing probability for input token"); - } - id = c.setGate(f, BooleanGate::MULIN, entry->prob); - c.addWire( - id, - c.getGate(uuid2string(provsql_shared_state->wires[entry->children_idx]))); - c.setInfo(id, entry->info1); - break; - - case gate_times: - case gate_project: - case gate_eq: - case gate_monus: - case gate_one: - id = c.setGate(f, BooleanGate::AND); - break; - - case gate_plus: - case gate_zero: - id = c.setGate(f, BooleanGate::OR); - break; - - default: - elog(ERROR, "Wrong type of gate in circuit"); - } - - if(entry->nb_children > 0) { - if(entry->type == gate_monus) { - auto id_not = c.setGate(BooleanGate::NOT); - auto child1 = provsql_shared_state->wires[entry->children_idx]; - auto child2 = provsql_shared_state->wires[entry->children_idx+1]; - c.addWire( - id, - c.getGate(uuid2string(child1))); - c.addWire(id, id_not); - c.addWire( - id_not, - c.getGate(uuid2string(child2))); - if(processed.find(child1)==processed.end()) - to_process.insert(child1); - if(processed.find(child2)==processed.end()) - to_process.insert(child2); - } else { - for(unsigned i=0;inb_children;++i) { - auto child = provsql_shared_state->wires[entry->children_idx+i]; - - c.addWire( - id, - c.getGate(uuid2string(child))); - if(processed.find(child)==processed.end()) - to_process.insert(child); - } - } - } - } - } - LWLockRelease(provsql_shared_state->lock); + BooleanCircuit c(token); double result; auto gate = c.getGate(uuid2string(token)); @@ -144,8 +56,8 @@ static Datum probability_evaluate_internal // Default evaluation, use independent, tree-decomposition, and // compilation in order until one works try { - result = c.independentEvaluation(gate); - processed = true; + result = c.independentEvaluation(gate); + processed = true; } catch(CircuitException &) {} } @@ -159,12 +71,12 @@ static Datum probability_evaluate_internal try { samples = stoi(args); - } catch(std::invalid_argument &e) { + } catch(const std::invalid_argument &e) { } if(samples<=0) elog(ERROR, "Invalid number of samples: '%s'", args.c_str()); - + result = c.monteCarlo(gate, samples); } else if(method=="possible-worlds") { if(!args.empty()) @@ -211,7 +123,7 @@ static Datum probability_evaluate_internal provsql_interrupted = false; signal (SIGINT, prev_sigint_handler); - + // Avoid rounding errors that make probability outside of [0,1] if(result>1.) result=1.; @@ -227,7 +139,7 @@ Datum probability_evaluate(PG_FUNCTION_ARGS) Datum token = PG_GETARG_DATUM(0); string method; string args; - + if(PG_ARGISNULL(0)) PG_RETURN_NULL(); diff --git a/src/shapley.cpp b/src/shapley.cpp new file mode 100644 index 0000000..43743f3 --- /dev/null +++ b/src/shapley.cpp @@ -0,0 +1,83 @@ +extern "C" { +#include "postgres.h" +#include "fmgr.h" +#include "catalog/pg_type.h" +#include "utils/uuid.h" +#include "executor/spi.h" +#include "provsql_shmem.h" +#include "provsql_utils.h" + +PG_FUNCTION_INFO_V1(shapley); +} + +#include "BooleanCircuit.h" +#include "provsql_utils_cpp.h" +#include "dDNNFTreeDecompositionBuilder.h" + +using namespace std; + +static void provsql_sigint_handler (int) +{ + provsql_interrupted = true; +} + +static bool operator<(const pg_uuid_t a, const pg_uuid_t b) +{ + return memcmp(&a, &b, sizeof(pg_uuid_t))<0; +} + +static Datum shapley_internal + (pg_uuid_t token, pg_uuid_t variable) +{ + BooleanCircuit c(token); + auto gate = c.getGate(uuid2string(token)); + + provsql_interrupted = false; + void (*prev_sigint_handler)(int); + prev_sigint_handler = signal(SIGINT, provsql_sigint_handler); + + double result=0.; + + try { + TreeDecomposition td(c); + auto dnnf{ + dDNNFTreeDecompositionBuilder{ + c, uuid2string(token), td}.build() + }; + result = dnnf.shapley(dnnf.getGate("root"),"token"); + } catch(TreeDecompositionException &) { + provsql_interrupted = false; + signal (SIGINT, prev_sigint_handler); + elog(ERROR, "Treewidth greater than %u", TreeDecomposition::MAX_TREEWIDTH); + } + + provsql_interrupted = false; + signal (SIGINT, prev_sigint_handler); + + // Avoid rounding errors that make expected Shapley value outside of [-1,1] + if(result>1.) + result=1.; + else if(result<-1.) + result=-1.; + + PG_RETURN_FLOAT8(result); +} + +Datum shapley(PG_FUNCTION_ARGS) +{ + try { + Datum token = PG_GETARG_DATUM(0); + Datum variable = PG_GETARG_DATUM(1) + + if(PG_ARGISNULL(0) || PG_ARGISNULL(1)) + PG_RETURN_NULL(); + + return shapley_internal(*DatumGetUUIDP(token), *DatumGetUUIDP(variable)); + } catch(const std::exception &e) { + elog(ERROR, "shapley: %s", e.what()); + } catch(...) { + elog(ERROR, "shapley: Unknown exception"); + } + + PG_RETURN_NULL(); +}