Skip to content

Commit

Permalink
various fixes and refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Oct 31, 2023
1 parent 853e28b commit be6e5bf
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 74 deletions.
34 changes: 19 additions & 15 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Literal, TypeVar, Optional, Union, cast, get_args, overload, Any, Tuple
from pathlib import Path

import dotenv
from structlog import get_logger
Expand All @@ -10,16 +11,15 @@
dotenv.load_dotenv()
NOT_PROVIDED = "__NOT_PROVIDED__"

module_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "modules")
module_dir = Path(__file__).parent / 'modules'


def get_all_modules() -> frozenset[str]:
modules = set()
for file_name in os.listdir(module_dir):
if file_name.endswith(".py") and file_name not in ("__init__.py", "module.py"):
modules.add(file_name[:-3])

return frozenset(modules)
return frozenset({
filename.stem
for filename in module_dir.glob('*.py')
if filename.suffix == '.py' and filename.name not in ('__init__.py', 'module.py')
})


ALL_STAMPY_MODULES = get_all_modules()
Expand Down Expand Up @@ -47,8 +47,7 @@ def getenv(env_var: str, default = NOT_PROVIDED) -> str:


def getenv_bool(env_var: str) -> bool:
e = getenv(env_var, default="UNDEFINED")
return e != "UNDEFINED"
return getenv(env_var, default="UNDEFINED") != "UNDEFINED"


# fmt:off
Expand All @@ -64,12 +63,12 @@ def getenv_unique_set(var_name: str, default: T) -> Union[frozenset[str], T]:...


def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozenset, T]:
l = getenv(var_name, default="EMPTY_SET").split(" ")
if l == ["EMPTY_SET"]:
var = getenv(var_name, default='')
if not var.strip():
return default
s = frozenset(l)
assert len(l) == len(s), f"{var_name} has duplicate members! {l}"
return s
items = var.split()
assert len(items) == len(set(items)), f"{var_name} has duplicate members! {sorted(items)}"
return frozenset(items)


maximum_recursion_depth = 30
Expand Down Expand Up @@ -151,6 +150,11 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense
channel_whitelist: Optional[frozenset[str]]
disable_prompt_moderation: bool

## Flask settings
if flask_port := getenv('FLASK_PORT', '2300'):
flask_port = int(flask_port)
flask_address = getenv('FLASK_ADDRESS', "0.0.0.0")

is_rob_server = getenv_bool("IS_ROB_SERVER")
if is_rob_server:
# use robmiles server defaults
Expand Down Expand Up @@ -222,7 +226,7 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense
bot_dev_roles = getenv_unique_set("BOT_DEV_ROLES", frozenset())
bot_dev_ids = getenv_unique_set("BOT_DEV_IDS", frozenset())
bot_control_channel_ids = getenv_unique_set("BOT_CONTROL_CHANNEL_IDS", frozenset())
bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID")
bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID", None)
bot_error_channel_id = getenv("BOT_ERROR_CHANNEL_ID", bot_private_channel_id)
# NOTE: Rob's invite/member management functions, not ported yet
member_role_id = getenv("MEMBER_ROLE_ID", default=None)
Expand Down
5 changes: 1 addition & 4 deletions modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,7 @@ def is_at_me(self, message: ServiceMessage) -> Union[str, Literal[False]]:
)
at_me = True

if at_me:
return text
else:
return False
return at_me and text

def get_guild_and_invite_role(self):
return get_guild_and_invite_role()
Expand Down
4 changes: 3 additions & 1 deletion servicemodules/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@


# TODO: store long responses temporarily for viewing outside of discord
def limit_text_and_notify(response: Response, why_traceback: list[str]) -> str:
def limit_text_and_notify(response: Response, why_traceback: list[str]) -> Union[str, Iterable]:
if isinstance(response.text, str):
wastrimmed = False
wastrimmed, text_to_return = limit_text(response.text, discordLimit)
if wastrimmed:
why_traceback.append(f"I had to trim the output from {response.module}")
return text_to_return
elif isinstance(response.text, (list, tuple)):
return response.text
return ""


