diff --git a/.github/workflows/docker-publish.yaml b/.github/workflows/docker-publish.yaml index 106dbff..316a3ad 100644 --- a/.github/workflows/docker-publish.yaml +++ b/.github/workflows/docker-publish.yaml @@ -21,6 +21,40 @@ env: jobs: + lint: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ "3.8", "3.9", "3.10" ] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.9 + uses: actions/setup-python@v3 + with: + python-version: "3.9" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pytest-asyncio + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with pytest + run: | + pytest build: runs-on: ubuntu-latest permissions: diff --git a/main.py b/main.py index d079eb4..a84a9c3 100644 --- a/main.py +++ b/main.py @@ -13,63 +13,50 @@ import boto3 from botocore.exceptions import ClientError -# Read config from env -TOKEN = os.getenv('DISCORD_TOKEN') -if TOKEN is None: - raise Exception('No discord token set. Please set DISCORD_TOKEN.') - -GPG_KEY_DIR = os.getenv('GPG_KEY_DIR') -if GPG_KEY_DIR is None: - raise Exception('No GPG key directory set. Please set GPG_KEY_DIR.') - -EPHEMERAL_PATH = os.getenv('EPHEMERAL_PATH') -if EPHEMERAL_PATH is None: - print(f'EPHEMERAL_PATH not set. will dump all files in working dir.') - EPHEMERAL_PATH = '.' - -S3_ENABLED = os.getenv('S3_ENABLED') == '1' -S3_BUCKET = os.getenv('S3_BUCKET') -S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY') -S3_SECRET_KEY = os.getenv('S3_SECRET_KEY') -S3_ENDPOINT = os.getenv('S3_ENDPOINT') +global S3_BUCKET +S3_ENABLED = False +key_fingerprints = [] +gpg = GPG() -HEARTBEAT_URL = os.getenv('HEARTBEAT_URL') -# Send a start signal to heartbeat -try: - requests.get(HEARTBEAT_URL + "/start", timeout=5) -except requests.exceptions.RequestException: - # If the network request fails for any reason, we don't want - # it to prevent the main job from running - pass +def get_ephemeral_path(): + ephemeral_path = os.getenv('EPHEMERAL_PATH') + if ephemeral_path is None: + ephemeral_path = '.' + return ephemeral_path -# Prepare discord connection -intents = nextcord.Intents.default() -intents.message_content = True -client = nextcord.Client(intents=intents) -s3 = None +def init_s3(): + # Read Env + S3_ENABLED = os.getenv('S3_ENABLED') == '1' + S3_BUCKET = os.getenv('S3_BUCKET') + S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY') + S3_SECRET_KEY = os.getenv('S3_SECRET_KEY') + S3_ENDPOINT = os.getenv('S3_ENDPOINT') -if S3_ENABLED: - # Prepare S3 access - if S3_BUCKET is None: - raise Exception('S3 enabled but S3_BUCKET not set.') + # Setup + if S3_ENABLED: + # Prepare S3 access + if S3_BUCKET is None: + raise Exception('S3 enabled but S3_BUCKET not set.') - if S3_ACCESS_KEY is None: - raise Exception('S3 enabled but S3_ACCESS_KEY not set.') + if S3_ACCESS_KEY is None: + raise Exception('S3 enabled but S3_ACCESS_KEY not set.') - if S3_SECRET_KEY is None: - raise Exception('S3 enabled but S3_SECRET_KEY not set.') + if S3_SECRET_KEY is None: + raise Exception('S3 enabled but S3_SECRET_KEY not set.') - if S3_ENDPOINT is None: - raise Exception('S3 enabled but S3_ENDPOINT not set.') + if S3_ENDPOINT is None: + raise Exception('S3 enabled but S3_ENDPOINT not set.') - s3 = boto3.client( - service_name='s3', - aws_access_key_id=S3_ACCESS_KEY, - aws_secret_access_key=S3_SECRET_KEY, - endpoint_url=S3_ENDPOINT, - ) + return boto3.client( + service_name='s3', + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + endpoint_url=S3_ENDPOINT, + ) + else: + return None def hash_string(msg): @@ -165,7 +152,7 @@ def write_to_storage(backup_msg): if S3_ENABLED: s3.put_object(Bucket=S3_BUCKET, Key=f'messages/{enc_hash_str}', Body=str(enc_msg)) else: - with open(os.path.join(EPHEMERAL_PATH, f'{enc_hash_str}.msg'),'w') as msg_file: + with open(os.path.join(get_ephemeral_path(), f'{enc_hash_str}.msg'),'w') as msg_file: msg_file.write(str(enc_msg)) print(f'Message written: {enc_hash_str}') @@ -185,8 +172,6 @@ def get_signing_key(): if os.path.splitext(keyfile)[1] != '.pem': raise Exception('Signing key file not a pem file. make sure the extension is pem.') - - private_key = None if os.path.exists(keyfile): print(f'Loading signing key from {keyfile}...') # To reload the key: @@ -249,13 +234,13 @@ def seal_manifest(guild_id, channel_id): def get_manifest_path(guild_id, channel_id): channel_hash, _ = hash_string(str(channel_id)) guild_hash, _ = hash_string(str(guild_id)) - return os.path.join(EPHEMERAL_PATH, f'{guild_hash}-{channel_hash}.manifest'), f'manifests/{guild_hash}-{channel_hash}' + return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.manifest'), f'manifests/{guild_hash}-{channel_hash}' def get_manifest_seal_path(guild_id, channel_id): channel_hash, _ = hash_string(str(channel_id)) guild_hash, _ = hash_string(str(guild_id)) - return os.path.join(EPHEMERAL_PATH, f'{guild_hash}-{channel_hash}.seal'), f'manifests/seals/{guild_hash}-{channel_hash}' + return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.seal'), f'manifests/seals/{guild_hash}-{channel_hash}' async def backup_channel(channel, last_message_id): @@ -321,7 +306,7 @@ def get_loc_path(channel): """ guild_hash, _ = hash_string(str(channel.guild.id)) channel_hash, _ = hash_string(str(channel.id)) - return os.path.join(EPHEMERAL_PATH, f'{guild_hash}-{channel_hash}.loc'), f'locations/{guild_hash}-{channel_hash}' + return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.loc'), f'locations/{guild_hash}-{channel_hash}' async def get_last_message_id(channel): @@ -370,7 +355,7 @@ async def set_last_message_id(channel, new_last_msg_id): file.write(str(new_last_msg_id)) -def generate_directory_file(target_channels): +def generate_directory_file(target_channels, current_datetime): """ This method generates a directory file based on the servers and channels the bot has access to. see Readme about directory files for more info. @@ -378,7 +363,6 @@ def generate_directory_file(target_channels): :param target_channels: A list of target channels to include in the directory file. :return: None """ - current_datetime = datetime.now() iso8601_format = current_datetime.isoformat().replace(':', '-').replace('.', '-') directory = [] @@ -431,7 +415,7 @@ def generate_directory_file(target_channels): if S3_ENABLED: s3.put_object(Bucket=S3_BUCKET, Key=f'directories/{iso8601_format}', Body=str(enc_msg)) else: - with open(os.path.join(EPHEMERAL_PATH, f'{iso8601_format}.dir'), 'w') as file: + with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dir'), 'w') as file: file.write(str(enc_msg)) # generate seal @@ -442,91 +426,118 @@ def generate_directory_file(target_channels): if S3_ENABLED: s3.put_object(Bucket=S3_BUCKET, Key=f'directories/seals/{iso8601_format}', Body=man_signature) else: - with open(os.path.join(EPHEMERAL_PATH, f'{iso8601_format}.dirseal'), 'w') as seal_file: + with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dirseal'), 'w') as seal_file: seal_file.write(man_signature) -@client.event -async def on_ready(): - """ - This method is an event handler for the `on_ready` event in a Discord bot. - Used here to trigger the backup run after login. - - :return: None - """ - print(f'Bot has logged in to discord as {client.user}') - - # Grab servers and channels - target_channels = client.get_all_channels() +def load_gpg_keys(): + print('Loading GPG keys...') + GPG_KEY_DIR = os.getenv('GPG_KEY_DIR') + if GPG_KEY_DIR is None: + raise Exception('No GPG key directory set. Please set GPG_KEY_DIR.') + key_files = glob.glob(os.path.join(GPG_KEY_DIR, '*.asc')) + imported_keys = [gpg.import_keys_file(key_file) for key_file in key_files] - # Build directory file - generate_directory_file(client.get_all_channels()) + # Trust keys + # Set trust for imported keys + for key in imported_keys: + keyid = key.fingerprints[0] + trust_result = gpg.trust_keys([keyid], 'TRUST_ULTIMATE') + print(f'GPG key imported and trusted: {keyid} => {trust_result}') - # Backup channels - for channel in target_channels: - # Only intressted in text channels - if not isinstance(channel, nextcord.TextChannel): - continue + # Get the fingerprints of the imported keys + return [result.fingerprints[0] for result in imported_keys] - print(f'Backing up Channel {channel.id} on {channel.guild.id}') - # Backup channels - try: - last_msg_id = await get_last_message_id(channel) - new_last_msg_id = await backup_channel(channel, last_msg_id) - await set_last_message_id(channel, new_last_msg_id) +if __name__ == '__main__': + # Main starting + print('EF Backup Bot starting...') - except Exception as e: - print(f'Unable to backup: {e}') + # Read config from env + TOKEN = os.getenv('DISCORD_TOKEN') + if TOKEN is None: + raise Exception('No discord token set. Please set DISCORD_TOKEN.') - # Backup threads in channel - for thread in channel.threads: - print(f'Backing up Thread {thread.id} in Channel {channel.id} on {channel.guild.id}') + HEARTBEAT_URL = os.getenv('HEARTBEAT_URL') - try: - last_msg_id = await get_last_message_id(thread) - new_last_msg_id = await backup_channel(thread, last_msg_id) - await set_last_message_id(thread, new_last_msg_id) - - except Exception as e: - print(f'Unable to backup: {e}') - - # Quit when done - print('Notifying the heartbeat check...') + # Send a start signal to heartbeat try: - requests.get(HEARTBEAT_URL, timeout=10) + requests.get(HEARTBEAT_URL + "/start", timeout=5) except requests.exceptions.RequestException: # If the network request fails for any reason, we don't want # it to prevent the main job from running pass - print('Done. exiting.') - await client.close() + # init S3 + s3 = init_s3() + # prepare gpg + key_fingerprints = load_gpg_keys() -# Main starting -print('EF Backup Bot starting...') + # load the signing key + signing_key = get_signing_key() -# prepare gpg -gpg = GPG() + # Prepare discord connection + intents = nextcord.Intents.default() + intents.message_content = True + client = nextcord.Client(intents=intents) -print('Loading GPG keys...') -key_files = glob.glob(os.path.join(GPG_KEY_DIR,'*.asc')) -imported_keys = [gpg.import_keys_file(key_file) for key_file in key_files] + @client.event + async def on_ready(): + """ + This method is an event handler for the `on_ready` event in a Discord bot. + Used here to trigger the backup run after login. -# Trust keys -# Set trust for imported keys -for key in imported_keys: - keyid = key.fingerprints[0] - trust_result = gpg.trust_keys([keyid], 'TRUST_ULTIMATE') - print(f'GPG key imported and trusted: {keyid} => {trust_result}') + :return: None + """ + print(f'Bot has logged in to discord as {client.user}') -# Get the fingerprints of the imported keys -key_fingerprints = [result.fingerprints[0] for result in imported_keys] + # Grab servers and channels + target_channels = client.get_all_channels() -# load the signing key -signing_key = get_signing_key() + # Build directory file + generate_directory_file(client.get_all_channels(), datetime.now()) + + # Backup channels + for channel in target_channels: + # Only intressted in text channels + if not isinstance(channel, nextcord.TextChannel): + continue + + print(f'Backing up Channel {channel.id} on {channel.guild.id}') + + # Backup channels + try: + last_msg_id = await get_last_message_id(channel) + new_last_msg_id = await backup_channel(channel, last_msg_id) + await set_last_message_id(channel, new_last_msg_id) + + except Exception as e: + print(f'Unable to backup: {e}') + + # Backup threads in channel + for thread in channel.threads: + print(f'Backing up Thread {thread.id} in Channel {channel.id} on {channel.guild.id}') + + try: + last_msg_id = await get_last_message_id(thread) + new_last_msg_id = await backup_channel(thread, last_msg_id) + await set_last_message_id(thread, new_last_msg_id) + + except Exception as e: + print(f'Unable to backup: {e}') + + # Quit when done + print('Notifying the heartbeat check...') + try: + requests.get(HEARTBEAT_URL, timeout=10) + except requests.exceptions.RequestException: + # If the network request fails for any reason, we don't want + # it to prevent the main job from running + pass -# run the bot -client.run(TOKEN) + print('Done. exiting.') + await client.close() + # run the bot - still in main + client.run(TOKEN) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2c15664..8605a5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ six==1.16.0 typing_extensions==4.9.0 urllib3==1.26.18 yarl==1.9.4 +pytest~=7.4.4 \ No newline at end of file diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..bee870a --- /dev/null +++ b/test_main.py @@ -0,0 +1,556 @@ +import base64 +import datetime +import os + +import nextcord +from cryptography.hazmat.primitives.asymmetric import ed25519 +import main +import pytest + + +def is_base64_encoded(string): + try: + decoded = base64.b64decode(string) + if isinstance(decoded, bytes): + return True + except: + pass + return False + + +class TestPaths: + @pytest.mark.parametrize("channel_id, guild_id, expected_path_loc", [ + [ + 1194636226960035892, + 1194636226041479298, + './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e74deb2daee135e0bdfc98d276a43d758af37a12afa8cdf5c4ca4aa9af933aa6.loc' + ], + [ + 1194923425324617778, + 1194636226041479298, + './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f.loc' + ] + ]) + def test_get_loc_path(self,mocker,channel_id, guild_id, expected_path_loc, monkeypatch): + monkeypatch.setenv("EPHEMERAL_PATH", "./") + channel_mock = mocker.Mock() + channel_mock.id = channel_id + channel_mock.guild.id = guild_id + + path_loc, _ = main.get_loc_path(channel_mock) + assert path_loc == expected_path_loc + + @pytest.mark.parametrize("channel_id, guild_id, expected_path_loc", [ + [ + 1194636226960035892, + 1194636226041479298, + 'locations/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e74deb2daee135e0bdfc98d276a43d758af37a12afa8cdf5c4ca4aa9af933aa6' + ], + [ + 1194923425324617778, + 1194636226041479298, + 'locations/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f' + ] + ]) + def test_get_loc_s3(self,mocker,channel_id, guild_id, expected_path_loc, monkeypatch): + monkeypatch.setenv("EPHEMERAL_PATH", "./") + channel_mock = mocker.Mock() + channel_mock.id = channel_id + channel_mock.guild.id = guild_id + + _, s3_loc = main.get_loc_path(channel_mock) + assert s3_loc == expected_path_loc + + @pytest.mark.parametrize("channel_id, guild_id, expected_path_loc", [ + [ + 1194636226960035892, + 1194636226041479298, + './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e74deb2daee135e0bdfc98d276a43d758af37a12afa8cdf5c4ca4aa9af933aa6.manifest' + ], + [ + 1194923425324617778, + 1194636226041479298, + './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f.manifest' + ] + ]) + def test_get_manifest_path(self,mocker,channel_id, guild_id, expected_path_loc): + main.EPHEMERAL_PATH = "./" + path_loc, _ = main.get_manifest_path(guild_id, channel_id) + assert path_loc == expected_path_loc + + @pytest.mark.parametrize("channel_id, guild_id, expected_path_loc", [ + [ + 1194636226960035892, + 1194636226041479298, + 'manifests/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e74deb2daee135e0bdfc98d276a43d758af37a12afa8cdf5c4ca4aa9af933aa6' + ], + [ + 1194923425324617778, + 1194636226041479298, + 'manifests/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f' + ] + ]) + def test_get_manifest_s3(self,mocker,channel_id, guild_id, expected_path_loc): + main.EPHEMERAL_PATH = "./" + _, s3_loc = main.get_manifest_path(guild_id, channel_id) + assert s3_loc == expected_path_loc + + @pytest.mark.parametrize("channel_id, guild_id, expected_path_loc", [ + [ + 1194636226960035892, + 1194636226041479298, + './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e74deb2daee135e0bdfc98d276a43d758af37a12afa8cdf5c4ca4aa9af933aa6.seal' + ], + [ + 1194923425324617778, + 1194636226041479298, + './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f.seal' + ] + ]) + def test_get_manifest_seal_path(self, mocker, channel_id, guild_id, expected_path_loc): + main.EPHEMERAL_PATH = "./" + path_loc, _ = main.get_manifest_seal_path(guild_id, channel_id) + assert path_loc == expected_path_loc + + @pytest.mark.parametrize("channel_id, guild_id, expected_path_loc", [ + [ + 1194636226960035892, + 1194636226041479298, + 'manifests/seals/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e74deb2daee135e0bdfc98d276a43d758af37a12afa8cdf5c4ca4aa9af933aa6' + ], + [ + 1194923425324617778, + 1194636226041479298, + 'manifests/seals/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f' + ] + ]) + def test_get_manifest_seal_s3(self, mocker, channel_id, guild_id, expected_path_loc): + main.EPHEMERAL_PATH = "./" + _, s3_loc = main.get_manifest_seal_path(guild_id, channel_id) + assert s3_loc == expected_path_loc + + +class TestConfiguration: + + @pytest.mark.parametrize("s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint", [ + ["1", "the-bucket", "12345-access", "12345-secret", "http://localhost:9000"], + ]) + def test_successful_s3_init(self, monkeypatch, s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint): + # prep env + monkeypatch.setenv("S3_ENABLED", s3enabled) + monkeypatch.setenv("S3_BUCKET", s3bucket) + monkeypatch.setenv("S3_ACCESS_KEY", s3accesskey) + monkeypatch.setenv("S3_SECRET_KEY", s3secretkey) + monkeypatch.setenv("S3_ENDPOINT", s3endpoint) + + s3_obj = main.init_s3() + + assert s3_obj is not None + assert s3_obj.meta.endpoint_url == s3endpoint + assert s3_obj.meta.region_name == "us-east-1" + + @pytest.mark.parametrize("s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint", [ + ["0", "the-bucket", "12345-access", "12345-secret", "http://localhost:9000"], + ]) + def test_disabled_s3_init(self, monkeypatch, s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint): + # prep env + monkeypatch.setenv("S3_ENABLED", s3enabled) + monkeypatch.setenv("S3_BUCKET", s3bucket) + monkeypatch.setenv("S3_ACCESS_KEY", s3accesskey) + monkeypatch.setenv("S3_SECRET_KEY", s3secretkey) + monkeypatch.setenv("S3_ENDPOINT", s3endpoint) + + s3_obj = main.init_s3() + + assert s3_obj is None + + @pytest.mark.parametrize("s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint, expException", [ + ["1", None, "12345-access", "12345-secret", "http://localhost:9000", "S3 enabled but S3_BUCKET not set."], + ["1", "bucket", None, "12345-secret", "http://localhost:9000", "S3 enabled but S3_ACCESS_KEY not set."], + ["1", "bucket", "12345-access", None, "http://localhost:9000", "S3 enabled but S3_SECRET_KEY not set."], + ["1", "bucket", "12345-access", "12345-secret", None, "S3 enabled but S3_ENDPOINT not set."], + ]) + def test_failed_s3_init(self, monkeypatch, s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint, expException): + # prep env + monkeypatch.setenv("S3_ENABLED", s3enabled) + if s3bucket is not None: + monkeypatch.setenv("S3_BUCKET", s3bucket) + if s3accesskey is not None: + monkeypatch.setenv("S3_ACCESS_KEY", s3accesskey) + if s3secretkey is not None: + monkeypatch.setenv("S3_SECRET_KEY", s3secretkey) + if s3endpoint is not None: + monkeypatch.setenv("S3_ENDPOINT", s3endpoint) + + with pytest.raises(Exception, match=expException): + main.init_s3() + + +class TestMisc: + @pytest.mark.parametrize("hash_input, expected", [ + ["hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"], + ]) + def test_hash_string(self, hash_input, expected): + r_str, r_bytes = main.hash_string(hash_input) + assert r_str == expected + + @pytest.mark.asyncio + async def test_last_message_id(self,monkeypatch,tmp_path,mocker): + msg_id = 123456543453 + monkeypatch.setenv("EPHEMERAL_PATH", str(tmp_path.absolute())) + channel_mock = mocker.Mock() + channel_mock.id = 1234 + channel_mock.guild.id = 5678 + + # Get empty id + empty_channel = await main.get_last_message_id(channel_mock) + assert empty_channel == -1 + + # store an id + await main.set_last_message_id(channel_mock, msg_id) + + # recall id + non_empty_channel = await main.get_last_message_id(channel_mock) + assert non_empty_channel == msg_id + +class TestMessage: + def test_extract_basic_message(self, mocker): + + utc_date = datetime.datetime.utcnow() + expected_iso_date = utc_date.isoformat() + + mock_msg = mocker.Mock() + mock_msg.attachments = [] + mock_msg.author.id = 12345 + mock_msg.author.name = 'Mr. Test on Server' + mock_msg.author.global_name = 'Mr. Test is Global' + mock_msg.guild.id = 56789 + mock_msg.guild.name = "The Servers Name" + mock_msg.channel.id = 13579 + mock_msg.channel.name = "The Channel Name" + mock_msg.channel.category.name = "The Channel Category" + mock_msg.content = "Hello World! This is a fun test!" + mock_msg.created_at = utc_date + + # make sure no parent exists + del mock_msg.channel.parent + + message = main.extract_message(mock_msg) + + # see if everything made it into the message + assert message['author']['id'] == mock_msg.author.id + assert message['author']['name'] == mock_msg.author.name + assert message['author']['global_name'] == mock_msg.author.global_name + assert message['server']['id'] == mock_msg.guild.id + assert message['server']['name'] == mock_msg.guild.name + assert message['channel']['id'] == mock_msg.channel.id + assert message['channel']['name'] == mock_msg.channel.name + assert message['category'] == mock_msg.channel.category.name + assert message['content'] == mock_msg.content + assert message['created_at'] == expected_iso_date + assert message['parent'] == '' + assert len(message['attachments']) == 0 + + def test_extract_message_with_parent(self, mocker): + + utc_date = datetime.datetime.utcnow() + expected_iso_date = utc_date.isoformat() + + mock_msg = mocker.Mock() + mock_msg.attachments = [] + mock_msg.author.id = 12345 + mock_msg.author.name = 'Mr. Test on Server' + mock_msg.author.global_name = 'Mr. Test is Global' + mock_msg.guild.id = 56789 + mock_msg.guild.name = "The Servers Name" + mock_msg.channel.id = 13579 + mock_msg.channel.name = "The Channel Name" + mock_msg.channel.category.name = "The Channel Category" + mock_msg.content = "Hello World! This is a fun test!" + mock_msg.created_at = utc_date + mock_msg.channel.parent.name = "Parent Channel" + + message = main.extract_message(mock_msg) + + # see if everything made it into the message + assert message['author']['id'] == mock_msg.author.id + assert message['author']['name'] == mock_msg.author.name + assert message['author']['global_name'] == mock_msg.author.global_name + assert message['server']['id'] == mock_msg.guild.id + assert message['server']['name'] == mock_msg.guild.name + assert message['channel']['id'] == mock_msg.channel.id + assert message['channel']['name'] == mock_msg.channel.name + assert message['category'] == mock_msg.channel.category.name + assert message['content'] == mock_msg.content + assert message['created_at'] == expected_iso_date + assert message['parent'] == mock_msg.channel.parent.name + assert len(message['attachments']) == 0 + + def test_extract_message_with_attachment(self, mocker): + # this is using a real attachment + + utc_date = datetime.datetime.utcnow() + expected_iso_date = utc_date.isoformat() + + mock_attach = mocker.Mock() + mock_attach.content_type = 'image/jpeg' + mock_attach.filename = 'some-image.jpg' + mock_attach.url = 'https://cdn.discordapp.com/attachments/1194636226960035892/1194637333845254144/8-Co4JPO3JEypnyGC.png?ex=65b113b7&is=659e9eb7&hm=abff144dfa9c75be4ec91130ffd35fbbf16ffd7a12d0373cd11f35418ab273d5&' + + mock_msg = mocker.Mock() + mock_msg.attachments = [mock_attach] + mock_msg.author.id = 12345 + mock_msg.author.name = 'Mr. Test on Server' + mock_msg.author.global_name = 'Mr. Test is Global' + mock_msg.guild.id = 56789 + mock_msg.guild.name = "The Servers Name" + mock_msg.channel.id = 13579 + mock_msg.channel.name = "The Channel Name" + mock_msg.channel.category.name = "The Channel Category" + mock_msg.content = "Hello World! This is a fun test!" + mock_msg.created_at = utc_date + + message = main.extract_message(mock_msg) + + # see if everything made it into the message + assert message['author']['id'] == mock_msg.author.id + assert message['author']['name'] == mock_msg.author.name + assert message['author']['global_name'] == mock_msg.author.global_name + assert message['server']['id'] == mock_msg.guild.id + assert message['server']['name'] == mock_msg.guild.name + assert message['channel']['id'] == mock_msg.channel.id + assert message['channel']['name'] == mock_msg.channel.name + assert message['category'] == mock_msg.channel.category.name + assert message['content'] == mock_msg.content + assert message['created_at'] == expected_iso_date + assert len(message['attachments']) == 1 + assert message['attachments'][0]['type'] == mock_attach.content_type + assert message['attachments'][0]['origin_name'] == mock_attach.filename + assert is_base64_encoded(message['attachments'][0]['content']) + + +class TestSigning: + def test_load_sign_key(self, tmp_path, monkeypatch): + key_file = tmp_path / 'priv_key.pem' + monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + + # Should generate a new signing key + priv_key = main.get_signing_key() + + # check ifs an ED25519 private key + assert isinstance(priv_key, ed25519.Ed25519PrivateKey) + + # check if the pubkey was generated + assert os.path.exists(str(key_file.absolute()).replace('.pem','.pub')) + + # check if the perms are correct + key_file_stat = os.stat(str(key_file.absolute())) + assert key_file_stat.st_mode == 0o100600 + + # Should load the key from disk + priv_key_second = main.get_signing_key() + + # should be the same key - compare pubkeys + assert priv_key.public_key() == priv_key_second.public_key() + + @pytest.mark.parametrize('filename, expException',[ + ['key.foo', 'Signing key file not a pem file. make sure the extension is pem.'], + [None, 'Signing key not configured. Please set SIGN_KEY_PEM.'] + ]) + def test_load_sign_key_fail(self, tmp_path, monkeypatch, filename, expException): + if filename is not None: + key_file = tmp_path / filename + monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + + # Should generate a new signing key + with pytest.raises(Exception, match=expException): + main.get_signing_key() + + +class TestStorage: + def test_write_to_storage_fail_encrypt_no_keys(self, tmp_path, monkeypatch): + # message doesn't need to conform to format completely + dt_iso = datetime.datetime.utcnow().isoformat() + msg = { + 'content': 'hello world', + 'author': 'Testy McTestface', + 'server': { + 'id': 1234 + }, + 'channel': { + 'id': 1234 + }, + 'created_at': dt_iso + } + + # set the ephemeral path + monkeypatch.setenv('EPHEMERAL_PATH', str(tmp_path)) + + # No GPG keys loaded so it should fail here + with pytest.raises(Exception, match="No recipients specified with asymmetric encryption"): + main.write_to_storage(msg) + + def test_write_to_storage(self, tmp_path, monkeypatch): + id = 1234 + hash_for_id = "03ac674216f3e15c761ee1a5e255f067953623c8b388b4459e13f978d7c846f4" + + # message doesn't need to conform to format completely + dt_iso = datetime.datetime.utcnow().isoformat() + msg = { + 'content': 'hello world', + 'author': 'Testy McTestface', + 'server': { + 'id': id + }, + 'channel': { + 'id': id + }, + 'created_at': dt_iso + } + + # set the ephemeral path + monkeypatch.setenv('EPHEMERAL_PATH', str(tmp_path)) + + # generate a signing key + key_file = tmp_path / 'priv_key.pem' + monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + main.signing_key = main.get_signing_key() + + # load a GPG key + pub_key = '''-----BEGIN PGP PUBLIC KEY BLOCK----- + +mDMEZZ6XXxYJKwYBBAHaRw8BAQdAea3323zBNgy12RVKkCWWgfDe5vSLW3R9/6LS +pqE/hxG0MUdQRyB0ZXN0IGtleSAoT05MWSBGT1IgVEVTVElORykgPG1hcmt1c0B0 +ZXN0Lm9yZz6IkwQTFgoAOxYhBMuek9p7pwAmbyYg1qWXo028DaarBQJlnpdfAhsD +BQsJCAcCAiICBhUKCQgLAgQWAgMBAh4HAheAAAoJEKWXo028DaarUU8BAOyAmxed +yWBHajYaEoyn0wfSEGIFVCXatsvcbYpL6hc+AQCrn/t+oC/OqrO4HWPhQDAEgYtW +9TWOC3A6CYyodYdPD7g4BGWel18SCisGAQQBl1UBBQEBB0DLccDTMTVh0a7Su94Z +ktDBAzTjYzQ5j2sxKe/OkK2VGQMBCAeIeAQYFgoAIBYhBMuek9p7pwAmbyYg1qWX +o028DaarBQJlnpdfAhsMAAoJEKWXo028DaarD+EA/0SIgap5bj9FqE+TwVNILLuO +UiwX/3AQaMi36RJ9oZYKAP9gIkwaL/m0Xu8WQiUNkATCHFsmauptqQw5V8GkSp0l +Ag== +=IhBg +-----END PGP PUBLIC KEY BLOCK----- +''' + + monkeypatch.setenv('GPG_KEY_DIR', str(tmp_path.absolute())) + gpg_pub_key = tmp_path / 'gpg_pub_key.asc' + gpg_pub_key.write_text(pub_key) + + main.key_fingerprints = main.load_gpg_keys() + + main.write_to_storage(msg) + + # a manifest should be in the tmp path now + result_manifest_path = tmp_path / f'{hash_for_id}-{hash_for_id}.manifest' + assert os.path.exists(result_manifest_path) + + # read manifest + with open(result_manifest_path, 'r') as f: + manifest_content = f.read() + manifest_fields = manifest_content.split(',') + + # we expect 3 fields in the manifest + assert len(manifest_fields) == 3 + + # we expect the first to be the iso date of the message + assert manifest_fields[0] == dt_iso + + # check if the msg file exists + msg_hash = manifest_fields[1] + + # Check if the filename is correct + msg_path = tmp_path / f'{msg_hash}.msg' + assert os.path.exists(str(msg_path.absolute())) + + def test_write_directory_file(self, tmp_path, monkeypatch, mocker): + # disable the test for nextcord.TextChannel + mocker.patch('__main__.isinstance', return_value=True) + + # set the ephemeral path + monkeypatch.setenv('EPHEMERAL_PATH', str(tmp_path)) + + # generate a signing key + key_file = tmp_path / 'priv_key.pem' + monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + main.signing_key = main.get_signing_key() + + # load a GPG key + pub_key = '''-----BEGIN PGP PUBLIC KEY BLOCK----- + + mDMEZZ6XXxYJKwYBBAHaRw8BAQdAea3323zBNgy12RVKkCWWgfDe5vSLW3R9/6LS + pqE/hxG0MUdQRyB0ZXN0IGtleSAoT05MWSBGT1IgVEVTVElORykgPG1hcmt1c0B0 + ZXN0Lm9yZz6IkwQTFgoAOxYhBMuek9p7pwAmbyYg1qWXo028DaarBQJlnpdfAhsD + BQsJCAcCAiICBhUKCQgLAgQWAgMBAh4HAheAAAoJEKWXo028DaarUU8BAOyAmxed + yWBHajYaEoyn0wfSEGIFVCXatsvcbYpL6hc+AQCrn/t+oC/OqrO4HWPhQDAEgYtW + 9TWOC3A6CYyodYdPD7g4BGWel18SCisGAQQBl1UBBQEBB0DLccDTMTVh0a7Su94Z + ktDBAzTjYzQ5j2sxKe/OkK2VGQMBCAeIeAQYFgoAIBYhBMuek9p7pwAmbyYg1qWX + o028DaarBQJlnpdfAhsMAAoJEKWXo028DaarD+EA/0SIgap5bj9FqE+TwVNILLuO + UiwX/3AQaMi36RJ9oZYKAP9gIkwaL/m0Xu8WQiUNkATCHFsmauptqQw5V8GkSp0l + Ag== + =IhBg + -----END PGP PUBLIC KEY BLOCK----- + ''' + + monkeypatch.setenv('GPG_KEY_DIR', str(tmp_path.absolute())) + gpg_pub_key = tmp_path / 'gpg_pub_key.asc' + gpg_pub_key.write_text(pub_key) + + main.key_fingerprints = main.load_gpg_keys() + + # the server list mock from discord + target_channels = [] + + channel1 = mocker.Mock(spec=nextcord.TextChannel) + channel1.guild.id = 1111 + channel1.guild.name = "Server 1" + channel1.id = 110011 + channel1.name = "Channel 1" + channel1.threads = [] + target_channels.append(channel1) + + channel2 = mocker.Mock(spec=nextcord.TextChannel) + channel2.guild.id = 1111 + channel2.guild.name = "Server 1" + channel2.id = 220011 + channel2.name = "Channel 2" + channel2.threads = [] + target_channels.append(channel2) + + channel3 = mocker.Mock(spec=nextcord.TextChannel) + channel3.guild.id = 2222 + channel3.guild.name = "Server 2" + channel3.id = 220022 + channel3.name = "Channel 1" + channel3.threads = [] + target_channels.append(channel3) + + # with threads + thread1 = mocker.Mock() + thread1.name = "Thread 1" + thread1.id = 121212 + + channel4 = mocker.Mock(spec=nextcord.TextChannel) + channel4.guild.id = 2222 + channel4.guild.name = "Server 2" + channel4.id = 330022 + channel4.name = "Channel 2" + channel4.threads = [thread1] + target_channels.append(channel4) + + channel5 = mocker.Mock(spec=nextcord.VoiceChannel) + target_channels.append(channel5) + + + dt = datetime.datetime.now() + iso8601_format = dt.isoformat().replace(':', '-').replace('.', '-') + + main.generate_directory_file(target_channels, dt) + + # check if a directory was written + result_manifest_path = tmp_path / f'{iso8601_format}.dir' + assert os.path.exists(result_manifest_path) + + # check if a directory seal was written + result_manifest_path = tmp_path / f'{iso8601_format}.dirseal' + assert os.path.exists(result_manifest_path) \ No newline at end of file