Skip to content

Commit

Permalink
Shapley all vars function
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreSenellart committed Nov 10, 2023
1 parent 5e6cfcd commit 7cc2d33
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 17 deletions.
18 changes: 14 additions & 4 deletions sql/provsql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ CREATE OR REPLACE FUNCTION provenance_evaluate(
monus_function regproc = NULL,
delta_function regproc = NULL)
RETURNS anyelement AS
'provsql','provenance_evaluate' LANGUAGE C;
'provsql','provenance_evaluate' LANGUAGE C STABLE;

CREATE OR REPLACE FUNCTION aggregation_evaluate(
token UUID,
Expand All @@ -621,22 +621,32 @@ CREATE OR REPLACE FUNCTION aggregation_evaluate(
monus_function regproc = NULL,
delta_function regproc = NULL)
RETURNS anyelement AS
'provsql','aggregation_evaluate' LANGUAGE C;
'provsql','aggregation_evaluate' LANGUAGE C STABLE;

CREATE OR REPLACE FUNCTION probability_evaluate(
token UUID,
method text = NULL,
arguments text = NULL)
RETURNS DOUBLE PRECISION AS
'provsql','probability_evaluate' LANGUAGE C;
'provsql','probability_evaluate' LANGUAGE C STABLE;

CREATE OR REPLACE FUNCTION shapley(
token UUID,
variable UUID,
method text = NULL,
arguments text = NULL)
RETURNS DOUBLE PRECISION AS
'provsql','shapley' LANGUAGE C;
'provsql','shapley' LANGUAGE C STABLE;

CREATE OR REPLACE FUNCTION shapley_all_vars(
IN token UUID,
IN method text = NULL,
IN arguments text = NULL,
OUT variable UUID,
OUT shapley DOUBLE PRECISION)
RETURNS SETOF record AS
'provsql', 'shapley_all_vars'
LANGUAGE C STABLE;

CREATE OR REPLACE FUNCTION view_circuit(
token UUID,
Expand Down
3 changes: 3 additions & 0 deletions src/BooleanCircuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ gate_t setGate(BooleanGate t) override;
gate_t setGate(const uuid &u, BooleanGate t) override;
gate_t setGate(BooleanGate t, double p);
gate_t setGate(const uuid &u, BooleanGate t, double p);
const std::set<gate_t> &getInputs() const {
return inputs;
}
void setProb(gate_t g, double p) {
if(!probabilistic && p!=1.)
probabilistic=true;
Expand Down
6 changes: 6 additions & 0 deletions src/dDNNF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ double dDNNF::shapley(gate_t var) const {

result *= getProb(var);

// Avoid rounding errors that make expected Shapley value outside of [-1,1]
if(result>1.)
result=1.;
else if(result<-1.)
result=-1.;

return result;
}

Expand Down
71 changes: 58 additions & 13 deletions src/shapley.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ extern "C" {
#include "provsql_utils.h"

PG_FUNCTION_INFO_V1(shapley);
PG_FUNCTION_INFO_V1(shapley_all_vars);
}

#include "BooleanCircuit.h"
Expand All @@ -33,21 +34,8 @@ static double shapley_internal

auto var_gate=dd.getGate(uuid2string(variable));

/*
std::string filename("/tmp/export.dd");
std::ofstream o(filename.c_str());
o << dd.exportCircuit(root);
o.close();
*/

double result = dd.shapley(var_gate);

// Avoid rounding errors that make expected Shapley value outside of [-1,1]
if(result>1.)
result=1.;
else if(result<-1.)
result=-1.;

return result;
}

Expand Down Expand Up @@ -81,3 +69,60 @@ Datum shapley(PG_FUNCTION_ARGS)

PG_RETURN_NULL();
}

Datum shapley_all_vars(PG_FUNCTION_ARGS)
{
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;

MemoryContext per_query_ctx = rsinfo->econtext->ecxt_per_query_memory;
MemoryContext oldcontext = MemoryContextSwitchTo(per_query_ctx);

TupleDesc tupdesc = rsinfo->expectedDesc;
Tuplestorestate *tupstore = tuplestore_begin_heap(rsinfo->allowedModes & SFRM_Materialize_Random, false, work_mem);

rsinfo->returnMode = SFRM_Materialize;
rsinfo->setResult = tupstore;

if(!PG_ARGISNULL(0)) {
pg_uuid_t token = *DatumGetUUIDP(PG_GETARG_DATUM(0));

std::string method;
if(!PG_ARGISNULL(1)) {
text *t = PG_GETARG_TEXT_P(1);
method = string(VARDATA(t),VARSIZE(t)-VARHDRSZ);
}

std::string args;
if(!PG_ARGISNULL(2)) {
text *t = PG_GETARG_TEXT_P(2);
args = string(VARDATA(t),VARSIZE(t)-VARHDRSZ);
}

BooleanCircuit c = createBooleanCircuit(token);

dDNNF dd = c.makeDD(c.getGate(uuid2string(token)), method, args);
dd.makeSmooth();
dd.makeGatesBinary(BooleanGate::AND);

for(auto &v_circuit_gate: c.getInputs()) {
auto var_uuid_string = c.getUUID(v_circuit_gate);
auto var_gate=dd.getGate(var_uuid_string);
pg_uuid_t *uuidp = reinterpret_cast<pg_uuid_t*>(palloc(UUID_LEN));
*uuidp = string2uuid(var_uuid_string);

double result = dd.shapley(var_gate);

Datum values[2] = {
UUIDPGetDatum(uuidp), Float8GetDatum(result)
};
bool nulls[sizeof(values)] = {0, 0};

tuplestore_putvalues(tupstore, tupdesc, values, nulls);
}
}

tuplestore_donestoring(tupstore);
MemoryContextSwitchTo(oldcontext);

PG_RETURN_NULL();
}
16 changes: 16 additions & 0 deletions test/expected/shapley.out
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,19 @@
Nancy | Paris | 0.048
(7 rows)

remove_provenance
-------------------

(1 row)

city | round
----------+-------
Paris | 0.060
Paris | 0.140
Paris | 0.210
Berlin | 0.120
Berlin | 0.420
New York | 0.080
New York | 0.180
(7 rows)

18 changes: 18 additions & 0 deletions test/sql/shapley.sql
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,21 @@ DROP TABLE shapley_result;
DO $$ BEGIN
PERFORM set_prob(provenance(), id*1./10) FROM personnel;
END $$;

CREATE TABLE shapley_result1 AS
SELECT city, provenance() FROM (
(SELECT DISTINCT city FROM personnel)
EXCEPT
(SELECT p1.city
FROM personnel p1, personnel p2
WHERE p1.city = p2.city AND p1.id < p2.id
GROUP BY p1.city
ORDER BY p1.city)
) t;
SELECT remove_provenance('shapley_result1');
CREATE TABLE shapley_result2 AS
SELECT * FROM shapley_result1, shapley_all_vars(provenance);

SELECT city, ROUND(shapley::numeric,3) FROM shapley_result2;
DROP TABLE shapley_result1;
DROP TABLE shapley_result2;

0 comments on commit 7cc2d33

Please sign in to comment.