Expand Down
70 changes: 40 additions & 30 deletions servicemodules/flask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask import Response as FlaskResponse
from collections.abc import Iterable
from config import TEST_RESPONSE_PREFIX, maximum_recursion_depth
from config import TEST_RESPONSE_PREFIX, maximum_recursion_depth, flask_port, flask_address
from flask import Flask, request
from modules.module import Response
from structlog import get_logger
Expand Down Expand Up @@ -46,29 +46,51 @@ def process_event(self) -> FlaskResponse:
Keys are currently defined in utilities.flaskutils
"""
if request.is_json:
message = request.get_json()
message[
"content"
] += " s" # This plus s should make it always trigger the is_at_me functions.
message = FlaskMessage.from_dict(request.get_json())
elif request.form:
message = FlaskMessage.from_dict(request.form)
else:
content = (
request.form.get("content") + " s"
) # This plus s should make it always trigger the is_at_me functions.
key = request.form.get("key")
modules = json.loads(
request.form.get("modules", json.dumps(list(self.modules.keys())))
)
message = {"content": content, "key": key, "modules": modules}
response = self.on_message(FlaskMessage(message))
return FlaskResponse("No data provided - aborting", 400)

try:
response = self.on_message(message)
except Exception as e:
response = FlaskResponse(str(e), 400)

log.debug(class_name, response=response, type=type(response))
return response

def process_list_modules(self) -> FlaskResponse:
return FlaskResponse(json.dumps(list(self.modules.keys())))

def _module_responses(self, message):
if message.modules is None:
message.modules = list(self.modules.keys())
elif not message.modules:
raise LookupError('No modules specified')

responses = [Response()]
for key, module in self.modules.items():
if key not in message.modules:
log.info(class_name, msg=f"# Skipping module: {key}")
continue # Skip this module if it's not requested.

log.info(class_name, msg=f"# Asking module: {module}")
response = module.process_message(message)
if response:
response.module = module
if response.callback:
response.confidence -= 0.001
responses.append(response)
return responses

def on_message(self, message: FlaskMessage) -> FlaskResponse:
if is_test_message(message.content) and self.utils.test_mode:
log.info(class_name, type="TEST MESSAGE", message_content=message.content)
elif self.utils.stampy_is_author(message):
for module in self.modules.values():
module.process_message_from_stampy(message)
return FlaskResponse("ok - if that's what I said", 200)

log.info(
class_name,
Expand All @@ -80,30 +102,18 @@ def on_message(self, message: FlaskMessage) -> FlaskResponse:
message_content=message.content,
)

responses = [Response()]
for key in self.modules:
if message.modules and key not in message.modules:
log.info(class_name, msg=f"# Skipping module: {key}")
continue # Skip this module if it's not requested.
module = self.modules[key]
log.info(class_name, msg=f"# Asking module: {module}")
response = module.process_message(message)
if response:
response.module = module
if response.callback:
response.confidence -= 0.001
responses.append(response)
responses = self._module_responses(message)

for i in range(maximum_recursion_depth):
responses = sorted(responses, key=(lambda x: x.confidence), reverse=True)

for response in responses:
args_string = ""
if response.callback:
args_string = ", ".join([a.__repr__() for a in response.args])
args_string = ", ".join([repr(a) for a in response.args])
if response.kwargs:
args_string += ", " + ", ".join(
[f"{k}={v.__repr__()}" for k, v in response.kwargs.items()]
[f"{k}={repr(v)}" for k, v in response.kwargs.items()]
)
log.info(
class_name,
Expand Down Expand Up @@ -159,7 +169,7 @@ def run(self):
app.add_url_rule(
"/list_modules", view_func=self.process_list_modules, methods=["GET"]
)
app.run(host="0.0.0.0", port=2300, debug=False)
app.run(host=flask_address, port=flask_port)

def stop(self):
exit()
Expand Down
7 changes: 7 additions & 0 deletions servicemodules/serviceConstants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ def __hash__(self):


default_italics_mark = "*"


def italicise(text: str, message) -> str:
if not text.strip():
return text
im = service_italics_marks.get(message.service, default_italics_mark)
return f'{im}{text}{im}'
9 changes: 4 additions & 5 deletions stam.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ def get_stampy_modules() -> dict[str, Module]:
loaded_module_filenames = set()

# filenames of modules that were skipped because not enabled
skipped_module_filenames = set(ALL_STAMPY_MODULES - enabled_modules)
skipped_module_filenames = ALL_STAMPY_MODULES - enabled_modules
if invalid_modules := enabled_modules - ALL_STAMPY_MODULES:
raise AssertionError(f"Non existent modules enabled!: {', '.join(invalid_modules)}")

for filename in enabled_modules:
if filename not in ALL_STAMPY_MODULES:
raise AssertionError(f"Module {filename} enabled but doesn't exist!")

log.info("import", filename=filename)
mod = __import__(f"modules.{filename}", fromlist=[filename])
log.info("import", module_name=mod)
Expand All @@ -60,7 +59,7 @@ def get_stampy_modules() -> dict[str, Module]:
# try instantiating it if it is a `Module`...
if isinstance(cls, type) and issubclass(cls, Module) and cls is not Module:
log.info("import Module Found", module_name=attr_name)
# unless it has a classmethod is_available, which in this particular situation returns False
# unless it has a staticmethod is_available, which in this particular situation returns False
if (
(is_available := getattr(cls, "is_available", None))
and callable(is_available)
Expand Down
52 changes: 37 additions & 15 deletions utilities/flaskutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from servicemodules.serviceConstants import Services
from utilities.serviceutils import ServiceUser, ServiceServer, ServiceChannel, ServiceMessage
from typing import TYPE_CHECKING
import json
import threading
import time

from typing import TYPE_CHECKING
from utilities.serviceutils import ServiceUser, ServiceServer, ServiceChannel, ServiceMessage
from servicemodules.serviceConstants import Services
from servicemodules.discordConstants import wiki_feed_channel_id

if TYPE_CHECKING:
from servicemodules.flask import FlaskHandler
Expand Down Expand Up @@ -47,8 +48,7 @@ def kill_thread(event: threading.Event, thread: "FlaskHandler"):

class FlaskUser(ServiceUser):
def __init__(self, key: str):
id = str(key)
super().__init__("User", "User", id)
super().__init__("User", "User", str(key))


class FlaskServer(ServiceServer):
Expand All @@ -59,15 +59,37 @@ def __init__(self, key: str):


class FlaskChannel(ServiceChannel):
def __init__(self, server: FlaskServer):
super().__init__("Web Interface", "flask_api", server)
def __init__(self, server: FlaskServer, channel=None):
super().__init__("Web Interface", channel or "flask_api", server)


class FlaskMessage(ServiceMessage):
def __init__(self, msg):
self._message = msg
server = FlaskServer(msg["key"])
id = str(time.time())
service = Services.FLASK
super().__init__(id, msg["content"], FlaskUser(msg["key"]), FlaskChannel(server), service)
self.modules = msg["modules"]

@staticmethod
def from_dict(data):
key = data.get('key')
if not key:
raise ValueError('No key provided')

# FIXME: A very hacky way of allowing HTTP requests to claim to come from stampy
author = data.get('author')
if author == 'stampy':
author = FlaskUser(wiki_feed_channel_id)
else:
author = FlaskUser(key)

modules = data.get('modules')
if not modules:
raise ValueError('No modules provided')
if isinstance(modules, str):
modules = json.loads(modules)

msg = FlaskMessage(
content=data['content'],
service=Services.FLASK,
author=author,
channel=FlaskChannel(FlaskServer(key), data.get('channel')),
id=str(time.time()),
)
msg.modules = modules
return msg
9 changes: 5 additions & 4 deletions utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,14 @@ def stampy_is_author(self, message: ServiceMessage) -> bool:
return self.is_stampy(message.author)

def is_stampy(self, user: ServiceUser) -> bool:
if (
user.id == wiki_feed_channel_id
): # consider wiki-feed ID as stampy to ignore -- is it better to set a wiki user?
if not user:
return False
# consider wiki-feed ID as stampy to ignore -- is it better to set a wiki user?
if user.id == wiki_feed_channel_id:
return True
if self.discord_user:
return user == self.discord_user
if user.id == str(cast(discord.ClientUser, self.client.user).id):
if self.client.user and user.id == str(cast(discord.ClientUser, self.client.user).id):
self.discord_user = user
return True
return False
Expand Down

0 comments on commit be6e5bf

Please sign in to comment.