diff --git a/poc/Makefile b/poc/Makefile index 35c52717..58a27886 100644 --- a/poc/Makefile +++ b/poc/Makefile @@ -21,6 +21,11 @@ setup: test: pyfiles sage test_oprf.sage +examples: pyfiles + sage example_oprf.sage + sage example_voprf.sage + sage example_poprf.sage + vectors: pyfiles @echo "Removing vectors folder, if present" @rm -rf vectors @@ -34,4 +39,4 @@ clean: .PHONY: distclean distclean: clean - rm -rf vectors ascii \ No newline at end of file + rm -rf vectors ascii diff --git a/poc/example_oprf.sage b/poc/example_oprf.sage new file mode 100644 index 00000000..fbc41533 --- /dev/null +++ b/poc/example_oprf.sage @@ -0,0 +1,53 @@ +#!/usr/bin/sage +# vim: syntax=python + +"""Exemplifies a run of the OPRF protocol""" + +import os +import sys + +try: + from sagelib.test_drng import TestDRNG + from sagelib.oprf \ + import DeriveKeyPair, SetupOPRFServer, SetupOPRFClient, MODE_OPRF, \ + oprf_ciphersuites, ciphersuite_p256_sha256 + +except ImportError as err: + sys.exit("Try running `make setup && make clean pyfiles`. Full error: " + err) + +to_hex = lambda x: "".join(["{:02x}".format(i) for i in x]) + +if __name__ == "__main__": + # Offline Setup + rng = TestDRNG('prng-seed'.encode('utf-8')) + suite = oprf_ciphersuites[ciphersuite_p256_sha256] + Ns = suite.group.scalar_byte_length() + info = b'info specific for this key' + seed = os.urandom(Ns) + skS, _ = DeriveKeyPair(MODE_OPRF, suite, seed, info) + + client = SetupOPRFClient(suite) + server = SetupOPRFServer(suite, skS) + + # Online Protocol + # + # Client Server(skS) + # ------------------------------------------------------------------- + # blind, blindedElement = Blind(input) + input = b'alice in wonderland' + blind, blinded_element = client.blind(input, rng) + # blindedElement + # ----------> + # + # evaluatedElement = BlindEvaluate(skS, blindedElement) + evaluated_element = server.blind_evaluate(blinded_element, rng) + # + # evaluatedElement + # <---------- + # + # output = Finalize(input, blind, evaluatedElement) + output = client.finalize(input, blind, evaluated_element) + print("mode:", "OPRF") + print("suite:", suite.name) + print("input:", to_hex(input)) + print("output:", to_hex(output)) diff --git a/poc/example_poprf.sage b/poc/example_poprf.sage new file mode 100644 index 00000000..a8f541ec --- /dev/null +++ b/poc/example_poprf.sage @@ -0,0 +1,55 @@ +#!/usr/bin/sage +# vim: syntax=python + +"""Exemplifies a run of the POPRF protocol""" + +import os +import sys + +try: + from sagelib.test_drng import TestDRNG + from sagelib.oprf \ + import DeriveKeyPair, SetupPOPRFServer, SetupPOPRFClient, MODE_POPRF, \ + oprf_ciphersuites, ciphersuite_p256_sha256 + +except ImportError as err: + sys.exit("Try running `make setup && make clean pyfiles`. Full error: " + err) + +to_hex = lambda x: "".join(["{:02x}".format(i) for i in x]) + +if __name__ == "__main__": + # Offline Setup + rng = TestDRNG('prng-seed'.encode('utf-8')) + suite = oprf_ciphersuites[ciphersuite_p256_sha256] + Ns = suite.group.scalar_byte_length() + info = b'info specific for this key' + seed = os.urandom(Ns) + skS, pkS = DeriveKeyPair(MODE_POPRF, suite, seed, info) + + client = SetupPOPRFClient(suite, pkS) + server = SetupPOPRFServer(suite, skS, pkS) + + # Online Protocol + # + # Client(pkS, info) <---- pkS ------ Server(skS, info) + # ------------------------------------------------------------------- + # blind, blindedElement, tweakedKey = Blind(input, info) + input = b'alice in wonderland' + blind, blinded_element, tweaked_key = client.blind(input, info, rng) + # + # blindedElement + # ----------> + # + # evaluatedElement, proof = BlindEvaluate(blindedElement, info) + evaluated_element, proof, _ = server.blind_evaluate(blinded_element, info, rng) + # + # evaluatedElement, proof + # <---------- + # + # output = Finalize(input, blind, evaluatedElement, + # blindedElement, proof, info, tweakedKey) + output = client.finalize(input, blind, evaluated_element, blinded_element, proof, info, tweaked_key) + print("mode:", "POPRF") + print("suite:", suite.name) + print("input:", to_hex(input)) + print("output:", to_hex(output)) diff --git a/poc/example_voprf.sage b/poc/example_voprf.sage new file mode 100644 index 00000000..e3cf56e5 --- /dev/null +++ b/poc/example_voprf.sage @@ -0,0 +1,52 @@ +#!/usr/bin/sage +# vim: syntax=python + +"""Exemplifies a run of the VOPRF protocol""" + +import os +import sys + +try: + from sagelib.test_drng import TestDRNG + from sagelib.oprf \ + import DeriveKeyPair, SetupVOPRFServer, SetupVOPRFClient, MODE_VOPRF, \ + oprf_ciphersuites, ciphersuite_p256_sha256 + +except ImportError as err: + sys.exit("Try running `make setup && make clean pyfiles`. Full error: " + err) + +to_hex = lambda x: "".join(["{:02x}".format(i) for i in x]) + +if __name__ == "__main__": + # Offline Setup + rng = TestDRNG('prng-seed'.encode('utf-8')) + suite = oprf_ciphersuites[ciphersuite_p256_sha256] + Ns = suite.group.scalar_byte_length() + info = b'info specific for this key' + seed = os.urandom(Ns) + skS, pkS = DeriveKeyPair(MODE_VOPRF, suite, seed, info) + + client = SetupVOPRFClient(suite, pkS) + server = SetupVOPRFServer(suite, skS, pkS) + + # Online Protocol + # + # Client(pkS) Server(skS,pkS) + # ------------------------------------------------------------------- + # blind, blindedElement = Blind(input) + input = b'alice in wonderland' + blind, blinded_element = client.blind(input, rng) + # blindedElement + # ----------> + # + # evaluatedElement, proof = BlindEvaluate(blindedElement) + evaluated_element, proof, _ = server.blind_evaluate(blinded_element, rng) + # <---------- + # + # output = Finalize(input, blind, evaluatedElement, + # blindedElement, proof) + output = client.finalize(input, blind, evaluated_element, blinded_element, proof) + print("mode:", "VOPRF") + print("suite:", suite.name) + print("input:", to_hex(input)) + print("output:", to_hex(output)) diff --git a/poc/oprf.sage b/poc/oprf.sage index 1c5c2bb4..fbd51ad3 100644 --- a/poc/oprf.sage +++ b/poc/oprf.sage @@ -41,56 +41,58 @@ class OPRFClientContext(Context): def identifier(self): return self.identifier - def blind(self, x, rng): + def blind(self, input, rng): blind = ZZ(self.suite.group.random_scalar(rng)) - input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag()) + input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag()) if input_element == self.suite.group.identity(): raise Exception("InvalidInputError") blinded_element = blind * input_element return blind, blinded_element - def unblind(self, blind, evaluated_element, blinded_element, proof): + def unblind(self, blind, evaluated_element): blind_inv = inverse_mod(blind, self.suite.group.order()) N = blind_inv * evaluated_element unblinded_element = self.suite.group.serialize(N) return unblinded_element - def finalize(self, x, blind, evaluated_element, blinded_element, proof, info): - unblinded_element = self.unblind(blind, evaluated_element, blinded_element, proof) - finalize_input = I2OSP(len(x), 2) + x \ + def finalize(self, input, blind, evaluated_element): + unblinded_element = self.unblind(blind, evaluated_element) + finalize_input = I2OSP(len(input), 2) + input \ + I2OSP(len(unblinded_element), 2) + unblinded_element \ + _as_bytes("Finalize") return self.suite.hash(finalize_input) class OPRFServerContext(Context): - def __init__(self, version, mode, suite, skS, pkS): + def __init__(self, version, mode, suite, skS): Context.__init__(self, version, mode, suite) self.skS = skS - self.pkS = pkS def internal_evaluate(self, blinded_element): evaluated_element = self.skS * blinded_element return evaluated_element - def blind_evaluate(self, blinded_element, info, rng): + def blind_evaluate(self, blinded_element, rng): evaluated_element = self.internal_evaluate(blinded_element) - return evaluated_element, None, None + return evaluated_element - def evaluate_without_proof(self, blinded_element, info): + def evaluate_without_proof(self, blinded_element): return self.internal_evaluate(blinded_element) - def evaluate(self, x, info): - input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag()) + def evaluate(self, input, expected_output): + input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag()) if input_element == self.suite.group.identity(): raise Exception("InvalidInputError") evaluated_element = self.internal_evaluate(input_element) issued_element = self.suite.group.serialize(evaluated_element) - finalize_input = I2OSP(len(x), 2) + x \ + + finalize_input = I2OSP(len(input), 2) + input \ + I2OSP(len(issued_element), 2) + issued_element \ + _as_bytes("Finalize") - return self.suite.hash(finalize_input) + digest = self.suite.hash(finalize_input) + + return (digest == expected_output) class Verifiable(object): def compute_composites_inner(self, k, B, Cs, Ds): @@ -138,12 +140,12 @@ class VOPRFClientContext(OPRFClientContext,Verifiable): self.pkS = pkS def verify_proof(self, A, B, Cs, Ds, proof): - a = self.compute_composites(B, Cs, Ds) + [M, Z] = self.compute_composites(B, Cs, Ds) + c = proof[0] + s = proof[1] - M = a[0] - Z = a[1] - t2 = (proof[1] * A) + (proof[0] * B) - t3 = (proof[1] * M) + (proof[0] * Z) + t2 = (s * A) + (c * B) + t3 = (s * M) + (c * Z) Bm = self.suite.group.serialize(B) a0 = self.suite.group.serialize(M) @@ -158,10 +160,10 @@ class VOPRFClientContext(OPRFClientContext,Verifiable): + I2OSP(len(a3), 2) + a3 \ + _as_bytes("Challenge") - c = self.suite.group.hash_to_scalar(h2s_input, self.scalar_domain_separation_tag()) + expectedC = self.suite.group.hash_to_scalar(h2s_input, self.scalar_domain_separation_tag()) - assert(c == proof[0]) - return c == proof[0] + assert(expectedC == c) + return expectedC == c def unblind(self, blind, evaluated_element, blinded_element, proof): G = self.suite.group.generator() @@ -190,15 +192,16 @@ class VOPRFClientContext(OPRFClientContext,Verifiable): return unblinded_elements - def finalize(self, x, blind, evaluated_element, blinded_element, proof, info): + def finalize(self, input, blind, evaluated_element, blinded_element, proof): unblinded_element = self.unblind(blind, evaluated_element, blinded_element, proof) - finalize_input = I2OSP(len(x), 2) + x \ + finalize_input = I2OSP(len(input), 2) + input \ + I2OSP(len(unblinded_element), 2) + unblinded_element \ + _as_bytes("Finalize") return self.suite.hash(finalize_input) - def finalize_batch(self, xs, blinds, evaluated_elements, blinded_elements, proof, info): + def finalize_batch(self, inputs, blinds, evaluated_elements, blinded_elements, proof): + assert(len(inputs) == len(blinds)) assert(len(blinds) == len(evaluated_elements)) assert(len(evaluated_elements) == len(blinded_elements)) @@ -206,25 +209,24 @@ class VOPRFClientContext(OPRFClientContext,Verifiable): outputs = [] for i, unblinded_element in enumerate(unblinded_elements): - finalize_input = I2OSP(len(xs[i]), 2) + xs[i] \ + finalize_input = I2OSP(len(inputs[i]), 2) + inputs[i] \ + I2OSP(len(unblinded_element), 2) + unblinded_element \ + _as_bytes("Finalize") - digest = self.suite.hash(finalize_input) - outputs.append(digest) + output = self.suite.hash(finalize_input) + outputs.append(output) return outputs class VOPRFServerContext(OPRFServerContext,Verifiable): def __init__(self, version, mode, suite, skS, pkS): - OPRFServerContext.__init__(self, version, mode, suite, skS, pkS) + OPRFServerContext.__init__(self, version, mode, suite, skS) + self.pkS = pkS def generate_proof(self, k, A, B, Cs, Ds, rng): - a = self.compute_composites_fast(k, B, Cs, Ds) + [M, Z] = self.compute_composites_fast(k, B, Cs, Ds) r = ZZ(self.suite.group.random_scalar(rng)) - M = a[0] - Z = a[1] t2 = r * A t3 = r * M @@ -250,15 +252,15 @@ class VOPRFServerContext(OPRFServerContext,Verifiable): evaluated_element = self.skS * blinded_element return evaluated_element - def blind_evaluate(self, blinded_element, info, rng): + def blind_evaluate(self, blinded_element, rng): evaluated_element = self.internal_evaluate(blinded_element) proof, r = self.generate_proof(self.skS, self.suite.group.generator(), self.pkS, [blinded_element], [evaluated_element], rng) return evaluated_element, proof, r - def evaluate_without_proof(self, blinded_element, info): + def evaluate_without_proof(self, blinded_element): return self.internal_evaluate(blinded_element) - def blind_evaluate_batch(self, blinded_elements, info, rng): + def blind_evaluate_batch(self, blinded_elements, rng): evaluated_elements = [] for blinded_element in blinded_elements: evaluated_element = self.skS * blinded_element @@ -272,16 +274,16 @@ class POPRFClientContext(VOPRFClientContext): VOPRFClientContext.__init__(self, version, mode, suite, pkS) self.pkS = pkS - def blind(self, x, info, rng): - context = _as_bytes("Info") + I2OSP(len(info), 2) + info - t = self.suite.group.hash_to_scalar(context, self.scalar_domain_separation_tag()) - G = self.suite.group.generator() - tweaked_key = (G * t) + self.pkS + def blind(self, input, info, rng): + framed_info = _as_bytes("Info") + I2OSP(len(info), 2) + info + m = self.suite.group.hash_to_scalar(framed_info, self.scalar_domain_separation_tag()) + T = m * self.suite.group.generator() + tweaked_key = T + self.pkS if tweaked_key == self.suite.group.identity(): raise Exception("InvalidInputError") blind = ZZ(self.suite.group.random_scalar(rng)) - input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag()) + input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag()) if input_element == self.suite.group.identity(): raise Exception("InvalidInputError") @@ -315,16 +317,17 @@ class POPRFClientContext(VOPRFClientContext): return unblinded_elements - def finalize(self, x, blind, evaluated_element, blinded_element, proof, info, tweaked_key): + def finalize(self, input, blind, evaluated_element, blinded_element, proof, info, tweaked_key): unblinded_element = self.unblind(blind, evaluated_element, blinded_element, proof, tweaked_key) - finalize_input = I2OSP(len(x), 2) + x \ + finalize_input = I2OSP(len(input), 2) + input \ + I2OSP(len(info), 2) + info \ + I2OSP(len(unblinded_element), 2) + unblinded_element \ + _as_bytes("Finalize") return self.suite.hash(finalize_input) - def finalize_batch(self, xs, blinds, evaluated_elements, blinded_elements, proof, info, tweaked_key): + def finalize_batch(self, inputs, blinds, evaluated_elements, blinded_elements, proof, info, tweaked_key): + assert(len(inputs) == len(blinds)) assert(len(blinds) == len(evaluated_elements)) assert(len(evaluated_elements) == len(blinded_elements)) @@ -332,13 +335,13 @@ class POPRFClientContext(VOPRFClientContext): outputs = [] for i, unblinded_element in enumerate(unblinded_elements): - finalize_input = I2OSP(len(xs[i]), 2) + xs[i] \ + finalize_input = I2OSP(len(inputs[i]), 2) + inputs[i] \ + I2OSP(len(info), 2) + info \ + I2OSP(len(unblinded_element), 2) + unblinded_element \ + _as_bytes("Finalize") - digest = self.suite.hash(finalize_input) - outputs.append(digest) + output = self.suite.hash(finalize_input) + outputs.append(output) return outputs @@ -347,21 +350,21 @@ class POPRFServerContext(VOPRFServerContext): VOPRFServerContext.__init__(self, version, mode, suite, skS, pkS) def internal_evaluate(self, blinded_element, info): - context = _as_bytes("Info") + I2OSP(len(info), 2) + info - t = self.suite.group.hash_to_scalar(context, self.scalar_domain_separation_tag()) - k = self.skS + t - if int(k) == 0: + framed_info = _as_bytes("Info") + I2OSP(len(info), 2) + info + m = self.suite.group.hash_to_scalar(framed_info, self.scalar_domain_separation_tag()) + t = (self.skS + m) % self.suite.group.order() + if int(t) == 0: raise Exception("InverseError") - k_inv = inverse_mod(k, self.suite.group.order()) - evaluated_element = k_inv * blinded_element + t_inv = inverse_mod(t, self.suite.group.order()) + evaluated_element = t_inv * blinded_element - return evaluated_element, k + return evaluated_element, t def blind_evaluate(self, blinded_element, info, rng): - evaluated_element, k = self.internal_evaluate(blinded_element, info) + evaluated_element, t = self.internal_evaluate(blinded_element, info) G = self.suite.group.generator() - U = k * G - proof, r = self.generate_proof(k, G, U, [evaluated_element], [blinded_element], rng) + tweaked_key = t * G + proof, r = self.generate_proof(t, G, tweaked_key, [evaluated_element], [blinded_element], rng) return evaluated_element, proof, r def evaluate_without_proof(self, blinded_element, info): @@ -369,34 +372,36 @@ class POPRFServerContext(VOPRFServerContext): return evaluated_element def blind_evaluate_batch(self, blinded_elements, info, rng): - context = _as_bytes("Info") + I2OSP(len(info), 2) + info - t = self.suite.group.hash_to_scalar(context, self.scalar_domain_separation_tag()) + framed_info = _as_bytes("Info") + I2OSP(len(info), 2) + info + m = self.suite.group.hash_to_scalar(framed_info, self.scalar_domain_separation_tag()) evaluated_elements = [] for blinded_element in blinded_elements: - k = self.skS + t - if int(k) == 0: + t = (self.skS + m) % self.suite.group.order() + if int(t) == 0: raise Exception("InverseError") - k_inv = inverse_mod(k, self.suite.group.order()) - evaluated_element = k_inv * blinded_element + t_inv = inverse_mod(t, self.suite.group.order()) + evaluated_element = t_inv * blinded_element evaluated_elements.append(evaluated_element) G = self.suite.group.generator() - U = k * G - proof, r = self.generate_proof(k, G, U, evaluated_elements, blinded_elements, rng) + tweaked_key = t * G + proof, r = self.generate_proof(t, G, tweaked_key, evaluated_elements, blinded_elements, rng) return evaluated_elements, proof, r - def evaluate(self, x, info): - input_element = self.suite.group.hash_to_group(x, self.group_domain_separation_tag()) + def evaluate(self, input, expected_output, info): + input_element = self.suite.group.hash_to_group(input, self.group_domain_separation_tag()) evaluated_element = self.evaluate_without_proof(input_element, info) issued_element = self.suite.group.serialize(evaluated_element) - finalize_input = I2OSP(len(x), 2) + x \ + finalize_input = I2OSP(len(input), 2) + input \ + I2OSP(len(info), 2) + info \ + I2OSP(len(issued_element), 2) + issued_element \ + _as_bytes("Finalize") - return self.suite.hash(finalize_input) + output = self.suite.hash(finalize_input) + + return (output == expected_output) MODE_OPRF = 0x00 MODE_VOPRF = 0x01 @@ -419,7 +424,7 @@ def DeriveKeyPair(mode, suite, seed, info): return skS, pkS def SetupOPRFServer(suite, skS): - return OPRFServerContext(VERSION, MODE_OPRF, suite, skS, None) + return OPRFServerContext(VERSION, MODE_OPRF, suite, skS) def SetupOPRFClient(suite): return OPRFClientContext(VERSION, MODE_OPRF, suite) @@ -446,7 +451,7 @@ ciphersuite_p521_sha512 = 0x0005 oprf_ciphersuites = { ciphersuite_ristretto255_sha512: Ciphersuite("OPRF(ristretto255, SHA-512)", ciphersuite_ristretto255_sha512, GroupRistretto255(), hashlib.sha512, lambda x : hashlib.sha512(x).digest()), - ciphersuite_decaf448_shake256: Ciphersuite("OPRF(decaf448, SHAKE256)", ciphersuite_decaf448_shake256, GroupDecaf448(), hashlib.shake_256, lambda x : hashlib.shake_256(x).digest(int(64))), + ciphersuite_decaf448_shake256: Ciphersuite("OPRF(decaf448, SHAKE-256)", ciphersuite_decaf448_shake256, GroupDecaf448(), hashlib.shake_256, lambda x : hashlib.shake_256(x).digest(int(64))), ciphersuite_p256_sha256: Ciphersuite("OPRF(P-256, SHA-256)", ciphersuite_p256_sha256, GroupP256(), hashlib.sha256, lambda x : hashlib.sha256(x).digest()), ciphersuite_p384_sha384: Ciphersuite("OPRF(P-384, SHA-384)", ciphersuite_p384_sha384, GroupP384(), hashlib.sha384, lambda x : hashlib.sha384(x).digest()), ciphersuite_p521_sha512: Ciphersuite("OPRF(P-521, SHA-512)", ciphersuite_p521_sha512, GroupP521(), hashlib.sha512, lambda x : hashlib.sha512(x).digest()), diff --git a/poc/test_oprf.sage b/poc/test_oprf.sage index ebc627f3..6da7ff1a 100644 --- a/poc/test_oprf.sage +++ b/poc/test_oprf.sage @@ -66,18 +66,23 @@ class Protocol(object): client = self.client server = self.server - def create_test_vector_for_input(x, info): + def create_test_vector_for_input(input, info): rng = TestDRNG("test vector seed".encode('utf-8')) if self.mode == MODE_POPRF: - blind, blinded_element, tweaked_key = client.blind(x, info, rng) + blind, blinded_element, tweaked_key = client.blind(input, info, rng) evaluated_element, proof, proof_randomness = server.blind_evaluate(blinded_element, info, rng) - output = client.finalize(x, blind, evaluated_element, blinded_element, proof, info, tweaked_key) - else: - blind, blinded_element = client.blind(x, rng) - evaluated_element, proof, proof_randomness = server.blind_evaluate(blinded_element, info, rng) - output = client.finalize(x, blind, evaluated_element, blinded_element, proof, info) - - assert(output == server.evaluate(x, info)) + output = client.finalize(input, blind, evaluated_element, blinded_element, proof, info, tweaked_key) + assert(server.evaluate(input, output, info)) + elif self.mode == MODE_VOPRF: + blind, blinded_element = client.blind(input, rng) + evaluated_element, proof, proof_randomness = server.blind_evaluate(blinded_element, rng) + output = client.finalize(input, blind, evaluated_element, blinded_element, proof) + assert(server.evaluate(input, output)) + elif self.mode == MODE_OPRF: + blind, blinded_element = client.blind(input, rng) + evaluated_element = server.blind_evaluate(blinded_element, rng) + output = client.finalize(input, blind, evaluated_element) + assert(server.evaluate(input, output)) vector = {} vector["Blind"] = to_hex(group.serialize_scalar(blind)) @@ -90,7 +95,7 @@ class Protocol(object): "r": to_hex(group.serialize_scalar(proof_randomness)), } - vector["Input"] = to_hex(x) + vector["Input"] = to_hex(input) if self.mode == MODE_POPRF: vector["Info"] = to_hex(info) vector["Output"] = to_hex(output) @@ -98,30 +103,34 @@ class Protocol(object): return vector - def create_batched_test_vector_for_inputs(xs, info): + def create_batched_test_vector_for_inputs(inputs, info): blinds = [] blinded_elements = [] tweaked_key = None rng = TestDRNG("test vector seed".encode('utf-8')) - for x in xs: - if self.mode == MODE_POPRF: - blind, blinded_element, tweaked_key = client.blind(x, info, rng) - blinds.append(blind) - blinded_elements.append(blinded_element) - else: - blind, blinded_element = client.blind(x, rng) + + if self.mode == MODE_POPRF: + tweaked_key = None + for input in inputs: + blind, blinded_element, tweaked_key = client.blind(input, info, rng) blinds.append(blind) blinded_elements.append(blinded_element) - evaluated_elements, proof, proof_randomness = server.blind_evaluate_batch(blinded_elements, info, rng) + evaluated_elements, proof, proof_randomness = server.blind_evaluate_batch(blinded_elements, info, rng) + outputs = client.finalize_batch(inputs, blinds, evaluated_elements, blinded_elements, proof, info, tweaked_key) + for i, output in enumerate(outputs): + assert(server.evaluate(inputs[i], output, info)) - if self.mode == MODE_POPRF: - outputs = client.finalize_batch(xs, blinds, evaluated_elements, blinded_elements, proof, info, tweaked_key) - else: - outputs = client.finalize_batch(xs, blinds, evaluated_elements, blinded_elements, proof, info) + elif self.mode == MODE_VOPRF: + for input in inputs: + blind, blinded_element = client.blind(input, rng) + blinds.append(blind) + blinded_elements.append(blinded_element) - for i, output in enumerate(outputs): - assert(output == server.evaluate(xs[i], info)) + evaluated_elements, proof, proof_randomness = server.blind_evaluate_batch(blinded_elements, rng) + outputs = client.finalize_batch(inputs, blinds, evaluated_elements, blinded_elements, proof) + for i, output in enumerate(outputs): + assert(server.evaluate(inputs[i], output)) vector = {} vector["Blind"] = ",".join([to_hex(group.serialize_scalar(blind)) for blind in blinds]) @@ -134,15 +143,15 @@ class Protocol(object): "r": to_hex(group.serialize_scalar(proof_randomness)), } - vector["Input"] = to_hex(xs) + vector["Input"] = to_hex(inputs) if self.mode == MODE_POPRF: vector["Info"] = to_hex(info) vector["Output"] = to_hex(outputs) - vector["Batch"] = int(len(xs)) + vector["Batch"] = int(len(inputs)) return vector - vectors = [create_test_vector_for_input(x, self.info) for x in self.inputs] + vectors = [create_test_vector_for_input(input, self.info) for input in self.inputs] if self.mode == MODE_VOPRF or self.mode == MODE_POPRF: vectors.append(create_batched_test_vector_for_inputs(self.inputs, self.info))