diff --git a/parsons/auth0/auth0.py b/parsons/auth0/auth0.py index 592686c3e2..45d4e6d2c2 100644 --- a/parsons/auth0/auth0.py +++ b/parsons/auth0/auth0.py @@ -1,9 +1,14 @@ +import gzip import json +import logging +import time import requests from parsons.etl.table import Table from parsons.utilities import check_env +logger = logging.getLogger(__name__) + class Auth0(object): """ @@ -80,6 +85,7 @@ def upsert_user( family_name=None, app_metadata={}, user_metadata={}, + connection="Username-Password-Authentication", ): """ Upsert Auth0 users by email. @@ -100,18 +106,21 @@ def upsert_user( `Returns:` Requests Response object """ - payload = json.dumps( - { - "email": email.lower(), - "given_name": given_name, - "family_name": family_name, - "username": username, - "connection": "Username-Password-Authentication", - "app_metadata": app_metadata, - "blocked": False, - "user_metadata": user_metadata, - } - ) + + obj = { + "email": email.lower(), + "username": username, + "connection": connection, + "app_metadata": app_metadata, + "blocked": False, + "user_metadata": user_metadata, + } + if given_name is not None: + obj["given_name"] = given_name + if family_name is not None: + obj["family_name"] = family_name + payload = json.dumps(obj) + existing = self.get_users_by_email(email.lower()) if existing.num_rows > 0: a0id = existing[0]["user_id"] @@ -127,3 +136,106 @@ def upsert_user( if ret.status_code != 200: raise ValueError(f"Invalid response {ret.json()}") return ret + + def block_user(self, user_id, connection="Username-Password-Authentication"): + """ + Blocks Auth0 users by email - setting the "blocked" attribute on Auth0's API. + + `Args:` + user_id: str + Auth0 user id + connection: optional str + Name of auth0 connection (default to Username-Password-Authentication) + `Returns:` + Requests Response object + """ + payload = json.dumps({"connection": connection, "blocked": True}) + ret = requests.patch( + f"{self.base_url}/api/v2/users/{user_id}", + headers=self.headers, + data=payload, + ) + if ret.status_code != 200: + raise ValueError(f"Invalid response {ret.json()}") + return ret + + def retrieve_all_users(self, connection="Username-Password-Authentication"): + """ + Retrieves all Auth0 users using the batch jobs endpoint. + + `Args:` + connection: optional str + Name of auth0 connection (default to Username-Password-Authentication) + `Returns:` + Requests Response object + """ + connection_id = self.get_connection_id(connection) + url = f"{self.base_url}/api/v2/jobs/users-exports" + + headers = self.headers + + fields = [ + {"name": n} + for n in ["user_id", "username", "email", "user_metadata", "app_metadata"] + ] + # Start the users-export job + response = requests.post( + url, + headers=headers, + json={"connection_id": connection_id, "format": "json", "fields": fields}, + ) + job_id = response.json().get("id") + + if job_id: + # Check job status until complete + while True: + status_response = requests.get( + f"{self.base_url}/api/v2/jobs/{job_id}", headers=headers + ) + status_data = status_response.json() + if status_response.status_code == 429: + time.sleep(10) + + elif status_response.status_code != 200: + break + elif status_data.get("status") == "completed": + download_url = status_data.get("location") + break + elif status_data.get("status") == "failed": + logger.error("Retrieve members job failed to complete.") + return None + + # Download the users-export file + users_response = requests.get(download_url) + + decompressed_data = gzip.decompress(users_response.content).decode("utf-8") + users_data = [] + for d in decompressed_data.split("\n"): + if d: + users_data.append(json.loads(d)) + + return Table(users_data) + + logger.error("Retrieve members job creation failed") + return None + + def get_connection_id(self, connection_name): + """ + Retrieves an Auth0 connection_id corresponding to a specific connection name + + `Args:` + connection_name: str + Name of auth0 connection + `Returns:` + Connection ID string + """ + url = f"{self.base_url}/api/v2/connections" + + response = requests.get(url, headers=self.headers) + connections = response.json() + + for connection in connections: + if connection["name"] == connection_name: + return connection["id"] + + return None diff --git a/test/test_auth0.py b/test/test_auth0.py index f15627f9c6..3f54bf7979 100644 --- a/test/test_auth0.py +++ b/test/test_auth0.py @@ -1,3 +1,5 @@ +import gzip +import json import unittest import unittest.mock from test.utils import assert_matching_tables @@ -39,6 +41,35 @@ def test_get_users_by_email(self, m): self.auth0.get_users_by_email(email), Table(mock_users), True ) + @requests_mock.Mocker() + def test_retrieve_all_users(self, m): + mock_users = [{"email": "fake3mail@fakedomain.com", "id": 2}] + + fake_job_id = 1234567 + m.post( + f"{self.auth0.base_url}/api/v2/jobs/users-exports", + json={"id": fake_job_id}, + ) + test_url = f"{self.auth0.base_url}/test.json.gz" + m.get( + f"{self.auth0.base_url}/api/v2/jobs/{fake_job_id}", + json={ + "status": "completed", + "location": test_url, + }, + ) + m.get( + test_url, + content=gzip.compress(bytes(json.dumps(mock_users), encoding="utf-8")), + ) + + connections = [{"id": 1234, "name": "Username-Password-Authentication"}] + m.get(f"{self.auth0.base_url}/api/v2/connections", json=connections) + data = self.auth0.retrieve_all_users() + print(data) + + assert_matching_tables(self.auth0.retrieve_all_users(), Table(mock_users), True) + @requests_mock.Mocker() def test_upsert_user(self, m): user = self.fake_upsert_person @@ -60,3 +91,13 @@ def test_upsert_user(self, m): {}, ) self.assertEqual(ret.status_code, 200) + + @requests_mock.Mocker() + def test_block_user(self, m): + user = self.fake_upsert_person + user["blocked"] = True + mock_resp = unittest.mock.MagicMock() + mock_resp.status_code = 200 + m.patch(f"{self.auth0.base_url}/api/v2/users/{user['user_id']}", [mock_resp]) + ret = self.auth0.block_user(user["user_id"]) + self.assertEqual(ret.status_code, 200)