From b8b803ac3208f415908e34850a03a87e53fec0a7 Mon Sep 17 00:00:00 2001 From: David Parks Date: Tue, 30 Jul 2024 22:56:01 -0700 Subject: [PATCH] Updates to messaging and shadows to use new authentication (#94) * updates to messaging and shadows to use new authentication * fixed webbrowser requirement * Update __init__.py * Moved authenticate to its own module to avoid circular imports * fixed auth header --------- Co-authored-by: Kateryna Voitiuk --- .gitignore | 3 + pyproject.toml | 2 +- src/braingeneers/iot/__init__.py | 8 +- src/braingeneers/iot/authenticate.py | 130 +++++++++++++++++++++++++++ src/braingeneers/iot/messaging.py | 38 +++++++- src/braingeneers/iot/shadows.py | 98 ++++++++++---------- 6 files changed, 225 insertions(+), 54 deletions(-) create mode 100644 src/braingeneers/iot/authenticate.py diff --git a/.gitignore b/.gitignore index 4c6417f..be6bdd3 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,6 @@ tmp/ **/.DS_Store dist/ **/.vscode/** + +# don't commit the service_account file +src/braingeneers/iot/service_account/config.json diff --git a/pyproject.toml b/pyproject.toml index e5a7474..50dc964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "typing_extensions>=4.6; python_version<'3.11'", 'diskcache', 'pytz', - 'tzlocal' + 'tzlocal', ] [tool.hatch.build.hooks.vcs] diff --git a/src/braingeneers/iot/__init__.py b/src/braingeneers/iot/__init__.py index 1c6f8b9..99a71e4 100644 --- a/src/braingeneers/iot/__init__.py +++ b/src/braingeneers/iot/__init__.py @@ -1,4 +1,4 @@ -import braingeneers -from braingeneers.iot.messaging import * -from braingeneers.iot.device import * -from braingeneers.iot.simple import * +import braingeneers +from braingeneers.iot.messaging import * +from braingeneers.iot.device import * +from braingeneers.iot.simple import * diff --git a/src/braingeneers/iot/authenticate.py b/src/braingeneers/iot/authenticate.py new file mode 100644 index 0000000..d5de1a2 --- /dev/null +++ b/src/braingeneers/iot/authenticate.py @@ -0,0 +1,130 @@ + +import os +import json +import webbrowser +import importlib.resources +import configparser +import datetime +import requests +import argparse + + +def authenticate_and_get_token(): + """ + Directs users to a URL to authenticate and get a JWT token. + Once the token has been obtained manually it will refresh automatically every month. + By default, the token is valid for 4 months from issuance. + Returns token data as a dict containing `access_token` and `expires_at` keys. + """ + PACKAGE_NAME = "braingeneers.iot" + + url = 'https://service-accounts.braingeneers.gi.ucsc.edu/generate_token' + print(f'Please visit the following URL to generate your JWT token: {url}') + webbrowser.open(url) + + token_json = input('Please paste the JSON token issued by the page and press Enter:\n') + try: + token_data = json.loads(token_json) + except json.JSONDecodeError: + raise ValueError('Invalid JSON. Please make sure you have copied the token correctly.') + + config_dir = os.path.join(importlib.resources.files(PACKAGE_NAME), 'service_account') + os.makedirs(config_dir, exist_ok=True) + config_file = os.path.join(config_dir, 'config.json') + + with open(config_file, 'w') as f: + json.dump(token_data, f) + + print('Token has been saved successfully.') + return token_data + + +def update_config_file(file_path, section, key, new_value): + with open(file_path, 'r') as file: + lines = file.readlines() + + with open(file_path, 'w') as file: + section_found = False + for line in lines: + if line.strip() == f'[{section}]': + section_found = True + if section_found and line.strip().startswith(key): + line = f'{key} = {new_value}\n' + section_found = False # Reset the flag + file.write(line) + + +def picroscope_authenticate_and_update_token(credentials_file): + """ + Authentication and update service-account token for legacy picroscope environment. This updates the AWS credentials file + with the JWT token and updates it if it has <3 months before expiration. This function can be run as a cron job. + """ + # Check if the JWT token exists and if it exists in the credentials file if it's expired. + # The credentials file section is [strapi] with `api_key` containing the jwt token, and `api_key_expires` containing + # the expiration date in ISO format. + config_file_path = os.path.expanduser(credentials_file) + + config = configparser.ConfigParser() + with open(config_file_path, 'r') as f: + config.read_string(f.read()) + + assert 'strapi' in config, \ + 'Your AWS credentials file is missing a section [strapi], you may have the wrong version of the credentials file.' + + token_exists = 'api_key' in config['strapi'] + expire_exists = 'api_key_expires' in config['strapi'] + + if expire_exists: + expiration_str = config['strapi']['api_key_expires'] + expiration_str = expiration_str.split(' ')[0] + ' ' + expiration_str.split(' ')[1] # Remove timezone + expiration_date = datetime.datetime.fromisoformat(expiration_str) + days_remaining = (expiration_date - datetime.datetime.now()).days + print('Days remaining for token:', days_remaining) + else: + days_remaining = -1 + + # check if api_key_expires exists, if not, it's expired, else check if it has <90 days remaining on it + manual_refresh = not token_exists \ + or not expire_exists \ + or (datetime.datetime.fromisoformat(config['strapi']['api_key_expires']) - datetime.datetime.now()).days < 0 + auto_refresh = (token_exists and expire_exists) \ + and (datetime.datetime.fromisoformat(config['strapi']['api_key_expires']) - datetime.datetime.now()).days < 90 + + if manual_refresh or auto_refresh: + token_data = authenticate_and_get_token() if manual_refresh else requests.get(url).json() + update_config_file(config_file_path, 'strapi', 'api_key', token_data['access_token']) + update_config_file(config_file_path, 'strapi', 'api_key_expires', token_data['expires_at']) + print(f'JWT token has been updated in {config_file_path}') + else: + print('JWT token is still valid, no action taken.') + + +def parse_args(): + """ + Two commands are available: + + # Authenticate and obtain a JWT service account token for braingeneerspy + python -m braingeneers.iot.messaging authenticate + + # Authenticate and obtain a JWT service account token for picroscope specific environment + python -m braingeneers.iot.messaging authenticate picroscope + """ + parser = argparse.ArgumentParser(description='JWT Service Account Token Management') + parser.add_argument('config', nargs='?', choices=['picroscope'], help='Picroscope specific JWT token configuration.') + parser.add_argument('--credentials', default='~/.aws/credentials', help='Path to the AWS credentials file, only used for picroscope authentication.') + + return parser.parse_args() + + +def main(): + args = parse_args() + + if args.config == 'picroscope': + credentials_file = args.credentials + picroscope_authenticate_and_update_token(credentials_file) + else: + authenticate_and_get_token() + + +if __name__ == '__main__': + main() diff --git a/src/braingeneers/iot/messaging.py b/src/braingeneers/iot/messaging.py index db19cfc..6f38710 100644 --- a/src/braingeneers/iot/messaging.py +++ b/src/braingeneers/iot/messaging.py @@ -12,6 +12,9 @@ import json import braingeneers.iot.shadows as sh import pickle +import importlib +import argparse +import datetime from typing import Callable, Tuple, List, Dict, Union from deprecated import deprecated @@ -162,8 +165,9 @@ def __init__(self, name: str = None, credentials_file: (str, io.IOBase) = None, self._boto_iot_client = None self._boto_iot_data_client = None self._redis_client = None + self._jwt_service_account_token = None - self.shadow_interface = sh.DatabaseInteractor() + self.shadow_interface = sh.DatabaseInteractor(jwt_service_token=self.jwt_service_account_token) self._subscribed_data_streams = set() # keep track of subscribed data streams self._subscribed_message_callback_map = {} # keep track of subscribed message callbacks, key is regex, value is tuple of (callback, topic) @@ -789,6 +793,38 @@ def redis_client(self) -> redis.Redis: return self._redis_client + @property + def jwt_service_account_token(self) -> str: + """ Lazy initialization of the JWT service account token. """ + PACKAGE_NAME = "braingeneers.iot" + config_dir = os.path.join(importlib.resources.files(PACKAGE_NAME), 'service_account') + config_file = os.path.join(config_dir, 'config.json') + + if self._jwt_service_account_token is None: + # Check if the JWT token exists + # This token is required for all operations that require web services. + # The token is a (json) dict of form {'access_token': '----', 'expires_at': '2024-11-07 23:39:42 UTC'} + os.makedirs(config_dir, exist_ok=True) + + # Try to load an existing JWT token locally if it exists + if os.path.exists(config_file): + with open(config_file, 'r') as f: + self._jwt_service_account_token = json.load(f) + + if self._jwt_service_account_token is None: + raise PermissionError('JWT service account token not found, please generate one using: python -m braingeneers.iot.messaging authenticate') + + # Check if the token is still valid, this happens on every access, but takes no action while it's still valid. + # If the token has less than 3 month left, refresh it, default tokens have 30 days at issuance. + expires_at = datetime.datetime.fromisoformat(self._jwt_service_account_token['expires_at'].replace(' UTC', '')) + if (expires_at - datetime.datetime.now()).days < 90: + GENERATE_TOKEN_URL = 'https://service-accounts.braingeneers.gi.ucsc.edu/generate_token' + self._jwt_service_account_token = requests.get(GENERATE_TOKEN_URL).json() + with open(config_file, 'w') as f: + json.dump(self._jwt_service_account_token, f) + + return self._jwt_service_account_token + def shutdown(self): """ Release resources and shutdown connections as needed. """ if self.certs_temp_dir is not None: diff --git a/src/braingeneers/iot/shadows.py b/src/braingeneers/iot/shadows.py index 16b225f..363003e 100644 --- a/src/braingeneers/iot/shadows.py +++ b/src/braingeneers/iot/shadows.py @@ -5,7 +5,6 @@ from typing import Union - class DatabaseInteractor: """ This class provides methods for interacting with the Strapi Shadows database. @@ -45,7 +44,7 @@ class objects: - get_sample: returns a sample object given its id - get_well: returns a well object given its id """ - def __init__(self , credentials: Union[str, io.IOBase] = None, overwrite_endpoint = None, overwrite_api_key = None) -> None: + def __init__(self, credentials: Union[str, io.IOBase] = None, overwrite_endpoint=None, overwrite_api_key=None, jwt_service_token=None) -> None: if credentials is None: credentials = os.path.expanduser('~/.aws/credentials') # default credentials location @@ -66,24 +65,29 @@ def __init__(self , credentials: Union[str, io.IOBase] = None, overwrite_endpoin assert 'api_key' in config['strapi'], 'Your AWS credentials file is malformed, ' \ 'api_key was not found under the [strapi] section.' + # Note that the "token" is a basic auth construct originally implemented with Strapi before full JWT + # authentication was available. It's deprecated, but still in use, if someone wants to reconfigure + # the services to remove it that would be good, but until then it's a superfluous detail. + # The JWT service token is the updated way to authenticate with all web services including Strapi self.endpoint = config['strapi']['endpoint'] self.token = config['strapi']['api_key'] if overwrite_endpoint: self.endpoint = overwrite_endpoint if overwrite_api_key: self.token = overwrite_api_key + self.jwt_service_token = jwt_service_token - class __API_object: """ This class is used to represent objects in the database as python objects """ - def __init__(self, endpoint, api_token, api_object_id): - self.endpoint = endpoint - self.token = api_token - self.id = None - self.attributes = {} - self.api_object_id = api_object_id + def __init__(self, endpoint, api_token, api_object_id, jwt_service_token): + self.endpoint = endpoint + self.token = api_token + self.id = None + self.attributes = {} + self.api_object_id = api_object_id + self.jwt_service_token = jwt_service_token def __str__(self): var_list = filter(lambda x: x not in ["endpoint", "token", "api_object_id"], vars(self)) @@ -119,13 +123,12 @@ def spawn(self): creates a new object in the database """ url = self.endpoint + "/"+self.api_object_id+"?filters[name][$eq]=" + self.attributes["name"] + "&populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} response = requests.get(url, headers=headers) if len(response.json()['data']) == 0: api_url = self.endpoint+"/"+self.api_object_id+"?populate=%2A" data = {"data": self.attributes} - response = requests.post(api_url, json=data, headers={ - 'Authorization': 'bearer ' + self.token}) + response = requests.post(api_url, json=data, headers=headers) if response.status_code == 200: self.parse_API_response(response.json()['data']) else: @@ -140,7 +143,7 @@ def push(self): updates the database with the current state of the object """ url = self.endpoint + "/"+self.api_object_id+"/" + str(self.id) + "?populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} data = {"data": self.attributes} response = requests.put(url, headers=headers, json=data) self.parse_API_response(response.json()['data']) @@ -150,7 +153,7 @@ def pull(self): updates object with the latest data from the database """ url = self.endpoint + "/"+self.api_object_id+"/" + str(self.id) + "?populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} response = requests.get(url, headers=headers) if len(response.json()['data']) == 0: raise Exception("Object not found") @@ -162,7 +165,7 @@ def get_by_name(self, name): gets the object from the database by name """ url = self.endpoint + "/"+self.api_object_id+"?filters[name][$eq]=" + name + "&populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} response = requests.get(url, headers=headers) if len(response.json()['data']) == 0: raise Exception("no " + self.api_object_id + " object with name " + name) @@ -174,7 +177,7 @@ def move_to_trash(self): marks the object for deletion """ url = self.endpoint + "/"+self.api_object_id+"/" + str(self.id) + "?populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} response = requests.get(url, headers=headers) if len(response.json()['data']) == 0: raise Exception("Object not found") @@ -188,7 +191,7 @@ def recover_from_trash(self): unmarks the object for deletion """ url = self.endpoint + "/"+self.api_object_id+"/" + str(self.id) + "?populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} response = requests.get(url, headers=headers) if len(response.json()['data']) == 0: raise Exception("Object not found") @@ -199,8 +202,8 @@ def recover_from_trash(self): self.push() class __Thing(__API_object): - def __init__(self, endpoint, api_token): - super().__init__(endpoint, api_token, "interaction-things") + def __init__(self, endpoint, api_token, jwt_service_token): + super().__init__(endpoint, api_token, "interaction-things", jwt_service_token) def add_to_shadow(self, json): if self.attributes["shadow"] is None: @@ -217,7 +220,6 @@ def add_uuid_to_shadow(self, uuid): self.attributes["shadow"]["uuid"] = uuid self.push() - def set_current_plate(self, plate): """ updates the current plate of the thing and adds the plate to the list of all plates historically associated with the thing. @@ -241,8 +243,8 @@ def set_current_experiment(self, experiment): self.push() class __Experiment(__API_object): - def __init__(self, endpoint, api_token): - super().__init__(endpoint, api_token, "experiments") + def __init__(self, endpoint, api_token, jwt_service_token): + super().__init__(endpoint, api_token, "experiments", jwt_service_token=jwt_service_token) def add_plate(self, plate): # Bidirectional relations have an owner and a related object, plate owns this relation @@ -250,8 +252,8 @@ def add_plate(self, plate): self.pull() class __Plate(__API_object): - def __init__(self, endpoint, api_token): - super().__init__(endpoint, api_token, "plates") + def __init__(self, endpoint, api_token, jwt_service_token): + super().__init__(endpoint, api_token, "plates", jwt_service_token) def add_thing(self, thing): """ @@ -309,12 +311,12 @@ def add_experiment(self, experiment): self.push() class __Well(__API_object): - def __init__(self, endpoint, api_token): - super().__init__(endpoint, api_token, "wells") + def __init__(self, endpoint, api_token, jwt_service_token): + super().__init__(endpoint, api_token, "wells", jwt_service_token) class __Sample(__API_object): - def __init__(self, endpoint, api_token): - super().__init__(endpoint, api_token, "samples") + def __init__(self, endpoint, api_token, jwt_service_token): + super().__init__(endpoint, api_token, "samples", jwt_service_token) def empty_trash(self): """ @@ -323,7 +325,7 @@ def empty_trash(self): object_list = ["interaction-things", "experiments", "plates", "wells", "samples"] for object in object_list: url = self.endpoint + "/"+object+"?filters[marked_for_deletion][$eq]=true&populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} response = requests.get(url, headers=headers) for item in response.json()['data']: url = self.endpoint + "/"+object+"/" + str(item['id']) @@ -331,14 +333,14 @@ def empty_trash(self): print("deleted object of type: " + object + " with id " + str(item['id'])) def create_interaction_thing(self, type, name): - thing = self.__Thing(self.endpoint, self.token) + thing = self.__Thing(self.endpoint, self.token, self.jwt_service_token) thing.attributes["name"] = name thing.attributes["type"] = type thing.spawn() return thing - def create_plate(self, name, rows, columns, image_params = {}): - plate = self.__Plate(self.endpoint, self.token) + def create_plate(self, name, rows, columns, image_params={}): + plate = self.__Plate(self.endpoint, self.token, self.jwt_service_token) plate.attributes["name"] = name plate.attributes["rows"] = rows plate.attributes["columns"] = columns @@ -348,7 +350,7 @@ def create_plate(self, name, rows, columns, image_params = {}): if len(plate.attributes["wells"]) == 0 or plate.attributes["wells"] is None: for i in range(1, rows+1): for j in range(1, columns+1): - well = self.__Well(self.endpoint, self.token) + well = self.__Well(self.endpoint, self.token, self.jwt_service_token) well.attributes["name"] = plate.attributes["name"]+"_well_"+str(i)+str(j) well.attributes["position_index"] = str(i) + str(j) well.attributes["plate"] = plate.id @@ -359,7 +361,7 @@ def create_plate(self, name, rows, columns, image_params = {}): def create_experiment(self, name, description): - experiment = self.__Experiment(self.endpoint, self.token) + experiment = self.__Experiment(self.endpoint, self.token, self.jwt_service_token) experiment.attributes["name"] = name experiment.attributes["description"] = description experiment.spawn() @@ -371,14 +373,14 @@ def start_image_capture(self, thing, uuid): group_id = thing.attributes["shadow"]["group-id"] value = { uuid : group_id } if thing.attributes["current_plate"]: - plate = self.__Plate(self.endpoint, self.token) + plate = self.__Plate(self.endpoint, self.token, self.jwt_service_token) plate.id = thing.attributes["current_plate"][0] plate.pull() plate.add_uuid_to_image_params(value) else: raise Exception("no plate associated with thing") - def list_objects(self, api_object_id, filter = "?", hide_deleted = True): + def list_objects(self, api_object_id, filter="?", hide_deleted=True): """ when you need a list of the objects in the database useful for populating dropdown lists in plotly dash @@ -386,11 +388,11 @@ def list_objects(self, api_object_id, filter = "?", hide_deleted = True): if hide_deleted: filter += "&filters[marked_for_deletion][$eq]=false" url = self.endpoint + "/"+ api_object_id + filter +"&populate=%2A" - headers = {"Authorization": "Bearer " + self.token} + headers = {"Authorization": "Bearer " + self.jwt_service_token['access_token']} response = requests.get(url, headers=headers) return response.json()['data'] - def list_objects_with_name_and_id(self, api_object_id, filter = "?", hide_deleted = True): + def list_objects_with_name_and_id(self, api_object_id, filter="?", hide_deleted=True): """ when you need a list of the objects in the database @@ -407,20 +409,20 @@ def list_experiments(self, hide_deleted = True): output.append(i["attributes"]["name"]) return output - def list_BioPlateScopes(self, hide_deleted = True): + def list_BioPlateScopes(self, hide_deleted=True): return self.list_objects_with_name_and_id("interaction-things", "?filters[type][$eq]=BioPlateScope", hide_deleted) def list_devices_by_type(self, thingTypeName, hide_deleted = True): return self.list_objects_with_name_and_id("interaction-things", "?filters[type][$eq]="+thingTypeName, hide_deleted) def get_device_state(self, thing_id): - thing = self.__Thing(self.endpoint, self.token) + thing = self.__Thing(self.endpoint, self.token, self.jwt_service_token) thing.id = thing_id thing.pull() return thing.attributes["shadow"] def get_device_state_by_name(self, thing_name): - thing = self.__Thing(self.endpoint, self.token) + thing = self.__Thing(self.endpoint, self.token, self.jwt_service_token) thing.get_by_name(thing_name) return thing.attributes["shadow"] @@ -430,39 +432,39 @@ def get_device_state_by_name(self, thing_name): Getters for objects from their id numbers """ - def get_device(self, thing_id= None, name = None): + def get_device(self, thing_id=None, name=None): if thing_id is None and name is None: raise Exception("must provide either thing_id or name") if name: - thing = self.__Thing(self.endpoint, self.token) + thing = self.__Thing(self.endpoint, self.token, self.jwt_service_token) thing.get_by_name(name) return thing else: - thing = self.__Thing(self.endpoint, self.token) + thing = self.__Thing(self.endpoint, self.token, self.jwt_service_token) thing.id = thing_id thing.pull() return thing def get_plate(self, plate_id): - plate = self.__Plate(self.endpoint, self.token) + plate = self.__Plate(self.endpoint, self.token, self.jwt_service_token) plate.id = plate_id plate.pull() return plate def get_experiment(self, experiment_id): - experiment = self.__Experiment(self.endpoint, self.token) + experiment = self.__Experiment(self.endpoint, self.token, self.jwt_service_token) experiment.id = experiment_id experiment.pull() return experiment def get_sample(self, sample_id): - sample = self.__Sample(self.endpoint, self.token) + sample = self.__Sample(self.endpoint, self.token, self.jwt_service_token) sample.id = sample_id sample.pull() return sample def get_well(self, well_id): - well = self.__Well(self.endpoint, self.token) + well = self.__Well(self.endpoint, self.token, self.jwt_service_token) well.id = well_id well.pull() return well