Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
elasticroentgen committed Jan 17, 2024
1 parent fec80ae commit e786dcb
Show file tree
Hide file tree
Showing 4 changed files with 722 additions and 120 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/docker-publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,40 @@ env:


jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10" ]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-asyncio
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest
build:
runs-on: ubuntu-latest
permissions:
Expand Down
251 changes: 131 additions & 120 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,50 @@
import boto3
from botocore.exceptions import ClientError

# Read config from env
TOKEN = os.getenv('DISCORD_TOKEN')
if TOKEN is None:
raise Exception('No discord token set. Please set DISCORD_TOKEN.')

GPG_KEY_DIR = os.getenv('GPG_KEY_DIR')
if GPG_KEY_DIR is None:
raise Exception('No GPG key directory set. Please set GPG_KEY_DIR.')

EPHEMERAL_PATH = os.getenv('EPHEMERAL_PATH')
if EPHEMERAL_PATH is None:
print(f'EPHEMERAL_PATH not set. will dump all files in working dir.')
EPHEMERAL_PATH = '.'

S3_ENABLED = os.getenv('S3_ENABLED') == '1'
S3_BUCKET = os.getenv('S3_BUCKET')
S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY')
S3_SECRET_KEY = os.getenv('S3_SECRET_KEY')
S3_ENDPOINT = os.getenv('S3_ENDPOINT')
global S3_BUCKET
S3_ENABLED = False
key_fingerprints = []
gpg = GPG()

HEARTBEAT_URL = os.getenv('HEARTBEAT_URL')

# Send a start signal to heartbeat
try:
requests.get(HEARTBEAT_URL + "/start", timeout=5)
except requests.exceptions.RequestException:
# If the network request fails for any reason, we don't want
# it to prevent the main job from running
pass
def get_ephemeral_path():
ephemeral_path = os.getenv('EPHEMERAL_PATH')
if ephemeral_path is None:
ephemeral_path = '.'
return ephemeral_path

# Prepare discord connection
intents = nextcord.Intents.default()
intents.message_content = True
client = nextcord.Client(intents=intents)

s3 = None
def init_s3():
# Read Env
S3_ENABLED = os.getenv('S3_ENABLED') == '1'
S3_BUCKET = os.getenv('S3_BUCKET')
S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY')
S3_SECRET_KEY = os.getenv('S3_SECRET_KEY')
S3_ENDPOINT = os.getenv('S3_ENDPOINT')

if S3_ENABLED:
# Prepare S3 access
if S3_BUCKET is None:
raise Exception('S3 enabled but S3_BUCKET not set.')
# Setup
if S3_ENABLED:
# Prepare S3 access
if S3_BUCKET is None:
raise Exception('S3 enabled but S3_BUCKET not set.')

if S3_ACCESS_KEY is None:
raise Exception('S3 enabled but S3_ACCESS_KEY not set.')
if S3_ACCESS_KEY is None:
raise Exception('S3 enabled but S3_ACCESS_KEY not set.')

if S3_SECRET_KEY is None:
raise Exception('S3 enabled but S3_SECRET_KEY not set.')
if S3_SECRET_KEY is None:
raise Exception('S3 enabled but S3_SECRET_KEY not set.')

if S3_ENDPOINT is None:
raise Exception('S3 enabled but S3_ENDPOINT not set.')
if S3_ENDPOINT is None:
raise Exception('S3 enabled but S3_ENDPOINT not set.')

s3 = boto3.client(
service_name='s3',
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
endpoint_url=S3_ENDPOINT,
)
return boto3.client(
service_name='s3',
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
endpoint_url=S3_ENDPOINT,
)
else:
return None


def hash_string(msg):
Expand Down Expand Up @@ -165,7 +152,7 @@ def write_to_storage(backup_msg):
if S3_ENABLED:
s3.put_object(Bucket=S3_BUCKET, Key=f'messages/{enc_hash_str}', Body=str(enc_msg))
else:
with open(os.path.join(EPHEMERAL_PATH, f'{enc_hash_str}.msg'),'w') as msg_file:
with open(os.path.join(get_ephemeral_path(), f'{enc_hash_str}.msg'),'w') as msg_file:
msg_file.write(str(enc_msg))
print(f'Message written: {enc_hash_str}')

Expand All @@ -185,8 +172,6 @@ def get_signing_key():
if os.path.splitext(keyfile)[1] != '.pem':
raise Exception('Signing key file not a pem file. make sure the extension is pem.')


private_key = None
if os.path.exists(keyfile):
print(f'Loading signing key from {keyfile}...')
# To reload the key:
Expand Down Expand Up @@ -249,13 +234,13 @@ def seal_manifest(guild_id, channel_id):
def get_manifest_path(guild_id, channel_id):
channel_hash, _ = hash_string(str(channel_id))
guild_hash, _ = hash_string(str(guild_id))
return os.path.join(EPHEMERAL_PATH, f'{guild_hash}-{channel_hash}.manifest'), f'manifests/{guild_hash}-{channel_hash}'
return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.manifest'), f'manifests/{guild_hash}-{channel_hash}'


def get_manifest_seal_path(guild_id, channel_id):
channel_hash, _ = hash_string(str(channel_id))
guild_hash, _ = hash_string(str(guild_id))
return os.path.join(EPHEMERAL_PATH, f'{guild_hash}-{channel_hash}.seal'), f'manifests/seals/{guild_hash}-{channel_hash}'
return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.seal'), f'manifests/seals/{guild_hash}-{channel_hash}'


async def backup_channel(channel, last_message_id):
Expand Down Expand Up @@ -321,7 +306,7 @@ def get_loc_path(channel):
"""
guild_hash, _ = hash_string(str(channel.guild.id))
channel_hash, _ = hash_string(str(channel.id))
return os.path.join(EPHEMERAL_PATH, f'{guild_hash}-{channel_hash}.loc'), f'locations/{guild_hash}-{channel_hash}'
return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.loc'), f'locations/{guild_hash}-{channel_hash}'


async def get_last_message_id(channel):
Expand Down Expand Up @@ -370,15 +355,14 @@ async def set_last_message_id(channel, new_last_msg_id):
file.write(str(new_last_msg_id))


def generate_directory_file(target_channels):
def generate_directory_file(target_channels, current_datetime):
"""
This method generates a directory file based on the servers and channels the bot has access to.
see Readme about directory files for more info.
:param target_channels: A list of target channels to include in the directory file.
:return: None
"""
current_datetime = datetime.now()
iso8601_format = current_datetime.isoformat().replace(':', '-').replace('.', '-')

directory = []
Expand Down Expand Up @@ -431,7 +415,7 @@ def generate_directory_file(target_channels):
if S3_ENABLED:
s3.put_object(Bucket=S3_BUCKET, Key=f'directories/{iso8601_format}', Body=str(enc_msg))
else:
with open(os.path.join(EPHEMERAL_PATH, f'{iso8601_format}.dir'), 'w') as file:
with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dir'), 'w') as file:
file.write(str(enc_msg))

# generate seal
Expand All @@ -442,91 +426,118 @@ def generate_directory_file(target_channels):
if S3_ENABLED:
s3.put_object(Bucket=S3_BUCKET, Key=f'directories/seals/{iso8601_format}', Body=man_signature)
else:
with open(os.path.join(EPHEMERAL_PATH, f'{iso8601_format}.dirseal'), 'w') as seal_file:
with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dirseal'), 'w') as seal_file:
seal_file.write(man_signature)


@client.event
async def on_ready():
"""
This method is an event handler for the `on_ready` event in a Discord bot.
Used here to trigger the backup run after login.
:return: None
"""
print(f'Bot has logged in to discord as {client.user}')

# Grab servers and channels
target_channels = client.get_all_channels()
def load_gpg_keys():
print('Loading GPG keys...')
GPG_KEY_DIR = os.getenv('GPG_KEY_DIR')
if GPG_KEY_DIR is None:
raise Exception('No GPG key directory set. Please set GPG_KEY_DIR.')
key_files = glob.glob(os.path.join(GPG_KEY_DIR, '*.asc'))
imported_keys = [gpg.import_keys_file(key_file) for key_file in key_files]

# Build directory file
generate_directory_file(client.get_all_channels())
# Trust keys
# Set trust for imported keys
for key in imported_keys:
keyid = key.fingerprints[0]
trust_result = gpg.trust_keys([keyid], 'TRUST_ULTIMATE')
print(f'GPG key imported and trusted: {keyid} => {trust_result}')

# Backup channels
for channel in target_channels:
# Only intressted in text channels
if not isinstance(channel, nextcord.TextChannel):
continue
# Get the fingerprints of the imported keys
return [result.fingerprints[0] for result in imported_keys]

print(f'Backing up Channel {channel.id} on {channel.guild.id}')

# Backup channels
try:
last_msg_id = await get_last_message_id(channel)
new_last_msg_id = await backup_channel(channel, last_msg_id)
await set_last_message_id(channel, new_last_msg_id)
if __name__ == '__main__':
# Main starting
print('EF Backup Bot starting...')

except Exception as e:
print(f'Unable to backup: {e}')
# Read config from env
TOKEN = os.getenv('DISCORD_TOKEN')
if TOKEN is None:
raise Exception('No discord token set. Please set DISCORD_TOKEN.')

# Backup threads in channel
for thread in channel.threads:
print(f'Backing up Thread {thread.id} in Channel {channel.id} on {channel.guild.id}')
HEARTBEAT_URL = os.getenv('HEARTBEAT_URL')

try:
last_msg_id = await get_last_message_id(thread)
new_last_msg_id = await backup_channel(thread, last_msg_id)
await set_last_message_id(thread, new_last_msg_id)

except Exception as e:
print(f'Unable to backup: {e}')

# Quit when done
print('Notifying the heartbeat check...')
# Send a start signal to heartbeat
try:
requests.get(HEARTBEAT_URL, timeout=10)
requests.get(HEARTBEAT_URL + "/start", timeout=5)
except requests.exceptions.RequestException:
# If the network request fails for any reason, we don't want
# it to prevent the main job from running
pass

print('Done. exiting.')
await client.close()
# init S3
s3 = init_s3()

# prepare gpg
key_fingerprints = load_gpg_keys()

# Main starting
print('EF Backup Bot starting...')
# load the signing key
signing_key = get_signing_key()

# prepare gpg
gpg = GPG()
# Prepare discord connection
intents = nextcord.Intents.default()
intents.message_content = True
client = nextcord.Client(intents=intents)

print('Loading GPG keys...')
key_files = glob.glob(os.path.join(GPG_KEY_DIR,'*.asc'))
imported_keys = [gpg.import_keys_file(key_file) for key_file in key_files]
@client.event
async def on_ready():
"""
This method is an event handler for the `on_ready` event in a Discord bot.
Used here to trigger the backup run after login.
# Trust keys
# Set trust for imported keys
for key in imported_keys:
keyid = key.fingerprints[0]
trust_result = gpg.trust_keys([keyid], 'TRUST_ULTIMATE')
print(f'GPG key imported and trusted: {keyid} => {trust_result}')
:return: None
"""
print(f'Bot has logged in to discord as {client.user}')

# Get the fingerprints of the imported keys
key_fingerprints = [result.fingerprints[0] for result in imported_keys]
# Grab servers and channels
target_channels = client.get_all_channels()

# load the signing key
signing_key = get_signing_key()
# Build directory file
generate_directory_file(client.get_all_channels(), datetime.now())

# Backup channels
for channel in target_channels:
# Only intressted in text channels
if not isinstance(channel, nextcord.TextChannel):
continue

print(f'Backing up Channel {channel.id} on {channel.guild.id}')

# Backup channels
try:
last_msg_id = await get_last_message_id(channel)
new_last_msg_id = await backup_channel(channel, last_msg_id)
await set_last_message_id(channel, new_last_msg_id)

except Exception as e:
print(f'Unable to backup: {e}')

# Backup threads in channel
for thread in channel.threads:
print(f'Backing up Thread {thread.id} in Channel {channel.id} on {channel.guild.id}')

try:
last_msg_id = await get_last_message_id(thread)
new_last_msg_id = await backup_channel(thread, last_msg_id)
await set_last_message_id(thread, new_last_msg_id)

except Exception as e:
print(f'Unable to backup: {e}')

# Quit when done
print('Notifying the heartbeat check...')
try:
requests.get(HEARTBEAT_URL, timeout=10)
except requests.exceptions.RequestException:
# If the network request fails for any reason, we don't want
# it to prevent the main job from running
pass

# run the bot
client.run(TOKEN)
print('Done. exiting.')
await client.close()

# run the bot - still in main
client.run(TOKEN)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ six==1.16.0
typing_extensions==4.9.0
urllib3==1.26.18
yarl==1.9.4
pytest~=7.4.4
Loading

0 comments on commit e786dcb

Please sign in to comment.