From 71c6c09b2135a8e265cff6494671cc3f057e2a6e Mon Sep 17 00:00:00 2001 From: Vadim Suharnikov Date: Fri, 28 Jun 2024 09:26:14 +0400 Subject: [PATCH] Rework of models and tests --- .github/workflows/python-package.yml | 2 +- .pre-commit-config.yaml | 2 +- .vscode/settings.json | 2 +- asyncord/base64_image.py | 15 +- .../client/applications/models/requests.py | 128 ++--- .../client/applications/models/responses.py | 16 +- asyncord/client/applications/resources.py | 24 +- asyncord/client/auth/models.py | 6 + asyncord/client/auth/resources.py | 9 +- asyncord/client/bans/models/responses.py | 5 + asyncord/client/bans/resources.py | 38 +- .../channels/models/requests/creation.py | 60 +-- .../channels/models/requests/updating.py | 53 +- asyncord/client/channels/models/responses.py | 2 +- asyncord/client/channels/resources.py | 44 +- asyncord/client/commands/models/requests.py | 52 +- .../emoji.py => emojis/models/responses.py} | 4 +- asyncord/client/emojis/resources.py | 24 +- .../client/guild_templates/models/requests.py | 14 +- asyncord/client/guild_templates/resources.py | 12 +- asyncord/client/guilds/models/requests.py | 115 ++-- asyncord/client/guilds/models/responses.py | 19 +- asyncord/client/guilds/resources.py | 25 +- asyncord/client/http/middleware/base.py | 2 +- asyncord/client/http/models.py | 2 +- .../client/interactions/models/requests.py | 44 +- asyncord/client/interactions/resources.py | 28 +- asyncord/client/invites/resources.py | 6 +- asyncord/client/members/resources.py | 34 +- .../messages/models/requests/components.py | 74 ++- .../messages/models/requests/messages.py | 58 +- .../messages/models/responses/messages.py | 4 +- asyncord/client/messages/resources.py | 21 +- asyncord/client/models/attachments.py | 45 +- asyncord/client/models/automoderation.py | 7 +- asyncord/client/models/permissions.py | 16 +- asyncord/client/polls/models/requests.py | 128 +++-- asyncord/client/polls/models/responses.py | 32 +- asyncord/client/polls/resources.py | 5 +- asyncord/client/reactions/resources.py | 19 +- asyncord/client/roles/resources.py | 11 +- .../scheduled_events/models/requests.py | 68 +-- asyncord/client/scheduled_events/resources.py | 2 +- .../client/stage_instances/models/requests.py | 6 +- asyncord/client/stage_instances/resources.py | 54 +- asyncord/client/stickers/models/requests.py | 9 +- asyncord/client/stickers/resources.py | 15 +- asyncord/client/threads/models/requests.py | 16 +- asyncord/client/threads/resources.py | 18 +- asyncord/client/users/models/requests.py | 6 +- asyncord/client/users/resources.py | 7 +- asyncord/client/webhooks/models/requests.py | 16 +- asyncord/client/webhooks/models/responces.py | 2 +- asyncord/client/webhooks/resources.py | 92 ++-- asyncord/client_hub.py | 11 +- asyncord/gateway/client/client.py | 24 +- asyncord/gateway/client/heartbeat.py | 29 +- asyncord/gateway/events/event_map.py | 2 + asyncord/gateway/events/guilds.py | 10 +- asyncord/gateway/message.py | 20 +- asyncord/typedefs.py | 45 +- asyncord/urls.py | 1 + pdm.lock | 140 +++-- pyproject.toml | 22 +- tests/client/__init__.py | 0 tests/client/test_interactions.py | 126 +++++ tests/client/test_rest.py | 43 ++ tests/conftest.py | 25 +- tests/gateway/conftest.py | 19 + tests/gateway/test_client.py | 501 ++++++++++++++++++ tests/gateway/test_connection_data.py | 35 ++ tests/gateway/test_dispatcher.py | 69 ++- tests/gateway/test_heartbeat.py | 203 +++++++ tests/gateway/test_heartbeat_factory.py | 89 ++++ tests/gateway/test_opcode_handlers.py | 152 ++++++ .../client/componenets/test_action_row.py | 60 +++ .../test_creation.py} | 142 ++--- .../client/componenets/test_emoji.py | 17 + tests/integration/client/conftest.py | 32 ++ .../client/messages/test_attachments.py | 2 +- .../client/messages/test_messages.py | 152 +++++- tests/integration/client/test_applications.py | 93 +++- tests/integration/client/test_bans.py | 62 ++- tests/integration/client/test_channels.py | 173 ++++-- tests/integration/client/test_emoji.py | 78 +-- .../client/test_guild_templates.py | 76 ++- tests/integration/client/test_guilds.py | 286 +++++++++- tests/integration/client/test_invites.py | 55 +- tests/integration/client/test_members.py | 70 ++- tests/integration/client/test_polls.py | 128 +++-- tests/integration/client/test_reactions.py | 110 ++-- tests/integration/client/test_roles.py | 41 +- .../client/test_scheduled_events.py | 162 +++++- .../client/test_stages_instances.py | 47 +- tests/integration/client/test_stickers.py | 140 +++-- tests/integration/client/test_users.py | 42 +- tests/integration/client/test_webhooks.py | 208 +++++--- tests/integration/conftest.py | 92 +++- tests/test_base64_image.py | 59 ++- tests/test_client_hub.py | 69 +++ tests/test_color.py | 12 + tests/test_heartbeat.py | 76 --- tests/test_http_client.py | 41 +- tests/test_hub.py | 30 -- tests/test_strflag.py | 98 ++++ 105 files changed, 4277 insertions(+), 1460 deletions(-) rename asyncord/client/{models/emoji.py => emojis/models/responses.py} (94%) create mode 100644 tests/client/__init__.py create mode 100644 tests/client/test_interactions.py create mode 100644 tests/client/test_rest.py create mode 100644 tests/gateway/conftest.py create mode 100644 tests/gateway/test_client.py create mode 100644 tests/gateway/test_connection_data.py create mode 100644 tests/gateway/test_heartbeat.py create mode 100644 tests/gateway/test_heartbeat_factory.py create mode 100644 tests/gateway/test_opcode_handlers.py create mode 100644 tests/integration/client/componenets/test_action_row.py rename tests/integration/client/{messages/test_components.py => componenets/test_creation.py} (64%) create mode 100644 tests/integration/client/componenets/test_emoji.py create mode 100644 tests/integration/client/conftest.py create mode 100644 tests/test_client_hub.py delete mode 100644 tests/test_heartbeat.py delete mode 100644 tests/test_hub.py create mode 100644 tests/test_strflag.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 05d2c67..7b02490 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -51,7 +51,7 @@ jobs: ASYNCORD_TEST_APP_ID: ${{ vars.ASYNCORD_TEST_APP_ID }} ASYNCORD_TEST_ROLE_ID: ${{ vars.ASYNCORD_TEST_ROLE_ID }} ASYNCORD_TEST_USER_TO_BAN: ${{ vars.ASYNCORD_TEST_USER_TO_BAN }} - ASYNCORD_TEST_STAGE_ID: ${{ vars.ASYNCORD_TEST_STAGE_ID }} ASYNCORD_TEST_WORKERS: ${{ vars.ASYNCORD_TEST_WORKERS }} + ASYNCORD_TEST_ROLE_TO_PRUNE: ${{ vars.ASYNCORD_TEST_ROLE_TO_PRUNE }} run: pdm run pytest -n $ASYNCORD_TEST_WORKERS diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d9345de..bc1e265 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.4 + rev: v0.4.10 hooks: # Run the linter. - id: ruff diff --git a/.vscode/settings.json b/.vscode/settings.json index 6a61a57..0086728 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -16,4 +16,4 @@ }, "editor.defaultFormatter": "charliermarsh.ruff" }, -} \ No newline at end of file +} diff --git a/asyncord/base64_image.py b/asyncord/base64_image.py index 7f900c4..61e8027 100644 --- a/asyncord/base64_image.py +++ b/asyncord/base64_image.py @@ -61,8 +61,6 @@ def build(cls, image_data: bytes | str, image_type: str | None = None) -> Self: encoded_image = base64.b64encode(image_data).decode() return cls(f'data:{image_type};base64, {encoded_image}') - raise ValueError('Invalid image data type') - @classmethod def from_file(cls, file_path: str | Path) -> Self: """Build Base64Image from file path. @@ -92,13 +90,17 @@ def validate(cls, value: bytes | str | Self) -> Self: Raises: ValueError: If value is not """ + if isinstance(value, cls): + return value + if isinstance(value, bytes | str): return cls.build(value) - if isinstance(value, cls): - return value + if isinstance(value, Path): + return cls.from_file(value) - raise ValueError('Invalid value type') + # This should never happen because of the pydantic schema + raise ValueError('Invalid value type') # pragma: no cover @classmethod def __get_pydantic_core_schema__( @@ -118,6 +120,7 @@ def __get_pydantic_core_schema__( schema = core_schema.union_schema([ core_schema.bytes_schema(), core_schema.str_schema(), + core_schema.is_instance_schema(Path), core_schema.is_instance_schema(cls), ]) @@ -143,7 +146,7 @@ def __str__(self) -> str: return self.image_data -Base64ImageInputType = Annotated[Base64Image | bytes | str, Base64Image] +Base64ImageInputType = Annotated[Base64Image | bytes | str | Path, Base64Image] """Base64Image input type for pydantic models. Base64ImageInput must validate and convert other types to Base64Image. diff --git a/asyncord/client/applications/models/requests.py b/asyncord/client/applications/models/requests.py index 06c5ae0..93d6e4a 100644 --- a/asyncord/client/applications/models/requests.py +++ b/asyncord/client/applications/models/requests.py @@ -1,8 +1,9 @@ """This module contains models related to Discord applications.""" from enum import IntEnum +from typing import Annotated -from pydantic import BaseModel, Field, ValidationInfo, field_validator +from pydantic import BaseModel, Field, HttpUrl, field_validator from asyncord.base64_image import Base64ImageInputType from asyncord.client.applications.models.common import ( @@ -13,6 +14,22 @@ from asyncord.client.models.permissions import PermissionFlag from asyncord.locale import LocaleInputType +__all__ = ( + 'ApplicationIntegrationType', + 'ApplicationIntegrationTypeConfig', + 'InstallParams', + 'UpdateApplicationRequest', + 'UpdateApplicationRoleConnectionMetadataRequest', +) + +APPLICATION_ALLOWED_TYPES = ( + ApplicationFlag.GATEWAY_PRESENCE_LIMITED + | ApplicationFlag.GATEWAY_GUILD_MEMBERS_LIMITED + | ApplicationFlag.GATEWAY_MESSAGE_CONTENT_LIMITED +) + +"""Allowed application uodate flags.""" + class InstallParams(BaseModel): """Application install parameters. @@ -73,13 +90,7 @@ class UpdateApplicationRequest(BaseModel): install_params: InstallParams | None = None """Settings for the app's default in-app authorization link, if enabled.""" - integration_types_config: ( - dict[ - ApplicationIntegrationType, - ApplicationIntegrationTypeConfig, - ] - | None - ) = None + integration_types_config: dict[ApplicationIntegrationType, ApplicationIntegrationTypeConfig] | None = None """Default scopes and permissions for each supported installation context. Value for each key is an integration type configuration object. @@ -94,53 +105,40 @@ class UpdateApplicationRequest(BaseModel): cover_image: Base64ImageInputType | None = None """Default rich presence invite cover image for the app.""" - interactions_endpoint_url: str | None = None + interactions_endpoint_url: HttpUrl | None = None """Interactions endpoint URL for the app.""" - tags: list[str] | None = None + tags: ( + Annotated[ + set[Annotated[str, Field(max_length=20)]], + Field(max_length=5), + ] + | None + ) = None """List of tags describing the content and functionality of the app. Maximum of 5 tags. Maximum of 20 characters per tag. """ - @field_validator('tags') - @classmethod - def validate_tags( - cls, - tags: list[str] | None, - field_info: ValidationInfo, - ) -> list[str] | None: - """Ensures that the length of tags is less than or equal to 5. - - And each tag is less than or equal to 20 characters. - """ - max_tags = 5 - max_tag_length = 20 - - if tags is not None: - if len(tags) > max_tags: - raise ValueError('Maximum of 5 tags allowed.') - for tag in tags: - if len(tag) > max_tag_length: - raise ValueError('Maximum of 20 characters per tag allowed.') - return tags - @field_validator('flags') @classmethod - def validate_flags( - cls, - tags: ApplicationFlag | None, - ) -> ApplicationFlag: - """Ensures that the flag is valid.""" - if tags is not None: - if tags not in { - ApplicationFlag.GATEWAY_PRESENCE_LIMITED, - ApplicationFlag.GATEWAY_GUILD_MEMBERS_LIMITED, - ApplicationFlag.GATEWAY_MESSAGE_CONTENT_LIMITED, - }: - raise ValueError('Invalid flag.') - return tags + def validate_flags(cls, flags: ApplicationFlag | None) -> ApplicationFlag | None: + """Ensures that the flag is one of the allowed types.""" + if not flags: + return None + + if (flags & APPLICATION_ALLOWED_TYPES) != flags: + err_msg = 'Invalid flag. Must be one of the following: ' + ', '.join( + [ + 'GATEWAY_PRESENCE_LIMITED', + 'GATEWAY_GUILD_MEMBERS_LIMITED', + 'GATEWAY_MESSAGE_CONTENT_LIMITED', + ], + ) + raise ValueError(err_msg) + + return flags class UpdateApplicationRoleConnectionMetadataRequest(BaseModel): @@ -153,48 +151,20 @@ class UpdateApplicationRoleConnectionMetadataRequest(BaseModel): type: ApplicationRoleConnectionMetadataType """Type of metadata value.""" - key: str + key: Annotated[str, Field(min_length=1, max_length=50, pattern=r'^[a-z0-9_]{1,50}$')] """Dictionary key for the metadata field. - Must be a - z, 0 - 9, or _ characters; + Must be a - z, 0 - 9, or _ characters. 1 - 50 characters. """ - name: str = Field(None, min_length=1, max_length=100) - """Name of the metadata field. - - (1 - 100 characters). - """ + name: Annotated[str, Field(min_length=1, max_length=100)] + """Name of the metadata field.""" name_localizations: dict[LocaleInputType, str] | None = None """Translations of the name.""" - description: str = Field(None, min_length=1, max_length=200) - """Description of the metadata field. - - (1 - 200 characters). - """ + description: Annotated[str, Field(min_length=1, max_length=200)] + """Description of the metadata field.""" description_localizations: dict[LocaleInputType, str] | None = None """Translations of the description.""" - - @field_validator('key') - @classmethod - def validate_key( - cls, - key: str, - field_info: ValidationInfo, - ) -> list[str] | None: - """Ensures that the length of key is 1 - 50 characters. - - And a - z, 0 - 9, or _ characters. - """ - max_length = 50 - allowed_symbols = set('abcdefghijklmnopqrstuvwxyz0123456789_') - - if not 1 <= len(key) <= max_length: - raise ValueError('Key length must be between 1 and 50 characters.') - - if not set(key).issubset(allowed_symbols): - raise ValueError('Key must contain only a - z, 0 - 9, or _ characters.') - - return key diff --git a/asyncord/client/applications/models/responses.py b/asyncord/client/applications/models/responses.py index 3f52e4d..cb1f4b6 100644 --- a/asyncord/client/applications/models/responses.py +++ b/asyncord/client/applications/models/responses.py @@ -20,6 +20,20 @@ from asyncord.locale import LocaleInputType from asyncord.snowflake import Snowflake +__all__ = ( + 'ApplicationCommandPermissionOut', + 'ApplicationOut', + 'ApplicationRoleConnectionMetadataOut', + 'ApplicationUserOut', + 'BotApplicationOut', + 'GuildApplicationCommandPermissionsOut', + 'InstallParamsOut', + 'InviteCreateEventApplication', + 'TeamMemberOut', + 'TeamMemberUserOut', + 'TeamOut', +) + class InstallParamsOut(BaseModel): """Application install parameters. @@ -264,7 +278,7 @@ class ApplicationOut(BaseModel): role_connections_verification_url: AnyHttpUrl | None = None """Application's default role connection verification url.""" - tags: list[str] = Field(default_factory=list, max_length=5) + tags: set[str] = Field(default_factory=list, max_length=5) """Tags describing the content and functionality of the application. Maximum of 5 tags. diff --git a/asyncord/client/applications/resources.py b/asyncord/client/applications/resources.py index 1c2f27e..caf8b4c 100644 --- a/asyncord/client/applications/resources.py +++ b/asyncord/client/applications/resources.py @@ -1,16 +1,24 @@ """This module contains the applications resource for the client.""" -from asyncord.client.applications.models.requests import ( - UpdateApplicationRequest, - UpdateApplicationRoleConnectionMetadataRequest, -) +from __future__ import annotations + +from typing import TYPE_CHECKING + from asyncord.client.applications.models.responses import ApplicationOut, ApplicationRoleConnectionMetadataOut from asyncord.client.commands.resources import CommandResource from asyncord.client.resources import APIResource -from asyncord.snowflake import SnowflakeInputType -from asyncord.typedefs import list_model +from asyncord.typedefs import CURRENT_USER, list_model from asyncord.urls import REST_API_URL +if TYPE_CHECKING: + from asyncord.client.applications.models.requests import ( + UpdateApplicationRequest, + UpdateApplicationRoleConnectionMetadataRequest, + ) + from asyncord.snowflake import SnowflakeInputType + +__all__ = ('ApplicationResource',) + class ApplicationResource(APIResource): """Represents the applications resource for the client. @@ -38,7 +46,7 @@ async def get_application(self) -> ApplicationOut: Returns: Application object. """ - resp = await self._http_client.get(url=self.apps_url / '@me') + resp = await self._http_client.get(url=self.apps_url / CURRENT_USER) return ApplicationOut.model_validate(resp.body) async def update_application( @@ -55,7 +63,7 @@ async def update_application( """ payload = application_data.model_dump(mode='json', exclude_unset=True) resp = await self._http_client.patch( - url=self.apps_url / '@me', + url=self.apps_url / CURRENT_USER, payload=payload, ) return ApplicationOut.model_validate(resp.body) diff --git a/asyncord/client/auth/models.py b/asyncord/client/auth/models.py index 721f78c..14ebf5b 100644 --- a/asyncord/client/auth/models.py +++ b/asyncord/client/auth/models.py @@ -10,6 +10,12 @@ from asyncord.client.users.models.responses import UserResponse from asyncord.typedefs import StrFlag +__all__ = ( + 'AuthorizationInfoApplication', + 'AuthorizationInfoResponse', + 'OAuthScope', +) + @enum.unique class OAuthScope(StrFlag): diff --git a/asyncord/client/auth/resources.py b/asyncord/client/auth/resources.py index 1384ad5..f20cc1a 100644 --- a/asyncord/client/auth/resources.py +++ b/asyncord/client/auth/resources.py @@ -4,11 +4,16 @@ https://discord.com/developers/docs/topics/oauth2 """ +from __future__ import annotations + from asyncord.client.applications.models.responses import ApplicationOut from asyncord.client.auth.models import AuthorizationInfoResponse from asyncord.client.resources import APIResource +from asyncord.typedefs import CURRENT_USER from asyncord.urls import REST_API_URL +__all__ = ('OAuthResource',) + class OAuthResource(APIResource): """Represents an OAuth2 resource. @@ -31,7 +36,7 @@ async def get_current_application_info(self) -> ApplicationOut: Reference: https://discord.com/developers/docs/topics/oauth2#get-current-bot-application-information """ - url = self.oauth_url / 'applications' / '@me' + url = self.oauth_url / 'applications' / CURRENT_USER resp = await self._http_client.get(url=url) return ApplicationOut.model_validate(resp.body) @@ -41,6 +46,6 @@ async def get_current_authorization_info(self) -> AuthorizationInfoResponse: Reference: https://discord.com/developers/docs/topics/oauth2#get-current-authorization-information """ - url = self.oauth_url / '@me' + url = self.oauth_url / CURRENT_USER resp = await self._http_client.get(url=url) return AuthorizationInfoResponse.model_validate(resp.body) diff --git a/asyncord/client/bans/models/responses.py b/asyncord/client/bans/models/responses.py index 462d67c..4dce1f2 100644 --- a/asyncord/client/bans/models/responses.py +++ b/asyncord/client/bans/models/responses.py @@ -5,6 +5,11 @@ from asyncord.client.users.models.responses import UserResponse from asyncord.snowflake import SnowflakeInputType +__all__ = ( + 'BanResponse', + 'BulkBanResponse', +) + class BanResponse(BaseModel): """Ban object. diff --git a/asyncord/client/bans/resources.py b/asyncord/client/bans/resources.py index a1b82c0..26465a4 100644 --- a/asyncord/client/bans/resources.py +++ b/asyncord/client/bans/resources.py @@ -2,14 +2,21 @@ from __future__ import annotations +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + from asyncord.client.bans.models.responses import BanResponse, BulkBanResponse -from asyncord.client.http.client import HttpClient from asyncord.client.http.headers import AUDIT_LOG_REASON from asyncord.client.resources import APIResource -from asyncord.snowflake import SnowflakeInputType from asyncord.typedefs import list_model from asyncord.urls import REST_API_URL +if TYPE_CHECKING: + from asyncord.client.http.client import HttpClient + from asyncord.snowflake import SnowflakeInputType + +__all__ = ('BanResource',) + class BanResource(APIResource): """Base class for ban resources. @@ -73,7 +80,7 @@ async def get_list( async def ban( self, user_id: SnowflakeInputType, - delete_message_days: int | None = None, + delete_message_seconds: int | None = None, reason: str | None = None, ) -> None: """Ban a user from a guild. @@ -83,19 +90,19 @@ async def ban( Args: user_id: ID of a user to ban. - delete_message_days: Number of days to delete messages for. - Should be between 0 and 7. Defaults to 0. + delete_message_seconds: number of seconds to delete messages for. + between 0 and 604800 (7 days). Defaults to 0. reason: Reason for banning the user. Defaults to None. """ url = self.bans_url / str(user_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} - if delete_message_days is not None: - payload = {'delete_message_days': delete_message_days} + if delete_message_seconds is not None: + payload = {'delete_message_seconds': delete_message_seconds} else: payload = None @@ -111,7 +118,7 @@ async def unban(self, user_id: SnowflakeInputType, reason: str | None = None) -> user_id: ID of the user to unban. reason: Reason for unbanning the user. Defaults to None. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -120,7 +127,7 @@ async def unban(self, user_id: SnowflakeInputType, reason: str | None = None) -> async def bulk_ban( self, - user_ids: list[SnowflakeInputType], + user_ids: Sequence[SnowflakeInputType], delete_message_seconds: int | None = None, reason: str | None = None, ) -> BulkBanResponse: @@ -135,20 +142,19 @@ async def bulk_ban( between 0 and 604800 (7 days). Defaults to 0. reason: Reason for banning the users. Defaults to None. """ - url = self.bans_url + url = self.guilds_url / str(self.guild_id) / 'bulk-ban' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} - payload = {} - payload['user_ids'] = user_ids + payload: dict[str, Any] = { + 'user_ids': user_ids, + } if delete_message_seconds is not None: payload['delete_message_seconds'] = delete_message_seconds - else: - payload = None resp = await self._http_client.post( url=url, diff --git a/asyncord/client/channels/models/requests/creation.py b/asyncord/client/channels/models/requests/creation.py index b014b3a..33ef366 100644 --- a/asyncord/client/channels/models/requests/creation.py +++ b/asyncord/client/channels/models/requests/creation.py @@ -5,7 +5,7 @@ """ import logging -from typing import Any, Literal, Self +from typing import Annotated, Any, Literal, Self from pydantic import BaseModel, Field, model_validator @@ -131,20 +131,20 @@ class BaseCreateChannel(BaseModel): class CreateCategoryChannelRequest(BaseCreateChannel): """Data to create a guild category with.""" - type: Literal[ChannelType.GUILD_CATEGORY] = ChannelType.GUILD_CATEGORY + type: Literal[ChannelType.GUILD_CATEGORY] = ChannelType.GUILD_CATEGORY # type: ignore """Type of channel.""" class CreateTextChannelRequest(BaseCreateChannel): """Data to create a text channel with.""" - type: Literal[ChannelType.GUILD_TEXT] = ChannelType.GUILD_TEXT + type: Literal[ChannelType.GUILD_TEXT] = ChannelType.GUILD_TEXT # type: ignore """Type of channel.""" - topic: str | None = Field(None, max_length=1024) + topic: Annotated[str | None, Field(max_length=1024)] = None """Channel topic.""" - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -164,7 +164,7 @@ class CreateTextChannelRequest(BaseCreateChannel): after recent activity. """ - default_thread_rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + default_thread_rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message in a thread. Should be between 0 and 21600. Bots, as well as users with the permission @@ -175,10 +175,10 @@ class CreateTextChannelRequest(BaseCreateChannel): class CreateAnoncementChannelRequest(BaseCreateChannel): """Data to create an announcement channel with.""" - type: Literal[ChannelType.GUILD_ANNOUNCEMENT] = ChannelType.GUILD_ANNOUNCEMENT + type: Literal[ChannelType.GUILD_ANNOUNCEMENT] = ChannelType.GUILD_ANNOUNCEMENT # type: ignore """Type of channel.""" - topic: str | None = Field(None, max_length=1024) + topic: Annotated[str | None, Field(max_length=1024)] = None """Channel topic.""" parent_id: SnowflakeInputType | None = None @@ -194,7 +194,7 @@ class CreateAnoncementChannelRequest(BaseCreateChannel): after recent activity. """ - default_thread_rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + default_thread_rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message in a thread. Should be between 0 and 21600. Bots, as well as users with the permission @@ -205,13 +205,13 @@ class CreateAnoncementChannelRequest(BaseCreateChannel): class CreateForumChannelRequest(BaseCreateChannel): """Data to create a forum channel with.""" - type: Literal[ChannelType.GUILD_FORUM] = ChannelType.GUILD_FORUM + type: Literal[ChannelType.GUILD_FORUM] = ChannelType.GUILD_FORUM # type: ignore """Type of channel.""" - topic: str | None = Field(None, max_length=1024) + topic: Annotated[str | None, Field(max_length=1024)] = None """Channel topic.""" - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -236,7 +236,7 @@ class CreateForumChannelRequest(BaseCreateChannel): default_forum_layout: DefaultForumLayoutType | None = None """Default layout for the forum channel.""" - default_thread_rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + default_thread_rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message in a thread. Should be between 0 and 21600. Bots, as well as users with the permission @@ -247,13 +247,13 @@ class CreateForumChannelRequest(BaseCreateChannel): class CreateMediaChannelRequest(BaseCreateChannel): """Data to create a media channel with.""" - type: Literal[ChannelType.GUILD_MEDIA] = ChannelType.GUILD_MEDIA + type: Literal[ChannelType.GUILD_MEDIA] = ChannelType.GUILD_MEDIA # type: ignore """Type of channel.""" - topic: str | None = Field(None, max_length=1024) + topic: Annotated[str | None, Field(max_length=1024)] = None """Channel topic.""" - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -272,7 +272,7 @@ class CreateMediaChannelRequest(BaseCreateChannel): available_tags: list[Tag] | None = None """List of available tags for the media channel.""" - default_thread_rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + default_thread_rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message in a thread. Should be between 0 and 21600. Bots, as well as users with the permission @@ -283,10 +283,10 @@ class CreateMediaChannelRequest(BaseCreateChannel): class CreateVoiceChannelRequest(BaseCreateChannel): """Data to create a voice channel with.""" - type: Literal[ChannelType.GUILD_VOICE] = ChannelType.GUILD_VOICE + type: Literal[ChannelType.GUILD_VOICE] = ChannelType.GUILD_VOICE # type: ignore """Type of channel.""" - bitrate: int | None = Field(None, ge=MIN_BITRATE, le=MAX_BITRATE) + bitrate: Annotated[int | None, Field(ge=MIN_BITRATE, le=MAX_BITRATE)] = None """Bitrate (in bits) of the voice channel. For voice channels, normal servers can set bitrate up to 96000. @@ -296,13 +296,13 @@ class CreateVoiceChannelRequest(BaseCreateChannel): Bitrate can be set up to 64000. """ - user_limit: int | None = Field(None, ge=0, le=99) + user_limit: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """User limit of the voice channel. No limit if set to 0. """ - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -328,10 +328,10 @@ class CreateVoiceChannelRequest(BaseCreateChannel): class CreateStageChannelRequest(BaseCreateChannel): """Data to create a stage channel with.""" - type: Literal[ChannelType.GUILD_STAGE_VOICE] = ChannelType.GUILD_STAGE_VOICE + type: Literal[ChannelType.GUILD_STAGE_VOICE] = ChannelType.GUILD_STAGE_VOICE # type: ignore """Type of channel.""" - bitrate: int | None = Field(None, ge=MIN_BITRATE, le=MAX_BITRATE) + bitrate: Annotated[int | None, Field(ge=MIN_BITRATE, le=MAX_BITRATE)] = None """Bitrate (in bits) of the voice channel. For voice channels, normal servers can set bitrate up to 96000. @@ -340,13 +340,13 @@ class CreateStageChannelRequest(BaseCreateChannel): feature can set up to 384000. For stage channels. """ - user_limit: int | None = Field(None, ge=0, le=99) + user_limit: Annotated[int | None, Field(ge=0, le=99)] = None """User limit of the voice channel. No limit if set to 0. """ - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int | None, Field(ge=0, le=MAX_RATELIMIT)] = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -376,10 +376,10 @@ class ChannelInviteRequest(BaseModel): https://discord.com/developers/docs/resources/channel#create-channel-invite-json-params """ - max_age: int | None = Field(None, le=604800) + max_age: Annotated[int | None, Field(le=604800)] = None """Duration of invite in seconds before expiry, or 0 for never.""" - max_uses: int | None = Field(None, ge=0, le=100) + max_uses: Annotated[int | None, Field(ge=0, le=100)] = None """Max number of uses or 0 for unlimited.""" temporary: bool | None = None @@ -388,7 +388,7 @@ class ChannelInviteRequest(BaseModel): unique: bool | None = None """If true, don't try to reuse a similar invite. - (useful for creating many unique one time use invites). + Useful for creating many unique one time use invites. """ target_type: InviteTargetType | None = None @@ -397,13 +397,13 @@ class ChannelInviteRequest(BaseModel): target_user_id: SnowflakeInputType | None = None """Id of the user whose stream to display for this invite. - Required if target_type is 1. + Required if target_type is `InviteTargetType.STREAM`. """ target_application_id: SnowflakeInputType | None = None """Id of the embedded application to open for this invite. - Required if target_type is 2. + Required if target_type is `InviteTargetType.EMBEDDED_APPLICATION`. """ @model_validator(mode='after') diff --git a/asyncord/client/channels/models/requests/updating.py b/asyncord/client/channels/models/requests/updating.py index 26cbcc2..6e7d80b 100644 --- a/asyncord/client/channels/models/requests/updating.py +++ b/asyncord/client/channels/models/requests/updating.py @@ -4,9 +4,9 @@ https://discord.com/developers/docs/resources/channel """ -from typing import Literal +from typing import Annotated, Literal -from pydantic import BaseModel, Field, ValidationInfo, field_validator +from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator from asyncord.base64_image import Base64ImageInputType from asyncord.client.channels.models.common import ( @@ -24,6 +24,7 @@ Overwrite, Tag, ) +from asyncord.client.models.permissions import PermissionFlag from asyncord.snowflake import SnowflakeInputType __all__ = ( @@ -46,7 +47,7 @@ class BaseUpdateChannel(BaseModel): """Data to create a channel with.""" - name: str | None = Field(None, min_length=1, max_length=100) + name: Annotated[str, Field(min_length=1, max_length=100)] | None = None """Channel name.""" position: int | None = None @@ -74,20 +75,20 @@ class UpdateChannelRequest(BaseUpdateChannel): with the "NEWS" feature. """ - topic: str | None = Field(None, max_length=1024) + topic: Annotated[str, Field(max_length=1024)] | None = None """Character channel topic.""" nsfw: bool | None = None """Whether the channel is nsfw.""" - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission manage_messages or manage_channel, are unaffected. """ - bitrate: int | None = Field(None, ge=MIN_BITRATE, le=MAX_BITRATE) + bitrate: Annotated[int, Field(ge=MIN_BITRATE, le=MAX_BITRATE)] | None = None """Bitrate (in bits) of the voice channel. For voice channels, normal servers can set bitrate up to 96000. @@ -96,7 +97,7 @@ class UpdateChannelRequest(BaseUpdateChannel): feature can set up to 384000. For stage channels. """ - user_limit: int | None = Field(None, ge=0, le=99) + user_limit: Annotated[int, Field(ge=0, le=99)] | None = None """User limit of the voice channel. No limit if set to 0. @@ -130,7 +131,7 @@ class UpdateChannelRequest(BaseUpdateChannel): default_reaction_emoji: DefaultReaction | None = None """Default reaction emoji for the forum channel.""" - default_thread_rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + default_thread_rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message in a thread. Should be between 0 and 21600. Bots, as well as users with the permission @@ -166,7 +167,7 @@ class UpdateGroupDMChannelRequest(BaseModel): https://discord.com/developers/docs/resources/channel#modify-channel-json-params-group-dm """ - name: str | None = Field(None, min_length=1, max_length=100) + name: Annotated[str, Field(min_length=1, max_length=100)] | None = None """Character channel name.""" icon: Base64ImageInputType | None = None @@ -199,7 +200,7 @@ class UpdateTextChannelRequest(BaseUpdateChannel): after recent activity. """ - default_thread_rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + default_thread_rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message in a thread. Should be between 0 and 21600. Bots, as well as users with the permission @@ -297,7 +298,7 @@ class UpdateMediaChannelRequest(BaseUpdateChannel): default_reaction_emoji: DefaultReaction | None = None """Default reaction emoji for the forum channel.""" - default_thread_rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + default_thread_rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message in a thread. Should be between 0 and 21600. Bots, as well as users with the permission @@ -387,18 +388,34 @@ class UpdateChannelPermissionsRequest(BaseModel): https://discord.com/developers/docs/resources/channel#edit-channel-permissions """ - allow: str | None = None - """The bitwise value of all allowed permissions.""" + type: Literal['role', 'member'] + """Type of the permission overwrite.""" - deny: str | None = None - """The bitwise value of all disallowed permissions.""" + allow: PermissionFlag | None = None + """The bitwise value of all allowed permissions. - type: Literal[0, 1] | None = None - """Type of the permission overwrite. + Default to None. - 0 for role, 1 for member. + If None, the value will be set 0 at discord's end. """ + deny: PermissionFlag | None = None + """The bitwise value of all disallowed permissions. + + Default to None. + + If None, the value will be set 0 at discord's end. + """ + + @field_serializer('type', when_used='json') + @classmethod + def serialize_type(cls, type_value: Literal['role', 'member']) -> int: + """Serialize type to number for JSON. + + 0 if role, 1 if member. + """ + return int(type_value == 'member') + class UpdateChannelPositionRequest(BaseModel): """Data to update a channel's position with. diff --git a/asyncord/client/channels/models/responses.py b/asyncord/client/channels/models/responses.py index e8ee96c..ea0216c 100644 --- a/asyncord/client/channels/models/responses.py +++ b/asyncord/client/channels/models/responses.py @@ -297,4 +297,4 @@ class FollowedChannelResponse(BaseModel): """Source channel id.""" webhook_id: Snowflake - """Created target webhook id.""" + """Created webhook id.""" diff --git a/asyncord/client/channels/resources.py b/asyncord/client/channels/resources.py index 2f9d11e..effec3d 100644 --- a/asyncord/client/channels/resources.py +++ b/asyncord/client/channels/resources.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING +from asyncord.client.channels.models.requests.updating import UpdateChannelPermissionsRequest from asyncord.client.channels.models.responses import ChannelResponse, FollowedChannelResponse from asyncord.client.guilds.models.responses import InviteResponse from asyncord.client.http.headers import AUDIT_LOG_REASON @@ -24,7 +25,6 @@ CreateChannelRequestType, ) from asyncord.client.channels.models.requests.updating import ( - UpdateChannelPermissionsRequest, UpdateChannelPositionRequest, UpdateChannelRequestType, ) @@ -111,7 +111,7 @@ async def create_channel( """ url = REST_API_URL / 'guilds' / str(guild_id) / 'channels' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -142,7 +142,7 @@ async def update( """ url = self.channels_url / str(channel_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -189,7 +189,7 @@ async def delete(self, channel_id: SnowflakeInputType, reason: str | None = None """ url = self.channels_url / str(channel_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -199,7 +199,7 @@ async def delete(self, channel_id: SnowflakeInputType, reason: str | None = None async def update_permissions( self, channel_id: SnowflakeInputType, - overwrite_id: SnowflakeInputType, + role_or_user_id: SnowflakeInputType, permission_data: UpdateChannelPermissionsRequest, reason: str | None = None, ) -> None: @@ -212,19 +212,18 @@ async def update_permissions( Args: channel_id: Channel id. - overwrite_id: Role or user id. + role_or_user_id: Role or user id. permission_data: The data to update the permissions with. reason: Reason for updating the permissions. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} payload = permission_data.model_dump(mode='json', exclude_unset=True) - url = self.channels_url / str(channel_id) / 'permissions' / str(overwrite_id) - + url = self.channels_url / str(channel_id) / 'permissions' / str(role_or_user_id) await self._http_client.put( url=url, payload=payload, @@ -234,22 +233,22 @@ async def update_permissions( async def delete_permission( self, channel_id: SnowflakeInputType, - overwrite_id: SnowflakeInputType, + role_or_user_id: SnowflakeInputType, reason: str | None = None, ) -> None: """Delete a channel permission overwrite for a user or role in a channel. Args: channel_id: Channel id. - overwrite_id: Role or user id. + role_or_user_id: Role or user id. reason: Reason for deleting the permission. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} - url = self.channels_url / str(channel_id) / 'permissions' / str(overwrite_id) + url = self.channels_url / str(channel_id) / 'permissions' / str(role_or_user_id) await self._http_client.delete(url=url, headers=headers) @@ -272,7 +271,7 @@ async def get_channel_invites(self, channel_id: SnowflakeInputType) -> list[Invi async def create_channel_invite( self, channel_id: SnowflakeInputType, - invite_request: ChannelInviteRequest | None = None, + invite_data: ChannelInviteRequest | None = None, reason: str | None = None, ) -> InviteResponse: """Create a new invite for a channel. @@ -282,7 +281,8 @@ async def create_channel_invite( Args: channel_id: Channel id. - invite_request: Data for creating the invite. + invite_data: Data for creating the invite. Default is None. + If None, a default invite will be created. reason: Reason for creating the invite. Returns: @@ -290,15 +290,15 @@ async def create_channel_invite( """ url = self.channels_url / str(channel_id) / 'invites' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} payload = {} - if invite_request: - payload = invite_request.model_dump(mode='json', exclude_unset=True) + if invite_data: + payload = invite_data.model_dump(mode='json', exclude_unset=True) resp = await self._http_client.post(url=url, payload=payload, headers=headers) return InviteResponse.model_validate(resp.body) @@ -306,17 +306,21 @@ async def create_channel_invite( async def follow_announcement_channel( self, channel_id: SnowflakeInputType, - webhook_channel_id: SnowflakeInputType, + target_channel_id: SnowflakeInputType, ) -> FollowedChannelResponse: """Follow an Announcement channel to send messages to a target channel. + Args: + channel_id: Channel id. + target_channel_id: Channel id to send messages to. + Reference: https://discord.com/developers/docs/resources/channel#follow-announcement-channel """ url = self.channels_url / str(channel_id) / 'followers' payload = { - 'webhook_channel_id': str(webhook_channel_id), + 'webhook_channel_id': str(target_channel_id), } resp = await self._http_client.post(url=url, payload=payload) diff --git a/asyncord/client/commands/models/requests.py b/asyncord/client/commands/models/requests.py index d4dd719..1e43809 100644 --- a/asyncord/client/commands/models/requests.py +++ b/asyncord/client/commands/models/requests.py @@ -36,10 +36,6 @@ 'CreateApplicationCommandRequest', ) -_APP_COMMAND_NAME_PATTERN: Final[str] = r'^[-_\p{L}\p{N}\p{sc=Deva}\p{sc=Thai}]{1,32}$' -_NameAnnotation = Annotated[str, Field(min_length=1, max_length=32, pattern=_APP_COMMAND_NAME_PATTERN)] -_DescriptionAnnotation = Annotated[str, Field(min_length=1, max_length=100)] - class ApplicationCommandOptionChoice(BaseModel): """Represents an option choice for a Discord application command. @@ -103,7 +99,7 @@ class ApplicationCommandSubCommandOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.SUB_COMMAND] = AppCommandOptionType.SUB_COMMAND + type: Literal[AppCommandOptionType.SUB_COMMAND] = AppCommandOptionType.SUB_COMMAND # type: ignore options: list[ApplicationCommandOption] | None = None """List of options for subcommand and subcommand group types.""" @@ -116,7 +112,7 @@ class ApplicationCommandSubCommandGroupOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.SUB_COMMAND_GROUP] = AppCommandOptionType.SUB_COMMAND_GROUP + type: Literal[AppCommandOptionType.SUB_COMMAND_GROUP] = AppCommandOptionType.SUB_COMMAND_GROUP # type: ignore options: list[ApplicationCommandOption] | None = None """List of options for subcommand and subcommand group types.""" @@ -129,18 +125,18 @@ class ApplicationCommandStringOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.STRING] = AppCommandOptionType.STRING + type: Literal[AppCommandOptionType.STRING] = AppCommandOptionType.STRING # type: ignore - choices: list[ApplicationCommandOptionChoice] | None = Field(None, max_length=25) + choices: _ChoiceType | None = None """List of choices for string and number types. Max length is 25. """ - min_length: int | None = Field(None, ge=0, le=6000) + min_length: Annotated[int | None, Field(ge=0, le=6000)] = None """Minimum length for the option.""" - max_length: int | None = Field(None, ge=0, le=6000) + max_length: Annotated[int | None, Field(ge=0, le=6000)] = None """Maximum length for the option.""" autocomplete: bool | None = None @@ -157,9 +153,9 @@ class ApplicationCommandIntegerOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.INTEGER] = AppCommandOptionType.INTEGER + type: Literal[AppCommandOptionType.INTEGER] = AppCommandOptionType.INTEGER # type: ignore - choices: list[ApplicationCommandOptionChoice] | None = Field(None, max_length=25) + choices: _ChoiceType | None = None """List of choices for string and number types. Max length is 25. @@ -185,7 +181,7 @@ class ApplicationCommandBooleanOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.BOOLEAN] = AppCommandOptionType.BOOLEAN + type: Literal[AppCommandOptionType.BOOLEAN] = AppCommandOptionType.BOOLEAN # type: ignore class ApplicationCommandUserOption(BaseApplicationCommandOption): @@ -195,7 +191,7 @@ class ApplicationCommandUserOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.USER] = AppCommandOptionType.USER + type: Literal[AppCommandOptionType.USER] = AppCommandOptionType.USER # type: ignore class ApplicationCommandChannelOption(BaseApplicationCommandOption): @@ -205,7 +201,7 @@ class ApplicationCommandChannelOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.CHANNEL] = AppCommandOptionType.CHANNEL + type: Literal[AppCommandOptionType.CHANNEL] = AppCommandOptionType.CHANNEL # type: ignore channel_types: list[ChannelType] | None = None """List of available channel types if the option type is a `CHANNEL`.""" @@ -218,7 +214,7 @@ class ApplicationCommandRoleOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.ROLE] = AppCommandOptionType.ROLE + type: Literal[AppCommandOptionType.ROLE] = AppCommandOptionType.ROLE # type: ignore class ApplicationCommandMentionableOption(BaseApplicationCommandOption): @@ -228,7 +224,7 @@ class ApplicationCommandMentionableOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.MENTIONABLE] = AppCommandOptionType.MENTIONABLE + type: Literal[AppCommandOptionType.MENTIONABLE] = AppCommandOptionType.MENTIONABLE # type: ignore class ApplicationCommandNumberOption(BaseApplicationCommandOption): @@ -238,9 +234,9 @@ class ApplicationCommandNumberOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.NUMBER] = AppCommandOptionType.NUMBER + type: Literal[AppCommandOptionType.NUMBER] = AppCommandOptionType.NUMBER # type: ignore - choices: list[ApplicationCommandOptionChoice] | None = Field(None, max_length=25) + choices: _ChoiceType | None = None """List of choices for string and number types. Max length is 25. @@ -266,7 +262,7 @@ class ApplicationCommandAttachmentOption(BaseApplicationCommandOption): https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-structure """ - type: Literal[AppCommandOptionType.ATTACHMENT] = AppCommandOptionType.ATTACHMENT + type: Literal[AppCommandOptionType.ATTACHMENT] = AppCommandOptionType.ATTACHMENT # type: ignore class CreateApplicationCommandRequest(BaseModel): @@ -300,7 +296,7 @@ class CreateApplicationCommandRequest(BaseModel): description_localizations: dict[LocaleInputType, _DescriptionAnnotation] | None = None """Dictionary of language codes to localized descriptions. Defaults to None.""" - options: list[ApplicationCommandOption] | None = Field(None, max_length=25) + options: Annotated[list[ApplicationCommandOption], Field(max_length=25)] | None = None """List of options for the command. Must be 0-25 long. Defaults to None. @@ -350,3 +346,17 @@ def validate_options( | ApplicationCommandNumberOption | ApplicationCommandAttachmentOption ) + + +_APP_COMMAND_NAME_PATTERN: Final[str] = r'^[-_\p{L}\p{N}\p{sc=Deva}\p{sc=Thai}]{1,32}$' +"""Pattern for the application command name. + +The name must be 1-32 characters long. +It can contain letters, numbers, and the following characters: `-`, `_`. +""" +_NameAnnotation = Annotated[str, Field(min_length=1, max_length=32, pattern=_APP_COMMAND_NAME_PATTERN)] +"""Annotated name field for application commands.""" +_DescriptionAnnotation = Annotated[str, Field(min_length=1, max_length=100)] +"""Annotated description field for application commands.""" +_ChoiceType = Annotated[list[ApplicationCommandOptionChoice], Field(max_length=25)] +"""Annotated choice type for application commands.""" diff --git a/asyncord/client/models/emoji.py b/asyncord/client/emojis/models/responses.py similarity index 94% rename from asyncord/client/models/emoji.py rename to asyncord/client/emojis/models/responses.py index f3b0bbd..0861df4 100644 --- a/asyncord/client/models/emoji.py +++ b/asyncord/client/emojis/models/responses.py @@ -5,10 +5,10 @@ from asyncord.client.users.models.responses import UserResponse from asyncord.snowflake import Snowflake -__all__ = ('Emoji',) +__all__ = ('EmojiResponse',) -class Emoji(BaseModel): +class EmojiResponse(BaseModel): """Represents a custom emoji that can be used in messages. Reference: diff --git a/asyncord/client/emojis/resources.py b/asyncord/client/emojis/resources.py index 5125780..d3ad0ed 100644 --- a/asyncord/client/emojis/resources.py +++ b/asyncord/client/emojis/resources.py @@ -4,8 +4,8 @@ from typing import TYPE_CHECKING +from asyncord.client.emojis.models.responses import EmojiResponse from asyncord.client.http.headers import AUDIT_LOG_REASON -from asyncord.client.models.emoji import Emoji from asyncord.client.resources import APIResource from asyncord.typedefs import list_model from asyncord.urls import REST_API_URL @@ -40,7 +40,7 @@ def __init__( async def get_guild_emoji( self, emoji_id: SnowflakeInputType, - ) -> Emoji: + ) -> EmojiResponse: """Returns an emoji object for the given guild. Args: @@ -51,22 +51,22 @@ async def get_guild_emoji( """ url = self.emojis_url / str(emoji_id) resp = await self._http_client.get(url=url) - return Emoji.model_validate(resp.body) + return EmojiResponse.model_validate(resp.body) - async def get_guild_emojis(self) -> list[Emoji]: + async def get_guild_emojis(self) -> list[EmojiResponse]: """Returns a list of emoji objects for the given guild. Reference: https://discord.com/developers/docs/resources/emoji#list-guild-emojis """ resp = await self._http_client.get(url=self.emojis_url) - return list_model(Emoji).validate_python(resp.body) + return list_model(EmojiResponse).validate_python(resp.body) async def create_guild_emoji( self, emoji_data: CreateEmojiRequest, reason: str | None = None, - ) -> Emoji: + ) -> EmojiResponse: """Create a new emoji for the guild. Args: @@ -76,21 +76,21 @@ async def create_guild_emoji( Reference: https://discord.com/developers/docs/resources/emoji#create-guild-emoji """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} payload = emoji_data.model_dump(mode='json', exclude_unset=True) resp = await self._http_client.post(url=self.emojis_url, payload=payload, headers=headers) - return Emoji.model_validate(resp.body) + return EmojiResponse.model_validate(resp.body) async def update_guild_emoji( self, emoji_id: SnowflakeInputType, emoji_data: UpdateEmojiRequest, reason: str | None = None, - ) -> Emoji: + ) -> EmojiResponse: """Update the given emoji. Args: @@ -101,7 +101,7 @@ async def update_guild_emoji( Reference: https://discord.com/developers/docs/resources/emoji#modify-guild-emoji """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -109,7 +109,7 @@ async def update_guild_emoji( payload = emoji_data.model_dump(mode='json', exclude_unset=True) url = self.emojis_url / str(emoji_id) resp = await self._http_client.patch(url=url, payload=payload, headers=headers) - return Emoji.model_validate(resp.body) + return EmojiResponse.model_validate(resp.body) async def delete_guild_emoji( self, @@ -125,7 +125,7 @@ async def delete_guild_emoji( Reference: https://discord.com/developers/docs/resources/emoji#delete-guild-emoji """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} diff --git a/asyncord/client/guild_templates/models/requests.py b/asyncord/client/guild_templates/models/requests.py index ec0953a..5927acb 100644 --- a/asyncord/client/guild_templates/models/requests.py +++ b/asyncord/client/guild_templates/models/requests.py @@ -1,5 +1,7 @@ """Models for guild template requests.""" +from typing import Annotated + from pydantic import BaseModel, Field from asyncord.base64_image import Base64ImageInputType @@ -18,10 +20,10 @@ class CreateGuildFromTemplateRequest(BaseModel): https://discord.com/developers/docs/resources/guild-template#create-guild-from-guild-template-json-params """ - name: str = Field(None, min_length=2, max_length=100) + name: str | None = Field(min_length=2, max_length=100) """Name of the guild (2-100 characters).""" - icon: Base64ImageInputType | None + icon: Base64ImageInputType | None = None """Base64 encoded 128x128 image for the guild icon.""" @@ -32,10 +34,10 @@ class CreateGuildTemplateRequest(BaseModel): https://discord.com/developers/docs/resources/guild-template#create-guild-template-json-params """ - name: str = Field(None, min_length=1, max_length=100) + name: str = Field(min_length=1, max_length=100) """Name of the template (1-100 characters).""" - description: str | None = Field(None, max_length=120) + description: Annotated[str, Field(max_length=120)] | None = None """Description for the template (0-120 characters).""" @@ -46,8 +48,8 @@ class UpdateGuildTemplateRequest(BaseModel): https://discord.com/developers/docs/resources/guild-template#modify-guild-template-json-params """ - name: str | None = Field(None, min_length=1, max_length=100) + name: Annotated[str, Field(min_length=1, max_length=100)] | None = None """Name of the template (1-100 characters).""" - description: str | None = Field(None, max_length=120) + description: Annotated[str, Field(max_length=120)] | None = None """Description for the template (0-120 characters).""" diff --git a/asyncord/client/guild_templates/resources.py b/asyncord/client/guild_templates/resources.py index ef84e50..f5808c5 100644 --- a/asyncord/client/guild_templates/resources.py +++ b/asyncord/client/guild_templates/resources.py @@ -54,6 +54,7 @@ async def get_template( template_code: The template code. """ url = self.guilds_url / 'templates' / str(template_code) + resp = await self._http_client.get(url=url) return GuildTemplateResponse.model_validate(resp.body) @@ -80,9 +81,11 @@ async def create_guild_from_template( Reference: https://discord.com/developers/docs/resources/guild-template#create-guild-from-guild-template """ - url = self.templates_url / str(template_code) - payload = create_data.model_dump(mode='json', exclude_unset=True) + url = self.guilds_url / 'templates' / str(template_code) + + payload = create_data.model_dump(mode='json', exclude_none=True) resp = await self._http_client.post(url=url, payload=payload) + return GuildResponse.model_validate(resp.body) async def create_guild_template( @@ -98,6 +101,7 @@ async def create_guild_template( template_data: The template data. """ payload = template_data.model_dump(mode='json', exclude_unset=True) + resp = await self._http_client.post(url=self.templates_url, payload=payload) return GuildTemplateResponse.model_validate(resp.body) @@ -114,6 +118,7 @@ async def sync_guild_template( template_code: The template code. """ url = self.templates_url / str(template_code) + resp = await self._http_client.put(url=url) return GuildTemplateResponse.model_validate(resp.body) @@ -132,7 +137,9 @@ async def update_guild_template( template_data: The template data. """ url = self.templates_url / str(template_code) + payload = template_data.model_dump(mode='json', exclude_unset=True) + resp = await self._http_client.patch(url=url, payload=payload) return GuildTemplateResponse.model_validate(resp.body) @@ -151,5 +158,6 @@ async def delete_guild_template( template_code: The template code. """ url = self.templates_url / str(template_code) + resp = await self._http_client.delete(url=url) return GuildTemplateResponse.model_validate(resp.body) diff --git a/asyncord/client/guilds/models/requests.py b/asyncord/client/guilds/models/requests.py index bf53e32..5feb2d4 100644 --- a/asyncord/client/guilds/models/requests.py +++ b/asyncord/client/guilds/models/requests.py @@ -1,9 +1,9 @@ """Request models for guilds.""" -from typing import Self +from collections.abc import Sequence +from typing import Annotated, Any, Self -from fbenum.adapter import FallbackAdapter -from pydantic import BaseModel, Field, field_serializer, model_validator +from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator from asyncord.base64_image import Base64ImageInputType from asyncord.client.channels.models.requests.creation import CreateChannelRequestType @@ -128,7 +128,7 @@ class UpdateWelcomeScreenRequest(BaseModel): enabled: bool | None = None """Whether the welcome screen is enabled.""" - welcome_channels: list[WelcomeScreenChannel] | None = Field(None, max_length=5) + welcome_channels: Annotated[list[WelcomeScreenChannel], Field(max_length=5)] | None = None """Channels shown in the welcome screen. Up to 5 channels can be specified. @@ -169,13 +169,13 @@ class CreateAutoModerationRuleRequest(BaseModel): False by default. """ - exempt_roles: list[SnowflakeInputType] | None = Field(None, max_length=20) + exempt_roles: Annotated[list[SnowflakeInputType], Field(max_length=20)] | None = None """Role ids that should not be affected by the rule. Maximum of 20. """ - exempt_channels: list[SnowflakeInputType] | None = Field(None, max_length=50) + exempt_channels: Annotated[list[SnowflakeInputType], Field(max_length=50)] | None = None """Channel ids that should not be affected by the rule. Maximum of 50. @@ -206,36 +206,37 @@ class UpdateAutoModerationRuleRequest(BaseModel): https://discord.com/developers/docs/resources/auto-moderation#modify-auto-moderation-rule-json-params """ - name: str + name: str | None = None """Name of the rule.""" - event_type: FallbackAdapter[AutoModerationRuleEventType] + event_type: AutoModerationRuleEventType | None = None """Rule event type.""" trigger_metadata: TriggerMetadata | None = None """Rule trigger metadata. Required, but can be omited based on trigger type. + Reference: https://discord.com/developers/docs/resources/auto-moderation#auto-moderation-rule-object-trigger-metadata """ - actions: list[RuleAction] + actions: list[RuleAction] | None = None """Actions which will execute when the rule is triggered.""" - enabled: bool + enabled: bool | None = None """Whether the rule is enabled. False by default. """ - exempt_roles: list[SnowflakeInputType] = Field(max_length=20) + exempt_roles: Annotated[list[SnowflakeInputType], Field(max_length=20)] | None = None """Role ids that should not be affected by the rule. Maximum of 20. """ - exempt_channels: list[SnowflakeInputType] = Field(max_length=50) + exempt_channels: Annotated[list[SnowflakeInputType], Field(max_length=50)] | None = None """Channel ids that should not be affected by the rule. Maximum of 50. @@ -263,13 +264,19 @@ class OnboardingPromptOption(BaseModel): https://discord.com/developers/docs/resources/guild#guild-onboarding-object-prompt-option-structure """ - id: SnowflakeInputType + id: SnowflakeInputType | None = None """Id of the prompt option.""" - channel_ids: list[SnowflakeInputType] + title: str = Field(min_length=1, max_length=50) + """Title of the option.""" + + description: Annotated[str, Field(max_length=100)] | None = None + """Description of the option.""" + + channel_ids: list[SnowflakeInputType] | None = None """Id for channels a member is added to when the option is selected.""" - role_ids: list[SnowflakeInputType] + role_ids: set[SnowflakeInputType] | None = None """Id for roles assigned to a member when the option is selected.""" emoji_id: SnowflakeInputType | None = None @@ -281,11 +288,25 @@ class OnboardingPromptOption(BaseModel): emoji_animated: bool | None = None """Whether the emoji is animated.""" - title: str - """Title of the option.""" + @model_validator(mode='after') + def validate_emoji(self) -> Self: + """Check if emoji_id or emoji_name is provided, but not both.""" + if self.emoji_id and self.emoji_name: + raise ValueError('Emoji ID and name cannot be provided together.') - description: str | None = None - """Description of the option.""" + if self.emoji_animated and not self.emoji_id: + raise ValueError('Emoji animated cannot be provided without emoji ID.') + + return self + + @field_validator('channel_ids', 'role_ids', mode='before') + @classmethod + def validate_channel_role_ids(cls, channel_or_role_list: list[SnowflakeInputType]) -> list[SnowflakeInputType]: + """Validate channels and roles.""" + if not channel_or_role_list: + raise ValueError('Channel or role ids must be provided.') + + return channel_or_role_list class OnboardingPrompt(BaseModel): @@ -295,28 +316,28 @@ class OnboardingPrompt(BaseModel): https://discord.com/developers/docs/resources/guild#guild-onboarding-object-onboarding-prompt-structure """ - id: SnowflakeInputType + id: SnowflakeInputType | None = None """ID of the prompt.""" - type: OnboardingPromptType + type: OnboardingPromptType | None = None """Type of prompt.""" - options: list[OnboardingPromptOption] - """Options available within the prompt.""" - - title: str + title: str = Field(min_length=1, max_length=100) """Title of the prompt.""" - single_select: bool + options: list[OnboardingPromptOption] = Field(min_length=1, max_length=50) + """Options available within the prompt.""" + + single_select: bool | None = None """Indicates whether users are limited to selecting one option for the prompt.""" - required: bool + required: bool | None = None """Indicates whether the prompt is required. Before a user completes the onboarding flow. """ - in_boarding: bool + in_onboarding: bool | None = None """Indicates whether the prompt is present in the onboarding flow. If false, the prompt will only appear in the Channels & Roles tab. @@ -327,21 +348,43 @@ class UpdateOnboardingRequest(BaseModel): """Update onboarding settings. Reference: - https://discord.com/developers/docs/resources/guild#modify-guild-onboarding-json-params + https: // discord.com / developers / docs / resources / guild # modify-guild-onboarding-json-params """ - prompts: list[OnboardingPrompt] - """Prompts shown during onboarding and in customize community.""" + prompts: list[OnboardingPrompt] | None = None + """Prompts shown during onboarding and in customize community. + + If prompt has no ID, it will be set automatically to dummy because + the API requires it. + """ - default_channel_ids: list[SnowflakeInputType] + default_channel_ids: list[SnowflakeInputType] | None = None """Channel IDs that members get opted into automatically""" - enabled: bool + enabled: bool | None = None """Whether onboarding is enabled.""" - mode: OnboardingMode + mode: OnboardingMode | None = None """Current mode of onboarding.""" + @field_validator('prompts', mode='before') + @classmethod + def validate_prompts( + cls, + prompts: Sequence[OnboardingPrompt | dict[str, Any]], + ) -> Sequence[OnboardingPrompt | dict[str, Any]]: + """Set prompt ID if not set.""" + counter = 0 + for prompt in prompts: + if isinstance(prompt, dict): + if prompt.get('id'): + prompt['id'] = counter + elif not prompt.id: + prompt.id = counter + + counter += 1 # noqa: SIM113 + return prompts + class PruneRequest(BaseModel): """Data for pruning guild members. @@ -350,13 +393,13 @@ class PruneRequest(BaseModel): https://discord.com/developers/docs/resources/guild#begin-guild-prune """ - days: int + days: int | None = None """Number of days to prune members for.""" compute_prune_count: bool | None = None """Whether to compute the number of pruned members.""" - include_roles: set[SnowflakeInputType] | None = None + include_roles: Sequence[SnowflakeInputType] | None = None """Roles to include in the prune.""" @field_serializer('include_roles', when_used='json-unless-none', return_type=str) diff --git a/asyncord/client/guilds/models/responses.py b/asyncord/client/guilds/models/responses.py index 3ceb137..786417f 100644 --- a/asyncord/client/guilds/models/responses.py +++ b/asyncord/client/guilds/models/responses.py @@ -5,9 +5,12 @@ from fbenum.adapter import FallbackAdapter from pydantic import BaseModel +from yarl import URL +from asyncord import urls from asyncord.client.channels.models.common import ChannelType from asyncord.client.commands.models.responses import ApplicationCommandResponse +from asyncord.client.emojis.models.responses import EmojiResponse from asyncord.client.guilds.models.common import ( AuditLogEvents, ExpireBehaviorOut, @@ -17,7 +20,6 @@ OnboardingPromptType, ) from asyncord.client.models.automoderation import AutoModerationRule -from asyncord.client.models.emoji import Emoji from asyncord.client.models.stickers import Sticker from asyncord.client.roles.models.responses import RoleResponse from asyncord.client.scheduled_events.models.responses import ScheduledEventResponse @@ -153,7 +155,7 @@ class GuildResponse(BaseModel): roles: list[RoleResponse] | None = None """Roles in the guild.""" - emojis: list[Emoji] | None = None + emojis: list[EmojiResponse] | None = None """Custom guild emojis.""" features: list[str] @@ -271,7 +273,7 @@ class GuildPreviewResponse(BaseModel): Only present for guilds with the "DISCOVERABLE" feature. """ - emojis: list[Emoji] | None = None + emojis: list[EmojiResponse] | None = None """Custom guild emojis.""" features: list[str] @@ -300,7 +302,7 @@ class PruneResponse(BaseModel): https://discord.com/developers/docs/resources/guild#get-guild-prune-count """ - pruned: int + pruned: int | None """Number of members pruned.""" @@ -449,6 +451,11 @@ class InviteResponse(BaseModel): and contains a valid id. """ + @property + def url(self) -> URL: + """Invite URL.""" + return urls.INVITE_BASE_URL / self.code + class IntegrationAccountOut(BaseModel): """Integration Account Structure. @@ -845,7 +852,7 @@ class OnboardingPromptOptionOut(BaseModel): role_ids: list[Snowflake] """Id for roles assigned to a member when the option is selected.""" - emoji: Emoji | None = None + emoji: EmojiResponse | None = None """Emoji of the option.""" title: str @@ -882,7 +889,7 @@ class OnboardingPromptOut(BaseModel): Before a user completes the onboarding flow. """ - in_boarding: bool + in_onboarding: bool """Indicates whether the prompt is present in the onboarding flow. If false, the prompt will only appear in the Channels & Roles tab. diff --git a/asyncord/client/guilds/resources.py b/asyncord/client/guilds/resources.py index f314457..bf00e78 100644 --- a/asyncord/client/guilds/resources.py +++ b/asyncord/client/guilds/resources.py @@ -7,6 +7,7 @@ from __future__ import annotations import datetime +from collections.abc import Sequence from typing import TYPE_CHECKING from asyncord.base64_image import Base64Image @@ -35,7 +36,7 @@ from asyncord.client.resources import APIResource from asyncord.client.roles.resources import RoleResource from asyncord.client.scheduled_events.resources import ScheduledEventsResource -from asyncord.typedefs import list_model +from asyncord.typedefs import CURRENT_USER, list_model from asyncord.urls import REST_API_URL if TYPE_CHECKING: @@ -203,14 +204,14 @@ async def update_mfa(self, guild_id: SnowflakeInputType, level: MFALevel) -> MFA """ url = self.guilds_url / str(guild_id) / 'mfa' payload = {'level': level} - resp = await self._http_client.patch(url=url, payload=payload) + resp = await self._http_client.post(url=url, payload=payload) return MFALevel(int(resp.raw_body)) async def get_prune_count( self, guild_id: SnowflakeInputType, days: int | None = None, - include_roles: list[SnowflakeInputType] | None = None, + include_roles: Sequence[SnowflakeInputType] | None = None, reason: str | None = None, ) -> PruneResponse: """Get the number of members that would be removed from a guild if pruned. @@ -233,7 +234,7 @@ async def get_prune_count( if include_roles is not None: url_params['include_roles'] = ','.join(map(str, include_roles)) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -272,7 +273,7 @@ async def begin_prune( """ payload = prune_data.model_dump(mode='json', exclude_unset=True) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -363,10 +364,10 @@ async def delete_integration( """ url = self.guilds_url / str(guild_id) / 'integrations' / str(integration_id) - if reason is None: - headers = {} - else: + if reason: headers = {AUDIT_LOG_REASON: reason} + else: + headers = {} await self._http_client.delete(url=url, headers=headers) @@ -413,7 +414,7 @@ async def update_widget( """ url = self.guilds_url / str(guild_id) / 'widget' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -477,7 +478,7 @@ async def update_onboarding( """ url = self.guilds_url / str(guild_id) / 'onboarding' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -523,7 +524,7 @@ async def update_welcome_screen( """ url = self.guilds_url / str(guild_id) / 'welcome-screen' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -550,7 +551,7 @@ async def update_current_user_voice_state( suppress: Whether the current user should be suppressed. request_to_speak_timestamp: Time at which the current user requested to speak. """ - url = self.guilds_url / str(guild_id) / 'voice-states' / '@me' + url = self.guilds_url / str(guild_id) / 'voice-states' / CURRENT_USER payload = {} if channel_id is not None: diff --git a/asyncord/client/http/middleware/base.py b/asyncord/client/http/middleware/base.py index 52594ef..3d6b994 100644 --- a/asyncord/client/http/middleware/base.py +++ b/asyncord/client/http/middleware/base.py @@ -50,7 +50,7 @@ async def __call__( ... -class BaseMiddleware(ABC): +class BaseMiddleware(ABC): # pragma: no cover """Base middleware class.""" @abstractmethod diff --git a/asyncord/client/http/models.py b/asyncord/client/http/models.py index c356655..40661c8 100644 --- a/asyncord/client/http/models.py +++ b/asyncord/client/http/models.py @@ -107,7 +107,7 @@ def __iter__(self) -> Iterator[tuple[str, FormField]]: """Iterate over the form fields.""" yield from ((name, field) for name, field in self._fields.items()) - def __len__(self) -> int: + def __len__(self) -> int: # pragma: no cover """Return the number of fields in the form data.""" return len(self._fields) diff --git a/asyncord/client/interactions/models/requests.py b/asyncord/client/interactions/models/requests.py index 0a8162b..a82932e 100644 --- a/asyncord/client/interactions/models/requests.py +++ b/asyncord/client/interactions/models/requests.py @@ -22,10 +22,10 @@ from asyncord.client.messages.models.common import MessageFlags from asyncord.client.messages.models.requests.components import ActionRow, Component, TextInput from asyncord.client.messages.models.requests.embeds import Embed -from asyncord.client.messages.models.requests.messages import AllowedMentions, Attachment, BaseMessage +from asyncord.client.messages.models.requests.messages import AllowedMentions, BaseMessage +from asyncord.client.models.attachments import Attachment, AttachmentContentType __all__ = ( - 'INTERACTIONS_CAN_CONTAIN_FILES', 'InteractionRespAutocompleteRequest', 'InteractionRespDeferredMessageRequest', 'InteractionRespMessageRequest', @@ -63,10 +63,10 @@ class InteractionRespMessageRequest(BaseMessage): """ - tts: bool = False + tts: bool | None = None """True if this is a TTS message.""" - content: str | None = Field(None, max_length=2000) + content: Annotated[str | None, Field(max_length=2000)] = None """Message content.""" embeds: list[Embed] | None = None @@ -84,7 +84,7 @@ class InteractionRespMessageRequest(BaseMessage): components: Sequence[Component] | Component | None = None """List of components included in the message..""" - attachments: list[Attachment] | None = None + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None """List of attachment objects with filename and description. Reference: @@ -103,10 +103,10 @@ class InteractionRespUpdateMessageRequest(BaseMessage): """ - tts: bool = False + tts: bool | None = None """True if this is a TTS message.""" - content: str | None = Field(None, max_length=2000) + content: Annotated[str | None, Field(max_length=2000)] = None """Message content.""" embeds: list[Embed] | None = None @@ -124,7 +124,7 @@ class InteractionRespUpdateMessageRequest(BaseMessage): components: Sequence[Component] | Component | None = None """List of components included in the message..""" - attachments: list[Attachment] | None = None + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None """List of attachment objects with filename and description. Reference: @@ -142,7 +142,7 @@ class InteractionRespDeferredMessageRequest(BaseMessage): https://discord.com/developers/docs/interactions/receiving-and-responding#interaction-response-object-messages """ - tts: bool = False + tts: bool | None = None """True if this is a TTS message.""" content: Annotated[str | None, Field(max_length=2000)] = None @@ -163,7 +163,7 @@ class InteractionRespDeferredMessageRequest(BaseMessage): components: Sequence[Component] | Component | None = None """List of components.""" - attachments: list[Attachment] | None = None + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None """Attachment objects with filename and description. Reference: @@ -200,7 +200,7 @@ class InteractionRespUpdateDeferredMessageRequest(BaseMessage): components: Sequence[Component] | Component | None = None """List of components.""" - attachments: list[Attachment] | None = None + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None """Attachment objects with filename and description. Reference: @@ -305,14 +305,6 @@ def validate_components( - InteractionRespModalRequest: A request model for sending a Modal interaction response. """ -INTERACTIONS_CAN_CONTAIN_FILES = ( - InteractionRespMessageRequest - | InteractionRespUpdateMessageRequest - | InteractionRespDeferredMessageRequest - | InteractionRespUpdateDeferredMessageRequest -) -"""Interaction response types that can contain files.""" - class RootInteractionResponse(BaseModel): """Root interaction response request data. @@ -324,23 +316,26 @@ class RootInteractionResponse(BaseModel): data: InteractionResponseRequestType """Interaction response data.""" - type: InteractionResponseType | None = None + type: Annotated[InteractionResponseType | None, Field(validate_default=True)] = None """Interaction response type. This field always has a value, but it is not required to be provided. """ - @field_validator('type') + @field_validator('type', mode='before') @classmethod def validate_type( cls, - type_value: InteractionResponseType | None, + type_value: int | InteractionResponseType | None, field_info: ValidationInfo, ) -> InteractionResponseType: """Validate the type of interaction response request. If type is not provided, extract it from data. """ + if type_value is not None: + type_value = InteractionResponseType(type_value) + data_value = cast(InteractionResponseRequestType, field_info.data.get('data')) calculated_type = cls._calculate_data_type(data_value) @@ -350,6 +345,9 @@ def validate_type( if type_value != calculated_type: raise ValueError('Provided type is not valid for the given data') + # type was set correctly by user, just return it + # it's very rare event that we reach here + # because user shouldn't set type of root model and shouldn't use it in general return type_value @classmethod @@ -360,6 +358,8 @@ def _calculate_data_type(cls, data_value: InteractionResponseRequestType) -> Int match data_value: case InteractionRespMessageRequest(): calculated_type = InteractionResponseType.CHANNEL_MESSAGE_WITH_SOURCE + case InteractionRespUpdateMessageRequest(): + calculated_type = InteractionResponseType.UPDATE_MESSAGE case InteractionRespDeferredMessageRequest(): calculated_type = InteractionResponseType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE case InteractionRespUpdateDeferredMessageRequest(): diff --git a/asyncord/client/interactions/resources.py b/asyncord/client/interactions/resources.py index 95f271e..93ff332 100644 --- a/asyncord/client/interactions/resources.py +++ b/asyncord/client/interactions/resources.py @@ -6,14 +6,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from asyncord.client.interactions.models.requests import ( - INTERACTIONS_CAN_CONTAIN_FILES, InteractionRespPongRequest, RootInteractionResponse, ) -from asyncord.client.models.attachments import make_attachment_payload +from asyncord.client.models.attachments import Attachment, make_payload_with_attachments from asyncord.client.resources import APIResource from asyncord.snowflake import SnowflakeInputType from asyncord.urls import REST_API_URL @@ -45,15 +44,24 @@ async def send_response( """ url = self.interactions_url / str(interaction_id) / interaction_token / 'callback' if isinstance(interaction_response, InteractionRespPongRequest): - # Pong response doesn't have a data key - wraped_model = interaction_response + # Pong response doesn't have a data key, only type + # It's already as a root model, but without data key + root_model = interaction_response else: - wraped_model = RootInteractionResponse(data=interaction_response) + root_model = RootInteractionResponse(data=interaction_response) - if isinstance(interaction_response, INTERACTIONS_CAN_CONTAIN_FILES): - payload = make_attachment_payload(interaction_response, wraped_model) - else: - payload = interaction_response.model_dump(mode='json') + # if we have any attachments, we need to send them as form + # Attachment after model mapping is a list of Attachment objects alwayszs or None + attachments = cast( + list[Attachment] | None, + getattr(interaction_response, 'attachments', None), + ) + payload = make_payload_with_attachments( + root_model, + attachments=attachments, + exclude_unset=False, + exclude_none=True, + ) await self._http_client.post(url=url, payload=payload) diff --git a/asyncord/client/invites/resources.py b/asyncord/client/invites/resources.py index dcc9c2d..146e846 100644 --- a/asyncord/client/invites/resources.py +++ b/asyncord/client/invites/resources.py @@ -57,8 +57,8 @@ async def get_invite( query_param['with_counts'] = str(with_counts) if with_expiration is not None: query_param['with_expiration'] = str(with_expiration) - if guild_scheduled_event_id is not None: - query_param['guild_scheduled_event_id'] = guild_scheduled_event_id + if guild_scheduled_event_id: + query_param['guild_scheduled_event_id'] = str(guild_scheduled_event_id) url = self.invites_url / str(invite_code) % query_param @@ -75,7 +75,7 @@ async def delete_invite( Reference: https://discord.com/developers/docs/resources/invite#delete-invite """ - if reason is not None: + if reason: headers = {'X-Audit-Log-Reason': reason} else: headers = {} diff --git a/asyncord/client/members/resources.py b/asyncord/client/members/resources.py index 9b5e3ea..5f4864f 100644 --- a/asyncord/client/members/resources.py +++ b/asyncord/client/members/resources.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from asyncord.client.http.headers import AUDIT_LOG_REASON from asyncord.client.members.models.responses import MemberResponse @@ -48,7 +48,11 @@ async def get(self, user_id: SnowflakeInputType) -> MemberResponse: resp = await self._http_client.get(url=url) return MemberResponse.model_validate(resp.body) - async def get_list(self, limit: int | None = None, after: SnowflakeInputType | None = None) -> list[MemberResponse]: + async def get_list( + self, + limit: int | None = None, + after: SnowflakeInputType | None = None, + ) -> list[MemberResponse]: """List members of guild. This endpoint is restricted according to whether the GUILD_MEMBERS Privileged @@ -75,6 +79,9 @@ async def get_list(self, limit: int | None = None, after: SnowflakeInputType | N async def search(self, nick_or_name: str, limit: int | None = None) -> list[MemberResponse]: """Search members of a guild by username or nickname. + Reference: + https://discord.com/developers/docs/resources/guild#search-guild-members + Args: nick_or_name: Name or nickname of the member to search for. limit: Maximum number of members to return. @@ -89,13 +96,14 @@ async def search(self, nick_or_name: str, limit: int | None = None) -> list[Memb if limit is not None: url_params['limit'] = limit - url = self.members_url % url_params + url = self.members_url / 'search' % url_params resp = await self._http_client.get(url=url) return list_model(MemberResponse).validate_python(resp.body) async def update( self, - user_id: SnowflakeInputType, + *, + user_id: SnowflakeInputType | Literal['@me'], member_data: UpdateMemberRequest, reason: str | None = None, ) -> MemberResponse: @@ -110,15 +118,18 @@ async def update( The updated member. """ url = self.members_url / str(user_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} - resp = await self._http_client.patch(url=url, payload=member_data, headers=headers) + payload = member_data.model_dump(exclude_unset=True) + + resp = await self._http_client.patch(url=url, payload=payload, headers=headers) return MemberResponse(**resp.body) async def update_current_member( self, + *, nickname: str | None, reason: str | None = None, ) -> MemberResponse: @@ -133,7 +144,7 @@ async def update_current_member( The updated member. """ url = self.members_url / '@me' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -143,6 +154,7 @@ async def update_current_member( async def add_role( self, + *, user_id: SnowflakeInputType, role_id: SnowflakeInputType, reason: str | None = None, @@ -155,7 +167,7 @@ async def add_role( reason: Reason for adding the role to the member. Defaults to None. """ url = self.members_url / str(user_id) / 'roles' / str(role_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -163,6 +175,7 @@ async def add_role( async def remove_role( self, + *, user_id: SnowflakeInputType, role_id: SnowflakeInputType, reason: str | None = None, @@ -175,7 +188,7 @@ async def remove_role( reason: Reason for removing the role from the member. Defaults to None. """ url = self.members_url / str(user_id) / 'roles' / str(role_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -189,7 +202,7 @@ async def kick(self, user_id: SnowflakeInputType, reason: str | None = None) -> reason: Reason for kicking the member. Defaults to None. """ url = self.members_url / str(user_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -199,6 +212,7 @@ async def kick(self, user_id: SnowflakeInputType, reason: str | None = None) -> # Needs users oauth2 token. async def add_guild_member( self, + *, user_id: SnowflakeInputType, add_member_data: AddMemberRequest, ) -> MemberResponse | None: diff --git a/asyncord/client/messages/models/requests/components.py b/asyncord/client/messages/models/requests/components.py index a639587..ee4ea89 100644 --- a/asyncord/client/messages/models/requests/components.py +++ b/asyncord/client/messages/models/requests/components.py @@ -8,7 +8,7 @@ from collections import Counter from collections.abc import Sequence -from typing import Annotated, Any, Literal, Self +from typing import Annotated, Literal, Self from typing import get_args as get_typing_args from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator @@ -39,22 +39,29 @@ class BaseComponent(BaseModel): """Base component class.""" - type: ComponentType - """Type of the component.""" + type: ComponentType = None # type: ignore + """Type of the component. + + None value just helps to avoid a warning about the required field. + This field must be set in subclasses. + """ - def __init__(self, **data: Any) -> None: # noqa: ANN401 - """Initialize the component.""" - super().__init__(**data) - # Add `type` to `model_fields_set` to make `dict(exclude_unset)` work properly. - # We don't need to set 'type' field because it's already set in a component class, - # but we need to send it to Discord excluding another unset fields. + @model_validator(mode='after') + def set_type_field_set(self) -> Self: + """Set `type` field in `model_fields_set`. + + Add `type` to `model_fields_set` to make `dict(exclude_unset)` work properly. + We don't need to set 'type' field because it's already set in a component subclasses class, + but we need to send it to Discord excluding another unset fields. + """ self.model_fields_set.add('type') + return self class ComponentEmoji(BaseModel): """Emoji to be displayed on the button. - At least one of `name` or `id` must be provided. + At least one of `name` or `id` must be provided, and it can be only one of them. Name is used for unicode emojis, Id is a snowflake of custom emojis. """ @@ -74,6 +81,9 @@ def name_or_id_required(self) -> Self: if not self.name and not self.id: raise ValueError('At least one of `name` or `id` must be provided') + if self.name and self.id: + raise ValueError('Only one of `name` or `id` must be provided') + return self @@ -90,7 +100,7 @@ class Button(BaseComponent): https://discord.com/developers/docs/interactions/message-components#buttons """ - type: Literal[ComponentType.BUTTON] = ComponentType.BUTTON + type: Literal[ComponentType.BUTTON] = ComponentType.BUTTON # type: ignore """Type of the component. Only `ComponentType.BUTTON` is allowed. @@ -99,7 +109,7 @@ class Button(BaseComponent): style: ButtonStyle = ButtonStyle.PRIMARY """Style of the button.""" - label: str | None = Field(None, max_length=80) + label: Annotated[str, Field(max_length=80)] | None = None """Text to be displayed on the button. Max 80 characters. @@ -108,7 +118,7 @@ class Button(BaseComponent): emoji: ComponentEmoji | None = None """Emoji to be displayed on the button.""" - custom_id: str | None = Field(None, max_length=100) + custom_id: Annotated[str, Field(max_length=100)] | None = None """Developer-defined identifier for the button. Max 100 characters. @@ -159,7 +169,7 @@ class SelectMenuOption(BaseModel): Max 100 characters. """ - description: str | None = Field(None, max_length=100) + description: Annotated[str, Field(max_length=100)] | None = None """Additional description of the option. Max 100 characters. @@ -197,10 +207,10 @@ class SelectMenu(BaseComponent): https://discord.com/developers/docs/interactions/message-components#select-menus """ - type: Literal[SelectComponentType] = ComponentType.STRING_SELECT + type: Literal[SelectComponentType] = ComponentType.STRING_SELECT # type: ignore """Type of the component of select menu.""" - custom_id: str = Field(None, max_length=100) + custom_id: Annotated[str, Field(max_length=100)] | None = None """Developer-defined identifier for the select menu. Max 100 characters. @@ -216,7 +226,7 @@ class SelectMenu(BaseComponent): channel_types: list[ChannelType] = Field(default_factory=list) """List of channel types to include in the channel select component""" - placeholder: str | None = Field(None, max_length=150) + placeholder: Annotated[str, Field(max_length=150)] | None = None """Placeholder text if nothing is selected; max 150 characters.""" default_values: list[SelectDefaultValue] | None = None @@ -229,10 +239,10 @@ class SelectMenu(BaseComponent): Number of default values must be in the range defined by min_values and max_values. """ - min_values: int = Field(1, ge=0, le=25) + min_values: Annotated[int, Field(ge=0, le=25)] = 1 """Minimum number of items that must be chosen; default 1, min 0, max 25.""" - max_values: int = Field(1, ge=0, le=25) + max_values: Annotated[int, Field(ge=0, le=25)] = 1 """Maximum number of items that can be chosen; default 1, max 25.""" disabled: bool = False @@ -313,7 +323,7 @@ class TextInput(BaseComponent): https://discord.com/developers/docs/interactions/message-components#text-input-object-text-input-structure """ - type: Literal[ComponentType.TEXT_INPUT] = ComponentType.TEXT_INPUT + type: Literal[ComponentType.TEXT_INPUT] = ComponentType.TEXT_INPUT # type: ignore """Type of the component. Only `ComponentType.TEXT_INPUT` is allowed. @@ -334,13 +344,13 @@ class TextInput(BaseComponent): Max 45 characters. """ - min_length: int | None = Field(None, ge=0, le=4000) + min_length: Annotated[int, Field(ge=0, le=4000)] | None = None """Minimum length of the text input. Max 4000 characters. """ - max_length: int | None = Field(None, ge=1, le=4000) + max_length: Annotated[int, Field(ge=1, le=4000)] | None = None """Maximum length of the text input. Max 4000 characters. @@ -349,22 +359,28 @@ class TextInput(BaseComponent): required: bool = True """Whether the text input is required to be filled.""" - value: str | None = Field(None, max_length=4000) + value: Annotated[str, Field(max_length=4000)] | None = None """Pre-filled value for this component. Max 4000 characters. """ - placeholder: str | None = Field(None, max_length=100) + placeholder: Annotated[str, Field(max_length=100)] | None = None """Placeholder text. Max 100 characters. """ - def __init__(self, **data: Any) -> None: # noqa: ANN401 - """Create a new text input component.""" - super().__init__(**data) # type: ignore + @model_validator(mode='after') + def set_style_field_set(self) -> Self: + """Set `style` field in `model_fields_set`. + + Add `style` to `model_fields_set` to make `dict(exclude_unset)` work properly. + We don't need to set 'style' field because it's already set in a component subclasses class, + but we need to send it to Discord excluding another unset fields. + """ self.model_fields_set.add('style') + return self @field_validator('max_length') def validate_length(cls, max_length: int | None, field_info: ValidationInfo) -> int | None: @@ -423,9 +439,9 @@ def __init__(self, components: Sequence[Component | TextInput] | Component | Tex Args: components: Components in the action row. """ - super().__init__(components=components) + super().__init__(components=components) # type: ignore - type: Literal[ComponentType.ACTION_ROW] = ComponentType.ACTION_ROW + type: Literal[ComponentType.ACTION_ROW] = ComponentType.ACTION_ROW # type: ignore """Type of the component.""" components: Annotated[Sequence[Component | TextInput], Field(min_length=1, max_length=5)] | Component | TextInput diff --git a/asyncord/client/messages/models/requests/messages.py b/asyncord/client/messages/models/requests/messages.py index 3fb5429..c8ac351 100644 --- a/asyncord/client/messages/models/requests/messages.py +++ b/asyncord/client/messages/models/requests/messages.py @@ -9,7 +9,7 @@ from collections.abc import Sequence from io import BufferedReader, IOBase from pathlib import Path -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, Self from pydantic import ( BaseModel, @@ -24,7 +24,7 @@ from asyncord.client.messages.models.requests.components import ActionRow, Component from asyncord.client.messages.models.requests.embeds import Embed from asyncord.client.models.attachments import Attachment, AttachmentContentType -from asyncord.client.polls.models.requests import PollRequest +from asyncord.client.polls.models.requests import Poll from asyncord.snowflake import SnowflakeInputType __ALL__ = ( @@ -48,8 +48,27 @@ class BaseMessage(BaseModel): Contains axillary validation methods. """ - @model_validator(mode='before') - def has_any_content(cls, values: dict[str, Any]) -> dict[str, Any]: + content: Annotated[str | None, Field(max_length=2000)] = None + """Message content.""" + + embeds: list[Embed] | None = None + """Embedded rich content.""" + + components: Sequence[Component] | Component | None = None + """Components to include with the message.""" + + sticker_ids: list[SnowflakeInputType] | None = None + """Sticker ids to include with the message.""" + + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None + """List of attachment object. + + Reference: + https://discord.com/developers/docs/reference#uploading-files + """ + + @model_validator(mode='after') + def has_any_content(self) -> Self: """Validate message content. Reference: @@ -64,20 +83,22 @@ def has_any_content(cls, values: dict[str, Any]) -> dict[str, Any]: Raises: ValueError: If the message has no content or embeds. """ + # fmt: off has_any_content = bool( - values.get('content', False) - or values.get('embeds', False) - or values.get('sticker_ids', False) - or values.get('components', False) - or values.get('attachments', False), + self.content + or self.embeds + or self.sticker_ids + or self.components + or self.attachments, ) + # fmt: on if not has_any_content: raise ValueError( 'Message must have content, embeds, stickers, components or files.', ) - return values + return self @field_validator('embeds', check_fields=False) def validate_embeds(cls, embeds: list[Embed] | None) -> list[Embed] | None: @@ -110,7 +131,7 @@ def validate_embeds(cls, embeds: list[Embed] | None) -> list[Embed] | None: return embeds @field_validator('attachments', mode='before', check_fields=False) - def convert_files_to_attachments( + def convert_attachments( cls, attachments: Sequence[Attachment | AttachmentContentType], ) -> list[Attachment]: @@ -163,7 +184,12 @@ def validate_attachments(cls, attachments: list[Attachment] | None) -> list[Atta raise ValueError('Do not attach attachments must have content') if attachment.id is None: - attachment.id = index + # we do not want to modify the original attachment + # it helps to use the same object in many requests + # otherwise, this object can be conflicting with other because + # it will have the id after the first request + new_attachment_obj_with_id = attachment.model_copy(update={'id': index}) + attachments[index] = new_attachment_obj_with_id return attachments @@ -264,10 +290,10 @@ class AllowedMentions(BaseModel): parse: list[AllowedMentionType] | None = None """Array of allowed mention types to parse from the content.""" - roles: list[SnowflakeInputType] | None = Field(None, max_length=100) + roles: Annotated[list[SnowflakeInputType], Field(max_length=100)] | None = None """Array of role IDs to mention.""" - users: list[SnowflakeInputType] | None = Field(None, max_length=100) + users: Annotated[list[SnowflakeInputType], Field(max_length=100)] | None = None """Array of user IDs to mention.""" replied_user: bool | None = None @@ -335,7 +361,7 @@ class CreateMessageRequest(BaseMessage): sticker_ids: list[SnowflakeInputType] | None = None """Sticker ids to include with the message.""" - attachments: list[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None """List of attachment object. Reference: @@ -355,7 +381,7 @@ class CreateMessageRequest(BaseMessage): I set it to True because it will be default behavior in the near future. """ - poll: PollRequest | None = None + poll: Poll | None = None """A poll.""" diff --git a/asyncord/client/messages/models/responses/messages.py b/asyncord/client/messages/models/responses/messages.py index 4b44f8a..67002ce 100644 --- a/asyncord/client/messages/models/responses/messages.py +++ b/asyncord/client/messages/models/responses/messages.py @@ -11,6 +11,7 @@ from asyncord.client.channels.models.common import ChannelType from asyncord.client.channels.models.responses import ChannelResponse +from asyncord.client.emojis.models.responses import EmojiResponse from asyncord.client.interactions.models.common import InteractionType from asyncord.client.members.models.common import GuildMemberFlags from asyncord.client.members.models.responses import MemberResponse @@ -18,7 +19,6 @@ from asyncord.client.messages.models.responses.components import ComponentOut from asyncord.client.messages.models.responses.embeds import EmbedOut from asyncord.client.models.attachments import AttachmentFlags -from asyncord.client.models.emoji import Emoji from asyncord.client.models.stickers import StickerFormatType from asyncord.client.polls.models.responses import PollResponse from asyncord.client.roles.models.responses import RoleResponse @@ -302,7 +302,7 @@ class ReactionOut(BaseModel): me_burst: bool """Whether the current user super-reacted using this emoji.""" - emoji: Emoji + emoji: EmojiResponse """Emoji information.""" burst_colors: list[Color] diff --git a/asyncord/client/messages/resources.py b/asyncord/client/messages/resources.py index 8dd6cc2..31d5dd9 100644 --- a/asyncord/client/messages/resources.py +++ b/asyncord/client/messages/resources.py @@ -2,11 +2,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Sequence +from typing import TYPE_CHECKING, cast from asyncord.client.http.headers import AUDIT_LOG_REASON from asyncord.client.messages.models.responses.messages import MessageResponse -from asyncord.client.models.attachments import make_attachment_payload +from asyncord.client.models.attachments import Attachment, make_payload_with_attachments from asyncord.client.reactions.resources import ReactionResource from asyncord.client.resources import APIResource from asyncord.typedefs import list_model @@ -93,7 +94,8 @@ async def create(self, message_data: CreateMessageRequest) -> MessageResponse: Created message object. """ url = self.messages_url - payload = make_attachment_payload(message_data) + attachments = cast(list[Attachment] | None, message_data.attachments) + payload = make_payload_with_attachments(message_data, attachments=attachments) resp = await self._http_client.post(url=url, payload=payload) return MessageResponse.model_validate(resp.body) @@ -109,7 +111,8 @@ async def update(self, message_id: SnowflakeInputType, message_data: UpdateMessa Updated message object. """ url = self.messages_url / str(message_id) - payload = make_attachment_payload(message_data) + attachments = cast(list[Attachment] | None, message_data.attachments) + payload = make_payload_with_attachments(message_data, attachments) resp = await self._http_client.patch(url=url, payload=payload) return MessageResponse(**resp.body) @@ -123,7 +126,7 @@ async def delete(self, message_id: SnowflakeInputType, reason: str | None = None """ url = self.messages_url / str(message_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -132,7 +135,7 @@ async def delete(self, message_id: SnowflakeInputType, reason: str | None = None async def bulk_delete( self, - message_ids: list[SnowflakeInputType], + message_ids: Sequence[SnowflakeInputType], reason: str | None = None, ) -> None: """Delete multiple messages. @@ -144,7 +147,7 @@ async def bulk_delete( url = self.messages_url / 'bulk-delete' payload = {'messages': [str(message_id) for message_id in message_ids]} - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -194,7 +197,7 @@ async def pin_message( """ url = self.channels_url / str(channel_id) / 'pins' / str(message_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -215,7 +218,7 @@ async def unpin_message( """ url = self.channels_url / str(channel_id) / 'pins' / str(message_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} diff --git a/asyncord/client/models/attachments.py b/asyncord/client/models/attachments.py index 994b43f..c4185c8 100644 --- a/asyncord/client/models/attachments.py +++ b/asyncord/client/models/attachments.py @@ -3,9 +3,10 @@ from __future__ import annotations import enum +from collections.abc import Sequence from io import BufferedReader, IOBase from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, cast +from typing import Annotated, Any from pydantic import AnyHttpUrl, BaseModel, Field @@ -13,10 +14,6 @@ from asyncord.client.http.models import FormField, FormPayload from asyncord.snowflake import SnowflakeInputType -if TYPE_CHECKING: - # It fixes the circular import issue - from asyncord.client.messages.models.requests.messages import BaseMessage - __ALL__ = ( 'AttachmentContentType', 'AttachmentFlags', @@ -117,31 +114,39 @@ class Attachment(BaseModel, arbitrary_types_allowed=True): """ -def make_attachment_payload( - attachment_model_data: BaseMessage, - root_payload: BaseModel | None = None, +def make_payload_with_attachments( + json_payload: BaseModel | dict[str, Any], + attachments: Sequence[Attachment] | None = None, + exclude_unset: bool = True, + exclude_none: bool = False, ) -> FormPayload | dict[str, Any]: - """Convert message model data with attachments to a payload. + """Prepare a payload with possible attachments. + + !!! WARNING + `json_payload` will be converted from model **excluding unset fields**! + It means that if you want to send a message with saved unset values, + you should pass the message already dumped to json or set `exclude_unset` to False. + Be creful! If attachments are present in the message model data, the attachments will be converted to form fields and the message model data will be converted to a json payload field. Otherwise, the message model data will be converted to a json payload - dict. Args: - attachment_model_data: Message model data with attachments. - root_payload: Root payload model data which will be sent as json payload. - It will replace the attachment content in the payload. But the attachment - will still be sent as a form field. + json_payload: Base message model or raw dict data which will be sent as json payload. + attachments: List of attachments to be sent with the message. + exclude_unset: Whether to exclude unset fields from the json payload. Defaults to True. + exclude_none: Whether to exclude None fields from the json payload. Defaults to False. """ - payload_model_data = root_payload or attachment_model_data - json_payload = payload_model_data.model_dump(mode='json', exclude_unset=True) - if not attachment_model_data.attachments: # type: ignore + if isinstance(json_payload, BaseModel): + json_payload = json_payload.model_dump( + mode='json', + exclude_unset=exclude_unset, + exclude_none=exclude_none, + ) + if not attachments: # type: ignore return json_payload - # All base message models have attachments attribute. - # We can safely cast it to a list of attachments. - attachments = cast(list[Attachment], attachment_model_data.attachments) # type: ignore - file_form_fields = { f'files[{attachment.id}]': FormField( value=attachment.content, diff --git a/asyncord/client/models/automoderation.py b/asyncord/client/models/automoderation.py index e041ca6..3a0855b 100644 --- a/asyncord/client/models/automoderation.py +++ b/asyncord/client/models/automoderation.py @@ -3,6 +3,7 @@ from __future__ import annotations import enum +from typing import Annotated from fbenum.adapter import FallbackAdapter from pydantic import BaseModel, Field @@ -88,7 +89,7 @@ class TriggerMetadata(BaseModel): https://discord.com/developers/docs/resources/auto-moderation#auto-moderation-rule-object-keyword-matching-strategies """ - regex_patterns: list[str] | None = Field(None, max_length=10) + regex_patterns: Annotated[list[str] | None, Field(max_length=10)] = None """regular expression patterns which will be matched against content (Maximum of 10) Associated with `TriggerType.KEYWORD`. @@ -104,7 +105,7 @@ class TriggerMetadata(BaseModel): Associated with `TriggerType.KEYWORD_PRESET`. """ - allow_list: list[str] | None = Field(None, max_length=1000) + allow_list: Annotated[list[str] | None, Field(max_length=1000)] = None """Substrings which will be exempt from triggering the preset trigger type. Associated with `TriggerType.KEYWORD` and `TriggerType.KEYWORD_PRESET`. @@ -119,7 +120,7 @@ class TriggerMetadata(BaseModel): https://discord.com/developers/docs/resources/auto-moderation#auto-moderation-rule-object-keyword-matching-strategies """ - mention_total_limit: int | None = Field(None, le=50) + mention_total_limit: Annotated[int | None, Field(le=50)] = None """Total number of mentions(role & user) allowed per message. Maximum of 50. diff --git a/asyncord/client/models/permissions.py b/asyncord/client/models/permissions.py index 4b74ae8..4759b18 100644 --- a/asyncord/client/models/permissions.py +++ b/asyncord/client/models/permissions.py @@ -1,8 +1,10 @@ """This module contains the permission flags used by many Discord models.""" +from __future__ import annotations + import enum from collections.abc import Callable -from typing import Any, Self +from typing import Any from pydantic import BaseModel from pydantic_core import CoreSchema, core_schema @@ -200,6 +202,9 @@ def __get_pydantic_core_schema__( ) -> CoreSchema: """Pydantic auxiliary method to get schema. + Can be converted from string, int, or PermissionFlag. + Serializes to string. + Args: _source: Source of schema. _handler: Handler of schema. @@ -220,7 +225,7 @@ def __get_pydantic_core_schema__( ) @classmethod - def _validate(cls, value: str | int | Self) -> Self: + def _validate(cls, value: str | int | PermissionFlag) -> PermissionFlag: """Pydantic auxiliary validation method. Args: @@ -233,11 +238,12 @@ def _validate(cls, value: str | int | Self) -> Self: Validated permission flags. """ match value: + case PermissionFlag(): + return value case str(): return cls(int(value)) case int(): return cls(value) - case Self(): - return value - raise ValueError('Invalid value type for PermissionFlags') + # This should never happen because of the pydantic schema + raise ValueError('Invalid value type for PermissionFlags') # pragma: no cover diff --git a/asyncord/client/polls/models/requests.py b/asyncord/client/polls/models/requests.py index 28ad177..5844f08 100644 --- a/asyncord/client/polls/models/requests.py +++ b/asyncord/client/polls/models/requests.py @@ -1,83 +1,83 @@ """This module contains the response model for a polls.""" -from pydantic import BaseModel, Field, field_validator +from __future__ import annotations + +from typing import Self + +from pydantic import BaseModel, Field, SerializerFunctionWrapHandler, field_serializer, model_validator from asyncord.client.polls.models.common import PollLayoutType -from asyncord.snowflake import Snowflake +from asyncord.snowflake import SnowflakeInputType __all__ = ( - 'PartialEmoji', - 'PollAnswer', - 'PollMedia', - 'PollRequest', + 'Answer', + 'Poll', + 'PollEmoji', ) -class PartialEmoji(BaseModel): +class PollEmoji(BaseModel): """Represents a custom emoji that can be used in messages. Reference: https://discord.com/developers/docs/resources/emoji#emoji-object """ - id: Snowflake | None + id: SnowflakeInputType | None = None """Emoji id.""" - name: str | None - """Emoji name. + name: str | None = None + """Emoji name.""" - Can be null only in reaction emoji objects. - """ + @model_validator(mode='after') + def validate_id_or_name(self) -> Self: + """Validate that either id or name is set.""" + if not self.id and not self.name: + raise ValueError('Either id or name must be set') + + if self.id and self.name: + raise ValueError('Only one of id or name can be set') + + return self -class PollMedia(BaseModel): - """Poll Media Object. +class Answer(BaseModel): + """Answer object. + + Answer object is not an actual discord object. It's media object but currently + it uses only for answers. I removed media objects from question becauses it + supports only text field. Reference: https://discord.com/developers/docs/resources/poll#poll-media-object-poll-media-object-structure """ - text: str = Field(None, max_length=300) + text: str = Field(max_length=300) """Text of the field. The maximum length of text is 300 for the question, and 55 for any answer. """ - emoji: PartialEmoji | None = None + emoji: PollEmoji | None = None """Partial Emoji object.""" -class PollAnswer(BaseModel): - """Poll Answer Object. - - Reference: - https://discord.com/developers/docs/resources/poll#poll-answer-object-poll-answer-object-structure - """ - - poll_media: PollMedia - """Data of the answer.""" - - @field_validator('poll_media') - def validate_text_length(cls, poll_media: PollMedia) -> PollMedia: - """Validate the text length.""" - max_length = 55 - - if len(poll_media.text) > max_length: - raise ValueError('Text length should be less than 55 characters.') - return poll_media - - -class PollRequest(BaseModel): +class Poll(BaseModel): """Poll object. Reference: https://discord.com/developers/docs/resources/poll#poll-create-request-object """ - question: PollMedia - """The question of the poll. Only text is supported.""" + question: str = Field(max_length=300) + """The question of the poll. Only text is supported. + + Under the hood, the question is a media object but it supports only text now. + So, I decided to use a string field instead of a media object and prepare object + for sending to the API in the serializer. + """ - answers: list[PollAnswer] + answers: list[Answer] """Each of the answers available in the poll.""" duration: int | None = None @@ -89,11 +89,43 @@ class PollRequest(BaseModel): layout_type: PollLayoutType """The layout type of the poll.""" - @field_validator('question') - def validate_text_length(cls, question: PollMedia) -> PollMedia: - """Validate the text length.""" - max_length = 300 - - if len(question.text) > max_length: - raise ValueError('Text length should be less than 55 characters.') - return question + @field_serializer('question', mode='wrap', when_used='json') + @classmethod + def serialize_question( + cls, + question: str, + next_serializer: SerializerFunctionWrapHandler, + ) -> dict[str, str]: + """Prepare question for sending to the API. + + By default the question is wrapped in another structure. Currently for user + it is overengineering to have a separate object for each question. I dicided + to simplify the structure and allow user to pass a poll object directly. + """ + serialized_question = next_serializer(question) + return {'text': serialized_question} + + @field_serializer('answers', mode='wrap', when_used='json') + @classmethod + def serialize_answers( + cls, + answers: list[Answer], + next_serializer: SerializerFunctionWrapHandler, + ) -> list[dict[str, _JsonAnswerRepresentationType]]: + """Prepare answers for sending to the API. + + By default the answers are wrapped in another structure. Currently for user + it is overengineering to have a separate object for each answer. I dicided + to simplify the structure and allow user to pass a list of poll media objects + directly. + """ + serialized_answers = next_serializer(answers) + # fmt: off + return [ + {'poll_media': answer} + for answer in serialized_answers + ] + # fmt: on + + +type _JsonAnswerRepresentationType = dict[str, str | dict[str | None, str | None]] diff --git a/asyncord/client/polls/models/responses.py b/asyncord/client/polls/models/responses.py index a158dc9..0ad9dd2 100644 --- a/asyncord/client/polls/models/responses.py +++ b/asyncord/client/polls/models/responses.py @@ -1,18 +1,19 @@ """This module contains the response model for a polls.""" import datetime +from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from asyncord.client.polls.models.common import PollLayoutType from asyncord.client.users.models.responses import UserResponse from asyncord.snowflake import Snowflake __all__ = ( + 'AnswerOut', 'GetAnswerVotersResponse', - 'PartialEmojiOut', 'PollAnswerCountOut', - 'PollAnswerOut', + 'PollEmojiOut', 'PollMediaOut', 'PollResponse', 'PollResultsOut', @@ -29,7 +30,7 @@ class GetAnswerVotersResponse(BaseModel): users: list[UserResponse] -class PartialEmojiOut(BaseModel): +class PollEmojiOut(BaseModel): """Represents a custom emoji that can be used in messages. Reference: @@ -68,7 +69,7 @@ class PartialEmojiOut(BaseModel): class PollMediaOut(BaseModel): - """Poll Media Object. + """Poll Media Object in response. Reference: https://discord.com/developers/docs/resources/poll#poll-media-object-poll-media-object-structure @@ -77,18 +78,18 @@ class PollMediaOut(BaseModel): text: str | None = None """Text of the field.""" - emoji: PartialEmojiOut | None = None + emoji: PollEmojiOut | None = None """Partial Emoji object.""" -class PollAnswerOut(BaseModel): - """Poll Answer Object. +class AnswerOut(BaseModel): + """Poll Answer Object in response. Reference: https://discord.com/developers/docs/resources/poll#poll-answer-object-poll-answer-object-structure """ - asnwer_id: int | None = None + answer_id: int """ID of the answer.""" poll_media: PollMediaOut @@ -122,7 +123,7 @@ class PollResultsOut(BaseModel): is_finalized: bool """Whether the votes have been precisely counted.""" - answer_count: list[PollAnswerCountOut] | None = None + answer_counts: list[PollAnswerCountOut] """Counts for each answer""" @@ -133,10 +134,10 @@ class PollResponse(BaseModel): https://discord.com/developers/docs/resources/poll#poll-object-poll-object-structure """ - question: PollMediaOut + question: str """The question of the poll. Only text is supported.""" - answers: list[PollAnswerOut] + answers: list[AnswerOut] """Each of the answers available in the poll.""" expiry: datetime.datetime | None = None @@ -150,3 +151,10 @@ class PollResponse(BaseModel): results: PollResultsOut | None = None """The results of the poll.""" + + @model_validator(mode='before') + @classmethod + def prepare_question(cls, raw_values: dict[str, Any]) -> dict[str, Any]: + """Prepare question for validation.""" + raw_values['question'] = raw_values['question']['text'] + return raw_values diff --git a/asyncord/client/polls/resources.py b/asyncord/client/polls/resources.py index 43b8d57..d53db86 100644 --- a/asyncord/client/polls/resources.py +++ b/asyncord/client/polls/resources.py @@ -61,10 +61,7 @@ async def get_answer_voters( resp = await self._http_client.get(url=url) return GetAnswerVotersResponse.model_validate(resp.body) - async def end_poll( - self, - message_id: SnowflakeInputType, - ) -> MessageResponse: + async def end_poll(self, message_id: SnowflakeInputType) -> MessageResponse: """Immediately end a poll. You can't end polls from other users. diff --git a/asyncord/client/reactions/resources.py b/asyncord/client/reactions/resources.py index 142ee8c..3c2b710 100644 --- a/asyncord/client/reactions/resources.py +++ b/asyncord/client/reactions/resources.py @@ -2,11 +2,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from asyncord.client.resources import APIResource from asyncord.client.users.models.responses import UserResponse -from asyncord.typedefs import list_model +from asyncord.typedefs import CURRENT_USER, list_model from asyncord.urls import REST_API_URL if TYPE_CHECKING: @@ -69,22 +69,13 @@ async def add(self, emoji: str) -> None: Args: emoji: Emoji to react with. """ - url = self.reactions_url / emoji / '@me' + url = self.reactions_url / emoji / CURRENT_USER await self._http_client.put(url=url) - async def delete_own_reaction(self, emoji: str) -> None: - """Delete a reaction the current user made for the message. - - Args: - emoji: Emoji to delete the reaction for. - """ - url = self.reactions_url / emoji / '@me' - await self._http_client.delete(url=url) - async def delete( self, emoji: str | None = None, - user_id: SnowflakeInputType | None = None, + user_id: SnowflakeInputType | Literal['@me'] | None = None, ) -> None: """Delete a reaction a user made for the message. @@ -93,7 +84,7 @@ async def delete( user_id: ID of the user to delete the reaction for. """ if user_id and not emoji: - raise ValueError('If user_id is specified, emoji must be specified too.') + raise ValueError('Cannot delete a reaction for a user without an emoji.') url = self.reactions_url if emoji is not None: diff --git a/asyncord/client/roles/resources.py b/asyncord/client/roles/resources.py index 54b783b..fcb85a0 100644 --- a/asyncord/client/roles/resources.py +++ b/asyncord/client/roles/resources.py @@ -6,13 +6,14 @@ from asyncord.client.http.headers import AUDIT_LOG_REASON from asyncord.client.resources import APIResource +from asyncord.client.roles.models.requests import RolePositionRequest from asyncord.client.roles.models.responses import RoleResponse from asyncord.typedefs import list_model from asyncord.urls import REST_API_URL if TYPE_CHECKING: from asyncord.client.http.client import HttpClient - from asyncord.client.roles.models.requests import CreateRoleRequest, RolePositionRequest, UpdateRoleRequest + from asyncord.client.roles.models.requests import CreateRoleRequest, UpdateRoleRequest from asyncord.snowflake import SnowflakeInputType __all__ = ('RoleResource',) @@ -62,7 +63,7 @@ async def create( Returns: Created role. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -83,7 +84,9 @@ async def change_role_positions(self, role_positions: list[RolePositionRequest]) Returns: List of roles in the guild. """ - resp = await self._http_client.patch(url=self.roles_url, payload=role_positions) + payload = list_model(RolePositionRequest).dump_python(role_positions, mode='json') + + resp = await self._http_client.patch(url=self.roles_url, payload=payload) return list_model(RoleResponse).validate_python(resp.body) async def update_role(self, role_id: SnowflakeInputType, role_data: UpdateRoleRequest) -> RoleResponse: @@ -115,7 +118,7 @@ async def delete(self, role_id: SnowflakeInputType, reason: str | None = None) - reason: Reason for deleting the role. """ url = self.roles_url / str(role_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} diff --git a/asyncord/client/scheduled_events/models/requests.py b/asyncord/client/scheduled_events/models/requests.py index 94b5e9f..6270b1a 100644 --- a/asyncord/client/scheduled_events/models/requests.py +++ b/asyncord/client/scheduled_events/models/requests.py @@ -1,7 +1,9 @@ """Models for scheduled events requests.""" +from __future__ import annotations + import datetime -from typing import Self +from typing import Annotated, Self from pydantic import BaseModel, Field, model_validator @@ -22,7 +24,7 @@ class EventEntityMetadata(BaseModel): https://discord.com/developers/docs/resources/guild-scheduled-event#guild-scheduled-event-object-guild-scheduled-event-entity-metadata """ - location: str | None = Field(None, min_length=1, max_length=100) + location: Annotated[str, Field(min_length=1, max_length=100)] | None = None """Location of the event.""" @@ -63,20 +65,7 @@ class CreateScheduledEventRequest(BaseModel): @model_validator(mode='after') def validate_entity_type(self) -> Self: """Validates the entity type of the scheduled event.""" - if self.entity_type is EventEntityType.EXTERNAL: - if not self.entity_metadata: - raise ValueError('`entity_metadata` must be set if `entity_type` is EXTERNAL') - - if not self.entity_metadata.location: - raise ValueError('`entity_metadata.location` must be set if `entity_type` is EXTERNAL') - - if not self.scheduled_end_time: - raise ValueError('`scheduled_end_time` must be set if `entity_type` is EXTERNAL') - - elif not self.channel_id: - raise ValueError('`channel_id` must be set if `entity_type` is STAGE_INSTANCE or VOICE') - - return self + return _validate_entity_type(self) class UpdateScheduledEventRequest(BaseModel): @@ -119,21 +108,36 @@ class UpdateScheduledEventRequest(BaseModel): @model_validator(mode='after') def validate_entity_type(self) -> Self: """Validates the entity type of the scheduled event.""" - if not self.entity_type: - # can't validate if entity type is not set - return self - - if self.entity_type is EventEntityType.EXTERNAL: - if not self.entity_metadata: - raise ValueError('`entity_metadata` must be set if `entity_type` is EXTERNAL') - - if not self.entity_metadata.location: - raise ValueError('`entity_metadata.location` must be set if `entity_type` is EXTERNAL') + return _validate_entity_type(self) - if not self.scheduled_end_time: - raise ValueError('`scheduled_end_time` must be set if `entity_type` is EXTERNAL') - elif not self.channel_id: - raise ValueError('`channel_id` must be set if `entity_type` is STAGE_INSTANCE or VOICE') - - return self +def _validate_entity_type[ModelObjT: UpdateScheduledEventRequest | CreateScheduledEventRequest]( + model_obj: ModelObjT, +) -> ModelObjT: + """Validates the entity type of the scheduled event.""" + match model_obj.entity_type: + case None: + # can't validate if entity type is not set + return model_obj + + case EventEntityType.EXTERNAL: + # fmt: off + has_needed_fields = bool( + model_obj.entity_metadata + and model_obj.entity_metadata.location + and model_obj.scheduled_end_time, + ) + # fmt: on + if not has_needed_fields: + required_fields = ('entity_metadata', 'entity_metadata.location', 'scheduled_end_time') + err_msg = f'EXTERNAL type requires the fields {required_fields} to be set' + raise ValueError(err_msg) + + case EventEntityType.STAGE_INSTANCE | EventEntityType.VOICE: + if not model_obj.channel_id: + raise ValueError('`channel_id` must be set if `entity_type` is STAGE_INSTANCE or VOICE') + + case _: # pragma: no cover + raise ValueError('Invalid entity type') # unreachable in theory + + return model_obj diff --git a/asyncord/client/scheduled_events/resources.py b/asyncord/client/scheduled_events/resources.py index 58d1390..f8635fb 100644 --- a/asyncord/client/scheduled_events/resources.py +++ b/asyncord/client/scheduled_events/resources.py @@ -81,7 +81,7 @@ async def create( Returns: Created event. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} diff --git a/asyncord/client/stage_instances/models/requests.py b/asyncord/client/stage_instances/models/requests.py index 3f43e5d..222bbb7 100644 --- a/asyncord/client/stage_instances/models/requests.py +++ b/asyncord/client/stage_instances/models/requests.py @@ -1,5 +1,7 @@ """Models for stage instance resource requests.""" +from typing import Annotated + from pydantic import BaseModel, Field from asyncord.client.stage_instances.models.common import StageInstancePrivacyLevel @@ -21,7 +23,7 @@ class CreateStageInstanceRequest(BaseModel): channel_id: SnowflakeInputType """Id of the Stage channel.""" - topic: str = Field(None, min_length=1, max_length=120) + topic: str = Field(min_length=1, max_length=120) """Topic of the Stage instance. 1-120 characters.""" privacy_level: StageInstancePrivacyLevel | None = None @@ -41,7 +43,7 @@ class UpdateStageInstanceRequest(BaseModel): https://discord.com/developers/docs/resources/stage-instance#modify-stage-instance-json-params """ - topic: str | None = Field(None, min_length=1, max_length=120) + topic: Annotated[str, Field(min_length=1, max_length=120)] | None = None """Topic of the stage instance. 1-120 characters""" privacy_level: StageInstancePrivacyLevel | None = None diff --git a/asyncord/client/stage_instances/resources.py b/asyncord/client/stage_instances/resources.py index 70450fa..5f8ba21 100644 --- a/asyncord/client/stage_instances/resources.py +++ b/asyncord/client/stage_instances/resources.py @@ -1,4 +1,7 @@ -"""Stage Instances Resource. +"""Stage Instances Resources. + +Stage instances are a way to host live events in Discord. You need to create a stage channel +before creating a stage instance. Stage instances are associated with stage channels. Reference: https://discord.com/developers/docs/resources/stage-instance @@ -11,6 +14,7 @@ from asyncord.client.http.headers import AUDIT_LOG_REASON from asyncord.client.resources import APIResource from asyncord.client.stage_instances.models.responses import StageInstanceResponse +from asyncord.snowflake import SnowflakeInputType from asyncord.urls import REST_API_URL if TYPE_CHECKING: @@ -25,18 +29,16 @@ class StageInstancesResource(APIResource): """Stage Instance Resource. - These endpoints are for managing stage instances. - Reference: https://discord.com/developers/docs/resources/stage-instance """ stage_instances_url = REST_API_URL / 'stage-instances' - async def get_stage_instance( + async def get( self, - channel_id: str, - ) -> StageInstanceResponse | None: + channel_id: SnowflakeInputType, + ) -> StageInstanceResponse: """Gets the stage instance associated with the Stage channel. If exists. @@ -44,16 +46,14 @@ async def get_stage_instance( Reference: https://discord.com/developers/docs/resources/stage-instance#get-stage-instance - Args: - channel_id (str): The channel id. + Attributes: + channel_id: The channel id. """ url = self.stage_instances_url / str(channel_id) resp = await self._http_client.get(url=url) - if resp.body: - return StageInstanceResponse.model_validate(resp.body) - return None + return StageInstanceResponse.model_validate(resp.body) async def create_stage_instance( self, @@ -66,12 +66,15 @@ async def create_stage_instance( https://discord.com/developers/docs/resources/stage-instance#create-stage-instance Args: - stage_instance_data (CreateStageInstanceRequest): The stage instance data. - reason (str, optional): The reason for creating the stage instance. + stage_instance_data: The stage instance data. + reason: The reason for creating the stage instance. + + Returns: + Created stage instance. """ url = self.stage_instances_url - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -86,9 +89,9 @@ async def create_stage_instance( return StageInstanceResponse.model_validate(resp.body) - async def update_stage_instance( + async def update( self, - channel_id: str, + channel_id: SnowflakeInputType, stage_instance_data: UpdateStageInstanceRequest, reason: str | None = None, ) -> StageInstanceResponse: @@ -98,13 +101,16 @@ async def update_stage_instance( https://discord.com/developers/docs/resources/stage-instance#modify-stage-instance Args: - channel_id (str): The channel id. + channel_id: The channel id. stage_instance_data (UpdateStageInstanceRequest): The stage instance data. - reason (str, optional): The reason for updating the stage instance. + reason: The reason for updating the stage instance. + + Returns: + Updated stage instance. """ url = self.stage_instances_url / str(channel_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -119,9 +125,9 @@ async def update_stage_instance( return StageInstanceResponse.model_validate(resp.body) - async def delete_stage_instance( + async def delete( self, - channel_id: str, + channel_id: SnowflakeInputType, reason: str | None = None, ) -> None: """Deletes the stage instance associated with the Stage channel. @@ -130,12 +136,12 @@ async def delete_stage_instance( https://discord.com/developers/docs/resources/stage-instance#delete-stage-instance Args: - channel_id (str): The channel id. - reason (str, optional): The reason for deleting the stage instance. + channel_id: The channel id. + reason: The reason for deleting the stage instance. """ url = self.stage_instances_url / str(channel_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} diff --git a/asyncord/client/stickers/models/requests.py b/asyncord/client/stickers/models/requests.py index 4a86b66..7a10e45 100644 --- a/asyncord/client/stickers/models/requests.py +++ b/asyncord/client/stickers/models/requests.py @@ -3,7 +3,6 @@ from __future__ import annotations from collections.abc import Sequence -from collections.abc import Set as AbstractSet from typing import Annotated from pydantic import ( @@ -67,7 +66,7 @@ class UpdateGuildStickerRequest(BaseModel): """Autocomplete/suggestion tags for the sticker.""" -def _validate_tags(tags: Sequence[str] | AbstractSet[str] | str) -> set[str]: +def _validate_tags(tags: Sequence[str] | set[str] | str) -> set[str]: """Validate tags length. On serialization, tags converted to a string with a comma and space separator. @@ -78,7 +77,7 @@ def _validate_tags(tags: Sequence[str] | AbstractSet[str] | str) -> set[str]: if isinstance(tags, str): tags = set(tag.strip() for tag in tags.split(',')) else: - tags = set(tag.lower() for tag in tags) + tags = set(tag for tag in tags) total_tags_length = ( sum(len(tag) for tag in tags) # total length of all tags @@ -102,10 +101,10 @@ def _serialize_tags_to_string( type TagsType = Annotated[ - set[str] | str, + Sequence[str] | set[str] | str, BeforeValidator(_validate_tags), WrapSerializer(_serialize_tags_to_string), - set[str], # after all validators, tags must be a set, pydantic will check it + set[str], ] """Type for tags field in sticker requests. diff --git a/asyncord/client/stickers/resources.py b/asyncord/client/stickers/resources.py index fcbdc0a..597365c 100644 --- a/asyncord/client/stickers/resources.py +++ b/asyncord/client/stickers/resources.py @@ -53,9 +53,7 @@ async def get_sticker( return Sticker.model_validate(resp.body) - async def get_sticker_pack_list( - self, - ) -> StickerPackListResponse: + async def get_sticker_pack_list(self) -> StickerPackListResponse: """Returns a list of available sticker packs. Reference: @@ -67,10 +65,7 @@ async def get_sticker_pack_list( return StickerPackListResponse.model_validate(resp.body) - async def get_guild_stickers_list( - self, - guild_id: SnowflakeInputType, - ) -> list[Sticker]: + async def get_guild_stickers_list(self, guild_id: SnowflakeInputType) -> list[Sticker]: """Returns array of sticker objects foriven guild. Reference: @@ -129,7 +124,7 @@ async def create_guild_sticker( """ url = self.guild_url / str(guild_id) / 'stickers' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -172,7 +167,7 @@ async def update_guild_sticker( """ url = self.guild_url / str(guild_id) / 'stickers' / str(sticker_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -201,7 +196,7 @@ async def delete_guild_sticker( """ url = self.guild_url / str(guild_id) / 'stickers' / str(sticker_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} diff --git a/asyncord/client/threads/models/requests.py b/asyncord/client/threads/models/requests.py index e58ef2f..9c5e022 100644 --- a/asyncord/client/threads/models/requests.py +++ b/asyncord/client/threads/models/requests.py @@ -9,8 +9,8 @@ from asyncord.client.messages.models.common import MessageFlags from asyncord.client.messages.models.requests.components import Component from asyncord.client.messages.models.requests.embeds import Embed -from asyncord.client.messages.models.requests.messages import AllowedMentions, Attachment, BaseMessage -from asyncord.client.models.attachments import AttachmentContentType +from asyncord.client.messages.models.requests.messages import AllowedMentions, BaseMessage +from asyncord.client.models.attachments import Attachment, AttachmentContentType from asyncord.client.threads.models.common import ThreadType from asyncord.snowflake import SnowflakeInputType @@ -41,7 +41,7 @@ class CreateThreadRequest(BaseModel): Only available when creating a private thread. """ - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -66,7 +66,7 @@ class CreateThreadFromMessageRequest(BaseModel): auto_archive_duration: Literal[60, 1440, 4320, 10080] | None = None """Duration in minutes to automatically archive the thread after recent activity.""" - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -77,7 +77,7 @@ class CreateThreadFromMessageRequest(BaseModel): class ThreadMessage(BaseMessage): """Message model for a media/forum thread.""" - content: str | None = Field(None, max_length=2000) + content: Annotated[str | None, Field(max_length=2000)] = None """Message content.""" embeds: list[Embed] | None = None @@ -115,7 +115,7 @@ class CreateMediaForumThreadRequest(BaseModel): auto_archive_duration: Literal[60, 1440, 4320, 10080] | None = None """Duration in minutes to automatically archive the thread after recent activity.""" - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission @@ -132,7 +132,7 @@ class CreateMediaForumThreadRequest(BaseModel): class UpdateThreadRequest(BaseModel): """Request model for updating a thread.""" - name: str | None = Field(None, min_length=1, max_length=100) + name: Annotated[str, Field(min_length=1, max_length=100)] | None = None """Thread name.""" archived: bool | None = None @@ -150,7 +150,7 @@ class UpdateThreadRequest(BaseModel): auto_archive_duration: Literal[60, 1440, 4320, 10080] | None = None """Duration in minutes to automatically archive the thread after recent activity.""" - rate_limit_per_user: int | None = Field(None, ge=0, le=MAX_RATELIMIT) + rate_limit_per_user: Annotated[int, Field(ge=0, le=MAX_RATELIMIT)] | None = None """Amount of seconds a user has to wait before sending another message. Should be between 0 and 21600. Bots, as well as users with the permission diff --git a/asyncord/client/threads/resources.py b/asyncord/client/threads/resources.py index 32bc808..bd47c73 100644 --- a/asyncord/client/threads/resources.py +++ b/asyncord/client/threads/resources.py @@ -2,12 +2,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from asyncord.client.channels.models.responses import ThreadMemberResponse from asyncord.client.http.headers import AUDIT_LOG_REASON from asyncord.client.messages.resources import MessageResource -from asyncord.client.models.attachments import make_attachment_payload +from asyncord.client.models.attachments import Attachment, make_payload_with_attachments from asyncord.client.resources import APIResource from asyncord.client.threads.models.requests import UpdateThreadRequest from asyncord.client.threads.models.responses import ThreadResponse, ThreadsResponse @@ -140,7 +140,7 @@ async def create_thread_from_message( thread_data: Data to create the thread with. reason: Reason for creating the thread. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -157,7 +157,7 @@ async def create_thread(self, thread_data: CreateThreadRequest, reason: str | No thread_data: Data to create the thread with. reason: Reason for creating the thread. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -178,12 +178,12 @@ async def create_media_forum_thread( thread_data: Data to create the thread with. reason: Reason for creating the thread. """ - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} - - payload = make_attachment_payload(thread_data.message, root_payload=thread_data) + attachments = cast(list[Attachment] | None, thread_data.message.attachments) + payload = make_payload_with_attachments(thread_data, attachments=attachments) resp = await self._http_client.post( url=self.threads_url, @@ -201,7 +201,7 @@ async def delete(self, thread_id: SnowflakeInputType, reason: str | None = None) """ url = self.channels_url / str(thread_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -361,7 +361,7 @@ async def update( """ url = self.channels_url / str(thread_id) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} diff --git a/asyncord/client/users/models/requests.py b/asyncord/client/users/models/requests.py index ec25aa9..f567216 100644 --- a/asyncord/client/users/models/requests.py +++ b/asyncord/client/users/models/requests.py @@ -1,5 +1,7 @@ """This module contains the request models for the user endpoints.""" +from typing import Annotated + from pydantic import BaseModel, Field from asyncord.base64_image import Base64ImageInputType @@ -33,10 +35,10 @@ class UpdateApplicationRoleConnectionRequest(BaseModel): https://discord.com/developers/docs/resources/user#update-current-user-application-role-connection-json-params """ - platform_name: str | None = Field(None, max_length=50) + platform_name: Annotated[str, Field(max_length=50)] | None = None """Vanity name of the platform a bot has connected""" - platform_username: str | None = Field(None, max_length=100) + platform_username: Annotated[str, Field(max_length=100)] | None = None """Username of the platform a bot has connected""" metadata: dict[str, str] | None = None diff --git a/asyncord/client/users/resources.py b/asyncord/client/users/resources.py index 5ebc952..f261cec 100644 --- a/asyncord/client/users/resources.py +++ b/asyncord/client/users/resources.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING from asyncord.client.channels.models.responses import ChannelResponse @@ -149,7 +150,7 @@ async def create_dm(self, user_id: SnowflakeInputType) -> ChannelResponse: resp = await self._http_client.post(url=url, payload=payload) return ChannelResponse.model_validate(resp.body) - async def create_group_dm(self, user_ids: list[SnowflakeInputType]) -> ChannelResponse: + async def create_group_dm(self, user_ids: Sequence[SnowflakeInputType]) -> ChannelResponse: """Create a group DM. This endpoint was intended to be used with the now-deprecated GameBridge SDK. @@ -226,7 +227,7 @@ async def get_current_user_connections( resp = await self._http_client.get(url=url) return list_model(UserConnectionResponse).validate_python(resp.body) - async def get_current_user_application_role_connection( + async def get_current_user_application_role_connection( # pragma: no cover self, application_id: SnowflakeInputType, ) -> ApplicationRoleConnectionResponse: @@ -244,7 +245,7 @@ async def get_current_user_application_role_connection( resp = await self._http_client.get(url=url) return ApplicationRoleConnectionResponse.model_validate(resp.body) - async def update_current_user_application_role_connection( + async def update_current_user_application_role_connection( # pragma: no cover self, application_id: SnowflakeInputType, update_data: UpdateApplicationRoleConnectionRequest, diff --git a/asyncord/client/webhooks/models/requests.py b/asyncord/client/webhooks/models/requests.py index 003df03..c2506c5 100644 --- a/asyncord/client/webhooks/models/requests.py +++ b/asyncord/client/webhooks/models/requests.py @@ -9,9 +9,9 @@ from asyncord.client.messages.models.common import AllowedMentionType, MessageFlags from asyncord.client.messages.models.requests.components import Component from asyncord.client.messages.models.requests.embeds import Embed -from asyncord.client.messages.models.requests.messages import Attachment, BaseMessage -from asyncord.client.models.attachments import AttachmentContentType -from asyncord.client.polls.models.requests import PollRequest +from asyncord.client.messages.models.requests.messages import BaseMessage +from asyncord.client.models.attachments import Attachment, AttachmentContentType +from asyncord.client.polls.models.requests import Poll from asyncord.snowflake import SnowflakeInputType __ALL__ = ( @@ -81,10 +81,10 @@ class ExecuteWebhookRequest(BaseMessage): allowed_mentions: AllowedMentionType | None = None """Allowed mentions for the message.""" - components: Sequence[Component] | None = None + components: Component | Sequence[Component] | None = None """The components to include with the message.""" - attachments: list[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None """List of attachment object. See Uploading Files: @@ -109,7 +109,7 @@ class ExecuteWebhookRequest(BaseMessage): (requires the webhook channel to be a forum or media channel). """ - poll: PollRequest | None = None + poll: Poll | None = None """A poll.""" @@ -129,10 +129,10 @@ class UpdateWebhookMessageRequest(BaseMessage): allowed_mentions: AllowedMentionType | None = None """Allowed mentions for the message.""" - components: Sequence[Component] | None = None + components: Component | Sequence[Component] | None = None """The components to include with the message.""" - attachments: list[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None + attachments: Sequence[Annotated[Attachment | AttachmentContentType, Attachment]] | None = None """List of attachment object. See Uploading Files: diff --git a/asyncord/client/webhooks/models/responces.py b/asyncord/client/webhooks/models/responces.py index a00d53f..994c48b 100644 --- a/asyncord/client/webhooks/models/responces.py +++ b/asyncord/client/webhooks/models/responces.py @@ -79,7 +79,7 @@ class WebhookResponse(BaseModel): token: str | None = None """The secure token of the webhook. - (returned for Incoming Webhooks). + Returned only for incoming webhook type. """ application_id: Snowflake | None = None diff --git a/asyncord/client/webhooks/resources.py b/asyncord/client/webhooks/resources.py index ff56285..24945fc 100644 --- a/asyncord/client/webhooks/resources.py +++ b/asyncord/client/webhooks/resources.py @@ -6,11 +6,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from asyncord.client.http.headers import AUDIT_LOG_REASON from asyncord.client.messages.models.responses.messages import MessageResponse -from asyncord.client.models.attachments import make_attachment_payload +from asyncord.client.models.attachments import Attachment, make_payload_with_attachments from asyncord.client.resources import APIResource from asyncord.client.webhooks.models.responces import ( WebhookResponse, @@ -78,7 +78,7 @@ async def get_guild_webhooks( async def get_webhook( self, webhook_id: SnowflakeInputType, - webhook_token: str | None = None, + token: str | None = None, ) -> WebhookResponse: """Returns a new webhook object. @@ -87,12 +87,12 @@ async def get_webhook( Args: webhook_id: ID of the webhook to get. - webhook_token: Token of the webhook. + token: Token of the webhook. """ url = self.webhooks_url / str(webhook_id) - if webhook_token is not None: - url /= str(webhook_token) + if token is not None: + url /= str(token) resp = await self._http_client.get(url=url) return WebhookResponse.model_validate(resp.body) @@ -115,7 +115,7 @@ async def create_webhook( """ url = self.channel_url / str(channel_id) / 'webhooks' - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -126,9 +126,10 @@ async def create_webhook( async def update_webhook( self, + *, webhook_id: SnowflakeInputType, update_data: UpdateWebhookRequest, - webhook_token: str | None = None, + token: str | None = None, reason: str | None = None, ) -> WebhookResponse: """Modify a webhook with a token. @@ -140,19 +141,20 @@ async def update_webhook( Args: webhook_id: ID of the webhook to modify. update_data: Webhook data. - webhook_token: Token of the webhook. + token: Token of the webhook. reason: Reason for updating the webhook. """ + if update_data.channel_id and token: + raise ValueError('`channel_id` cannot be set when updating a webhook with a token') + url = self.webhooks_url / str(webhook_id) - payload = update_data.model_dump(mode='json', exclude_unset=True) + if token: + url /= str(token) - if webhook_token is not None: - url /= str(webhook_token) - if 'channel_id' in payload: - del payload['channel_id'] + payload = update_data.model_dump(mode='json', exclude_unset=True) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -163,7 +165,7 @@ async def update_webhook( async def delete_webhook( self, webhook_id: SnowflakeInputType, - webhook_token: str | None = None, + token: str | None = None, reason: str | None = None, ) -> None: """Delete a webhook. @@ -174,14 +176,14 @@ async def delete_webhook( Args: webhook_id: ID of the webhook to delete. - webhook_token: Token of the webhook. + token: Token of the webhook. reason: Reason for deleting the webhook. """ url = self.webhooks_url / str(webhook_id) - if webhook_token is not None: - url /= str(webhook_token) + if token is not None: + url /= str(token) - if reason is not None: + if reason: headers = {AUDIT_LOG_REASON: reason} else: headers = {} @@ -190,11 +192,12 @@ async def delete_webhook( async def execute_webhook( self, + *, webhook_id: SnowflakeInputType, - webhook_token: str, - execute_data: ExecuteWebhookRequest, - wait: bool | None = None, + token: str, + execution_data: ExecuteWebhookRequest, thread_id: SnowflakeInputType | None = None, + wait: bool = True, ) -> MessageResponse | None: """Execute a webhook. @@ -203,31 +206,33 @@ async def execute_webhook( Args: webhook_id: ID of the webhook to execute. - webhook_token: Token of the webhook. - execute_data: Webhook data. - wait: Waits for server confirmation. + token: Token of the webhook. + execution_data: Data to execute the webhook with. thread_id: Send the message to the specified thread. + wait: Whether to wait for the message to be sent. """ params = {} - if wait is not None: - params['wait'] = str(wait) if thread_id is not None: - params['thread_id'] = thread_id + params['thread_id'] = str(thread_id) + if wait: + params['wait'] = str(True) - url = self.webhooks_url / str(webhook_id) / str(webhook_token) % params - payload = make_attachment_payload(execute_data) + url = self.webhooks_url / str(webhook_id) / str(token) % params + attachments = cast(list[Attachment] | None, execution_data.attachments) + payload = make_payload_with_attachments(execution_data, attachments=attachments) message = await self._http_client.post(url=url, payload=payload) if wait: return MessageResponse.model_validate(message.body) - + # If wait is False, discord will return a 204 No Content response return None async def get_webhook_message( self, + *, webhook_id: SnowflakeInputType, - webhook_token: str, + token: str, message_id: SnowflakeInputType, thread_id: SnowflakeInputType | None = None, ) -> MessageResponse: @@ -238,11 +243,11 @@ async def get_webhook_message( Args: webhook_id: ID of the webhook. - webhook_token: Token of the webhook. + token: Token of the webhook. message_id: ID of the message. thread_id: ID of the thread. """ - url = self.webhooks_url / str(webhook_id) / str(webhook_token) / 'messages' / str(message_id) + url = self.webhooks_url / str(webhook_id) / str(token) / 'messages' / str(message_id) if thread_id is not None: params = {'thread_id': str(thread_id)} url %= params @@ -252,8 +257,9 @@ async def get_webhook_message( async def update_webhook_message( self, + *, webhook_id: SnowflakeInputType, - webhook_token: str, + token: str, message_id: SnowflakeInputType, update_data: UpdateWebhookMessageRequest, thread_id: SnowflakeInputType | None = None, @@ -265,25 +271,27 @@ async def update_webhook_message( Args: webhook_id: ID of the webhook. - webhook_token: Token of the webhook. + token: Token of the webhook. message_id: ID of the message. update_data: Message data. thread_id: ID of the thread. """ - url = self.webhooks_url / str(webhook_id) / str(webhook_token) / 'messages' / str(message_id) + url = self.webhooks_url / str(webhook_id) / str(token) / 'messages' / str(message_id) if thread_id is not None: params = {'thread_id': str(thread_id)} url %= params - payload = make_attachment_payload(update_data) + attachments = cast(list[Attachment] | None, update_data.attachments) + payload = make_payload_with_attachments(update_data, attachments=attachments) resp = await self._http_client.patch(url=url, payload=payload) return MessageResponse.model_validate(resp.body) async def delete_webhook_message( self, + *, webhook_id: SnowflakeInputType, - webhook_token: str, + token: str, message_id: SnowflakeInputType, thread_id: SnowflakeInputType | None = None, ) -> None: @@ -294,11 +302,11 @@ async def delete_webhook_message( Args: webhook_id: ID of the webhook. - webhook_token: Token of the webhook. + token: Token of the webhook. message_id: ID of the message. thread_id: ID of the thread. """ - url = self.webhooks_url / str(webhook_id) / str(webhook_token) / 'messages' / str(message_id) + url = self.webhooks_url / str(webhook_id) / str(token) / 'messages' / str(message_id) if thread_id is not None: params = {'thread_id': str(thread_id)} url %= params diff --git a/asyncord/client_hub.py b/asyncord/client_hub.py index fb2daf6..ddf0f30 100644 --- a/asyncord/client_hub.py +++ b/asyncord/client_hub.py @@ -53,10 +53,10 @@ def __init__( """ if session: self.session = session - self._outer_session = True + self._is_outer_session = True else: self.session = aiohttp.ClientSession() - self._outer_session = False + self._is_outer_session = False self.heartbeat_factory = heartbeat_factory_type() self.client_groups: dict[str, ClientGroup] = {} # Added type annotation @@ -161,14 +161,15 @@ async def start(self) -> None: await asyncio.gather(*tasks) except (KeyboardInterrupt, asyncio.CancelledError): logger.info('Shutting down...') - await asyncio.gather(*tasks, return_exceptions=True) + await self.stop() async def stop(self) -> None: """Stop the client hub.""" for client in self.client_groups.values(): await client.close() self.heartbeat_factory.stop() - await self.session.close() + if not self._is_outer_session: + await self.session.close() logger.info(':wave: Shutdown complete', extra={'markup': True}) async def __aenter__(self) -> Self: @@ -218,7 +219,7 @@ def _build_client_group( gateway_client = GatewayClient( token=auth, session=self.session, - heartbeat_class=self.heartbeat_factory.create, + heartbeat_class=self.heartbeat_factory, dispatcher=dispatcher, name=group_name, ) diff --git a/asyncord/gateway/client/client.py b/asyncord/gateway/client/client.py index da218dd..3e2be2a 100644 --- a/asyncord/gateway/client/client.py +++ b/asyncord/gateway/client/client.py @@ -10,7 +10,7 @@ import logging from collections.abc import Mapping from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable import aiohttp from pydantic import BaseModel @@ -79,7 +79,10 @@ def __init__( # noqa: PLR0913 self.session = session self.conn_data = conn_data or ConnectionData(token=token) self.intents = intents - self.heartbeat = heartbeat_class(self, self.conn_data) + if isinstance(heartbeat_class, HeartbeatFactoryProtocol): + self.heartbeat = heartbeat_class.create(self, self.conn_data) + else: + self.heartbeat = heartbeat_class(self, self.conn_data) self.dispatcher = dispatcher or EventDispatcher() self.is_started = False @@ -112,13 +115,15 @@ async def connect(self) -> None: async def close(self) -> None: """Stop the client.""" self.logger.info('Closing gateway client') - if not self.is_started or not self._ws: + if not self.is_started and not self._ws: return self.is_started = False self._need_restart.set() self.heartbeat.stop() - await self._ws.close() + if self._ws: + await self._ws.close() + self._ws = None self.logger.info('Gateway client closed') async def send_command(self, opcode: GatewayCommandOpcode, data: Any) -> None: # noqa: ANN401 @@ -221,12 +226,13 @@ async def _ws_recv_loop(self, ws: aiohttp.ClientWebSocketResponse) -> None: need_restart_task.cancel() message = await msg_task + # when get ending message, message is None if message: await self._handle_message(message) - async def _get_message(self, ws: aiohttp.ClientWebSocketResponse) -> GatewayMessageType | None: + async def _get_message(self, ws_resp: aiohttp.ClientWebSocketResponse) -> GatewayMessageType | None: """Get a message from the websocket.""" - msg = await ws.receive() + msg = await ws_resp.receive() if msg.type is aiohttp.WSMsgType.TEXT: data = msg.json() return GatewayMessageAdapter.validate_python(data) @@ -295,8 +301,10 @@ def stop(self) -> None: """Stop the heartbeat.""" +@runtime_checkable class HeartbeatFactoryProtocol(Protocol): """Protocol for the heartbeat factory class.""" - def __call__(self, client: GatewayClient, conn_data: ConnectionData) -> HeartbeatProtocol: # type: ignore - """Create a heartbeat for the client.""" + def create(self, client: GatewayClient, conn_data: ConnectionData) -> HeartbeatProtocol: + """Create a heartbeat instance.""" + ... diff --git a/asyncord/gateway/client/heartbeat.py b/asyncord/gateway/client/heartbeat.py index 3c5358b..45c7a48 100644 --- a/asyncord/gateway/client/heartbeat.py +++ b/asyncord/gateway/client/heartbeat.py @@ -11,7 +11,7 @@ import logging import random import threading -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING if TYPE_CHECKING: from asyncord.gateway.client.client import ConnectionData, GatewayClient @@ -59,15 +59,15 @@ def stop(self) -> None: self._task = None self._cleanup() - def __repr__(self) -> str: - """Return the representation of the heartbeat.""" - return f'' - @property def is_running(self) -> bool: """Whether the heartbeat is running.""" return self._task is not None + def __repr__(self) -> str: + """Return the representation of the heartbeat.""" + return f'' + def _cleanup(self) -> None: """Cleanup the heartbeat.""" self._ack_event.clear() @@ -91,8 +91,13 @@ async def _run(self, interval: datetime.timedelta) -> None: logger.debug('Keep interval: %i', keep_interval.total_seconds()) try: await asyncio.wait_for(self._wait_heartbeat_ack(), timeout=keep_interval.total_seconds()) - logger.debug('Heartbeat ack received.') except TimeoutError: + logger.error('Heartbeat ack not received in time. Reconnecting...') + self.client.reconnect() + self._task = None + break + except Exception as e: + logger.error('An unexpected error occurred: %s', e) self.client.reconnect() self._task = None break @@ -101,15 +106,16 @@ async def _wait_heartbeat_ack(self) -> None: """Wait for a heartbeat ack.""" for _ in range(100): await self.client.send_heartbeat(seq=self.conn_data.seq) - logger.debug('Heartbeat sent.') + logger.debug('Heartbeat sent') try: await asyncio.wait_for(self._ack_event.wait(), timeout=5) + logger.debug('Heartbeat ack received') return except TimeoutError: pass - logger.warning('Heartbeat ack not received.') + logger.debug('Heartbeat ack not received') - logger.error('Heartbeat ack not received after 100 attempts. Looks weird.') + logger.error('Heartbeat ack not received after 100 attempts. Looks weird') @property def _jittered_sleep_duration(self) -> float: @@ -125,9 +131,6 @@ def _jittered_sleep_duration(self) -> float: class HeartbeatFactory: """Factory for creating heartbeats.""" - HEARTBEAT_CLASS: ClassVar[type[Heartbeat]] = Heartbeat - """Heartbeat class to create.""" - def __init__(self) -> None: """Initialize the factory.""" self.loop = asyncio.new_event_loop() @@ -135,7 +138,7 @@ def __init__(self) -> None: def create(self, client: GatewayClient, conn_data: ConnectionData) -> Heartbeat: """Create a heartbeat.""" - return self.HEARTBEAT_CLASS(client=client, conn_data=conn_data, _loop=self.loop) + return Heartbeat(client=client, conn_data=conn_data, _loop=self.loop) def start(self) -> None: """Start the heartbeat.""" diff --git a/asyncord/gateway/events/event_map.py b/asyncord/gateway/events/event_map.py index bb0d4c8..909604a 100644 --- a/asyncord/gateway/events/event_map.py +++ b/asyncord/gateway/events/event_map.py @@ -11,6 +11,7 @@ channels, guilds, interactions, + invites, messages, moderation, presence, @@ -48,6 +49,7 @@ def _get_all_event_classes(modules: list[object]) -> Generator[type[base.Gateway moderation, presence, scheduled_events, + invites, ]) }) """Mapping of event names to event classes. diff --git a/asyncord/gateway/events/guilds.py b/asyncord/gateway/events/guilds.py index 98fe938..3806cd4 100644 --- a/asyncord/gateway/events/guilds.py +++ b/asyncord/gateway/events/guilds.py @@ -3,9 +3,9 @@ from datetime import datetime from typing import Any +from asyncord.client.emojis.models.responses import EmojiResponse from asyncord.client.guilds.resources import GuildResponse from asyncord.client.members.models.responses import MemberResponse -from asyncord.client.models.emoji import Emoji from asyncord.client.roles.models.responses import RoleResponse from asyncord.client.users.models.responses import UserResponse from asyncord.gateway.events.base import GatewayEvent @@ -143,7 +143,7 @@ class GuildEmojisUpdateEvent(GatewayEvent): guild_id: Snowflake """Guild id.""" - emojis: list[Emoji] + emojis: list[EmojiResponse] """List of emojis.""" @@ -219,16 +219,16 @@ class GuildMemberUpdateEvent(GatewayEvent, MemberResponse): avatar: str | None = None """Member's guild avatar hash.""" - joined_at: datetime | None = None + joined_at: datetime | None = None # type: ignore """When the user joined the guild.""" premium_since: datetime | None = None """When the user started boosting the guild.""" - deaf: bool | None = None + deaf: bool | None = None # type: ignore """Whether the user is deafened in voice channels.""" - mute: bool | None = None + mute: bool | None = None # type: ignore """Whether the user is muted in voice channels.""" pending: bool | None # type: ignore diff --git a/asyncord/gateway/message.py b/asyncord/gateway/message.py index f4d0b21..d75e8a4 100644 --- a/asyncord/gateway/message.py +++ b/asyncord/gateway/message.py @@ -4,6 +4,7 @@ from typing import Annotated, Any, Literal from fbenum.adapter import FallbackAdapter +from fbenum.enum import FallbackEnum from pydantic import BaseModel, Field, TypeAdapter __all__ = ( @@ -45,7 +46,7 @@ class GatewayCommandOpcode(enum.IntEnum): @enum.unique -class GatewayMessageOpcode(enum.IntEnum): +class GatewayMessageOpcode(enum.IntEnum, FallbackEnum): """Gateway message opcodes.""" DISPATCH = 0 @@ -64,7 +65,7 @@ class GatewayMessageOpcode(enum.IntEnum): """Sent in response to receiving a heartbeat to acknowledge that it has been received.""" -class BaseGatewayMessage(BaseModel): +class BaseGatewayMessage(BaseModel, frozen=True): """Base gateway message model.""" opcode: GatewayMessageOpcode = Field(alias='op') @@ -77,7 +78,7 @@ class BaseGatewayMessage(BaseModel): """Message trace information.""" -class DispatchMessage(BaseGatewayMessage): +class DispatchMessage(BaseGatewayMessage, frozen=True): """Dispatch message model.""" opcode: Literal[GatewayMessageOpcode.DISPATCH] = Field(GatewayMessageOpcode.DISPATCH, alias='op') @@ -110,7 +111,7 @@ class HelloMessageData(BaseModel): """Interval (in milliseconds) the client should heartbeat with.""" -class HelloMessage(BaseGatewayMessage): +class HelloMessage(BaseGatewayMessage, frozen=True): """Hello message model.""" opcode: Literal[GatewayMessageOpcode.HELLO] = Field(GatewayMessageOpcode.HELLO, alias='op') @@ -120,7 +121,7 @@ class HelloMessage(BaseGatewayMessage): """Hello message data.""" -class InvalidSessionMessage(BaseGatewayMessage): +class InvalidSessionMessage(BaseGatewayMessage, frozen=True): """Invalid session message model.""" opcode: Literal[GatewayMessageOpcode.INVALID_SESSION] = Field(GatewayMessageOpcode.INVALID_SESSION, alias='op') @@ -130,15 +131,18 @@ class InvalidSessionMessage(BaseGatewayMessage): """Whether the session can be resumed.""" -class DatalessMessage(BaseGatewayMessage): +class DatalessMessage(BaseGatewayMessage, frozen=True): """Other gateway messages that do not have data.""" + data: Annotated[None, Field(alias='d')] = None + """Message data.""" + -class FallbackGatewayMessage(BaseGatewayMessage): +class FallbackGatewayMessage(BaseGatewayMessage, frozen=True): """Gateway message model.""" opcode: Annotated[GatewayMessageOpcode, FallbackAdapter] = Field(alias='op') - data: Any = Field(alias='d') + data: Annotated[Any, Field(alias='d')] = None type GatewayMessageType = ( diff --git a/asyncord/typedefs.py b/asyncord/typedefs.py index ce1e190..b97bd26 100644 --- a/asyncord/typedefs.py +++ b/asyncord/typedefs.py @@ -4,15 +4,30 @@ import enum from functools import lru_cache -from typing import TYPE_CHECKING, NewType +from typing import TYPE_CHECKING, Literal, NewType from pydantic import TypeAdapter from yarl import URL -StrOrURL = str | URL +CURRENT_USER: CurrentUserType = '@me' +"""Literal for the current user endpoint.""" + +type CurrentUserType = Literal['@me'] +"""Type alias for the current user type.""" + + +StrOrURL = URL | str +"""URL in string or yarl.URL format.""" UnsetType = NewType('UnsetType', object) +"""Type of an unset value.""" Unset: UnsetType = UnsetType(object()) +"""Sentinel for an unset value. + +This value is used to represent an unset value in the API. +It can be used to differentiate between a value that is not set and a value that is set to None. +Sentinels described in draft PEP 601 (https://peps.python.org/pep-0661/). +""" # Fix for pydanitc and pylance. Pylance doesn't correctly infer the type # of the list_model function. @@ -49,9 +64,14 @@ def __init__(self, obj_value: str | set[str]): self._value_ = obj_value @classmethod - def _missing_(cls, values: set[str]) -> object: - """Returns the missing value of the scope.""" - # get members by values and fail if any value is not found + def _missing_(cls, values: str | set[str]) -> object: # type: ignore + """Returns the missing value of the scope. + + Try to get members by values and fail if any value is not found. + """ + if isinstance(values, str): + values = {values} + actual_members = [] for value in values: member = cls._value2member_map_.get(value) @@ -78,7 +98,7 @@ def _missing_(cls, values: set[str]) -> object: @property def value(self) -> str: """Returns the value of the scope.""" - return ' '.join(self._value_) + return ' '.join(sorted(self._value_)) def __str__(self) -> str: """Returns the string representation of the scope.""" @@ -92,6 +112,13 @@ def __or__(self, other: StrFlag) -> StrFlag: """Returns the union of the two scopes.""" return self.__class__(self._value_ | other._value_) - def _separated_flags(self) -> list[StrFlag]: - """Returns the separated flags of the scope.""" - return [self.__class__(value) for value in self._value_] + def __eq__(self, other: StrFlag | object) -> bool: + """Compare two flags.""" + if isinstance(other, StrFlag): + return sorted(self._value_) == sorted(other._value_) + + return super().__eq__(other) + + def __hash__(self) -> int: + """Return the hash value of the flag.""" + return hash(self.value) diff --git a/asyncord/urls.py b/asyncord/urls.py index f3aeff5..7f358a2 100644 --- a/asyncord/urls.py +++ b/asyncord/urls.py @@ -10,3 +10,4 @@ API_VERSION: Final[int] = 10 REST_API_URL: Final[URL] = URL(f'{BASE_URL}/api/v{API_VERSION}') GATEWAY_URL: Final[URL] = URL(f'wss://gateway.discord.gg/?v={API_VERSION}&encoding=json') +INVITE_BASE_URL: Final[URL] = URL('https://discord.gg') diff --git a/pdm.lock b/pdm.lock index 504221e..00fc671 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "testing", "docs"] strategy = ["cross_platform"] lock_version = "4.4.1" -content_hash = "sha256:28ebc361ee3fc1645403a7f6270c0ee0a07c662c039f978b2a60b5fc0dd7f47d" +content_hash = "sha256:62cb2dd339dcd01f89795c22fb67e1b0b4a8c00c63d5e3677c892816e0d890ac" [[package]] name = "aiohttp" @@ -482,7 +482,7 @@ files = [ [[package]] name = "mkdocs-material" -version = "9.5.23" +version = "9.5.27" requires_python = ">=3.8" summary = "Documentation that simply works" dependencies = [ @@ -499,8 +499,8 @@ dependencies = [ "requests~=2.26", ] files = [ - {file = "mkdocs_material-9.5.23-py3-none-any.whl", hash = "sha256:ffd08a5beaef3cd135aceb58ded8b98bbbbf2b70e5b656f6a14a63c917d9b001"}, - {file = "mkdocs_material-9.5.23.tar.gz", hash = "sha256:4627fc3f15de2cba2bde9debc2fd59b9888ef494beabfe67eb352e23d14bf288"}, + {file = "mkdocs_material-9.5.27-py3-none-any.whl", hash = "sha256:af8cc263fafa98bb79e9e15a8c966204abf15164987569bd1175fd66a7705182"}, + {file = "mkdocs_material-9.5.27.tar.gz", hash = "sha256:a7d4a35f6d4a62b0c43a0cfe7e987da0980c13587b5bc3c26e690ad494427ec0"}, ] [[package]] @@ -667,58 +667,58 @@ files = [ [[package]] name = "pydantic" -version = "2.7.1" +version = "2.7.4" requires_python = ">=3.8" summary = "Data validation using Python type hints" dependencies = [ "annotated-types>=0.4.0", - "pydantic-core==2.18.2", + "pydantic-core==2.18.4", "typing-extensions>=4.6.1", ] files = [ - {file = "pydantic-2.7.1-py3-none-any.whl", hash = "sha256:e029badca45266732a9a79898a15ae2e8b14840b1eabbb25844be28f0b33f3d5"}, - {file = "pydantic-2.7.1.tar.gz", hash = "sha256:e9dbb5eada8abe4d9ae5f46b9939aead650cd2b68f249bb3a8139dbe125803cc"}, + {file = "pydantic-2.7.4-py3-none-any.whl", hash = "sha256:ee8538d41ccb9c0a9ad3e0e5f07bf15ed8015b481ced539a1759d8cc89ae90d0"}, + {file = "pydantic-2.7.4.tar.gz", hash = "sha256:0c84efd9548d545f63ac0060c1e4d39bb9b14db8b3c0652338aecc07b5adec52"}, ] [[package]] name = "pydantic-core" -version = "2.18.2" +version = "2.18.4" requires_python = ">=3.8" summary = "Core functionality for Pydantic validation and serialization" dependencies = [ "typing-extensions!=4.7.0,>=4.6.0", ] files = [ - {file = "pydantic_core-2.18.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:fb2bd7be70c0fe4dfd32c951bc813d9fe6ebcbfdd15a07527796c8204bd36242"}, - {file = "pydantic_core-2.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6132dd3bd52838acddca05a72aafb6eab6536aa145e923bb50f45e78b7251043"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d904828195733c183d20a54230c0df0eb46ec746ea1a666730787353e87182"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c9bd70772c720142be1020eac55f8143a34ec9f82d75a8e7a07852023e46617f"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b8ed04b3582771764538f7ee7001b02e1170223cf9b75dff0bc698fadb00cf3"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e6dac87ddb34aaec85f873d737e9d06a3555a1cc1a8e0c44b7f8d5daeb89d86f"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca4ae5a27ad7a4ee5170aebce1574b375de390bc01284f87b18d43a3984df72"}, - {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:886eec03591b7cf058467a70a87733b35f44707bd86cf64a615584fd72488b7c"}, - {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca7b0c1f1c983e064caa85f3792dd2fe3526b3505378874afa84baf662e12241"}, - {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b4356d3538c3649337df4074e81b85f0616b79731fe22dd11b99499b2ebbdf3"}, - {file = "pydantic_core-2.18.2-cp312-none-win32.whl", hash = "sha256:8b172601454f2d7701121bbec3425dd71efcb787a027edf49724c9cefc14c038"}, - {file = "pydantic_core-2.18.2-cp312-none-win_amd64.whl", hash = "sha256:b1bd7e47b1558ea872bd16c8502c414f9e90dcf12f1395129d7bb42a09a95438"}, - {file = "pydantic_core-2.18.2-cp312-none-win_arm64.whl", hash = "sha256:98758d627ff397e752bc339272c14c98199c613f922d4a384ddc07526c86a2ec"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a1874c6dd4113308bd0eb568418e6114b252afe44319ead2b4081e9b9521fe75"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:ccdd111c03bfd3666bd2472b674c6899550e09e9f298954cfc896ab92b5b0e6d"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e18609ceaa6eed63753037fc06ebb16041d17d28199ae5aba0052c51449650a9"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e5c584d357c4e2baf0ff7baf44f4994be121e16a2c88918a5817331fc7599d7"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43f0f463cf89ace478de71a318b1b4f05ebc456a9b9300d027b4b57c1a2064fb"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e1b395e58b10b73b07b7cf740d728dd4ff9365ac46c18751bf8b3d8cca8f625a"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0098300eebb1c837271d3d1a2cd2911e7c11b396eac9661655ee524a7f10587b"}, - {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:36789b70d613fbac0a25bb07ab3d9dba4d2e38af609c020cf4d888d165ee0bf3"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3f9a801e7c8f1ef8718da265bba008fa121243dfe37c1cea17840b0944dfd72c"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3a6515ebc6e69d85502b4951d89131ca4e036078ea35533bb76327f8424531ce"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20aca1e2298c56ececfd8ed159ae4dde2df0781988c97ef77d5c16ff4bd5b400"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:223ee893d77a310a0391dca6df00f70bbc2f36a71a895cecd9a0e762dc37b349"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2334ce8c673ee93a1d6a65bd90327588387ba073c17e61bf19b4fd97d688d63c"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cbca948f2d14b09d20268cda7b0367723d79063f26c4ffc523af9042cad95592"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b3ef08e20ec49e02d5c6717a91bb5af9b20f1805583cb0adfe9ba2c6b505b5ae"}, - {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6fdc8627910eed0c01aed6a390a252fe3ea6d472ee70fdde56273f198938374"}, - {file = "pydantic_core-2.18.2.tar.gz", hash = "sha256:2e29d20810dfc3043ee13ac7d9e25105799817683348823f305ab3f349b9386e"}, + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6f5c4d41b2771c730ea1c34e458e781b18cc668d194958e0112455fff4e402b2"}, + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2fdf2156aa3d017fddf8aea5adfba9f777db1d6022d392b682d2a8329e087cef"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4748321b5078216070b151d5271ef3e7cc905ab170bbfd27d5c83ee3ec436695"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:847a35c4d58721c5dc3dba599878ebbdfd96784f3fb8bb2c356e123bdcd73f34"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c40d4eaad41f78e3bbda31b89edc46a3f3dc6e171bf0ecf097ff7a0ffff7cb1"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21a5e440dbe315ab9825fcd459b8814bb92b27c974cbc23c3e8baa2b76890077"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01dd777215e2aa86dfd664daed5957704b769e726626393438f9c87690ce78c3"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4b06beb3b3f1479d32befd1f3079cc47b34fa2da62457cdf6c963393340b56e9"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:564d7922e4b13a16b98772441879fcdcbe82ff50daa622d681dd682175ea918c"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0eb2a4f660fcd8e2b1c90ad566db2b98d7f3f4717c64fe0a83e0adb39766d5b8"}, + {file = "pydantic_core-2.18.4-cp312-none-win32.whl", hash = "sha256:8b8bab4c97248095ae0c4455b5a1cd1cdd96e4e4769306ab19dda135ea4cdb07"}, + {file = "pydantic_core-2.18.4-cp312-none-win_amd64.whl", hash = "sha256:14601cdb733d741b8958224030e2bfe21a4a881fb3dd6fbb21f071cabd48fa0a"}, + {file = "pydantic_core-2.18.4-cp312-none-win_arm64.whl", hash = "sha256:c1322d7dd74713dcc157a2b7898a564ab091ca6c58302d5c7b4c07296e3fd00f"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:574d92eac874f7f4db0ca653514d823a0d22e2354359d0759e3f6a406db5d55d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1f4d26ceb5eb9eed4af91bebeae4b06c3fb28966ca3a8fb765208cf6b51102ab"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77450e6d20016ec41f43ca4a6c63e9fdde03f0ae3fe90e7c27bdbeaece8b1ed4"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d323a01da91851a4f17bf592faf46149c9169d68430b3146dcba2bb5e5719abc"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43d447dd2ae072a0065389092a231283f62d960030ecd27565672bd40746c507"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:578e24f761f3b425834f297b9935e1ce2e30f51400964ce4801002435a1b41ef"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:81b5efb2f126454586d0f40c4d834010979cb80785173d1586df845a632e4e6d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ab86ce7c8f9bea87b9d12c7f0af71102acbf5ecbc66c17796cff45dae54ef9a5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:90afc12421df2b1b4dcc975f814e21bc1754640d502a2fbcc6d41e77af5ec312"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:51991a89639a912c17bef4b45c87bd83593aee0437d8102556af4885811d59f5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:293afe532740370aba8c060882f7d26cfd00c94cae32fd2e212a3a6e3b7bc15e"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48ece5bde2e768197a2d0f6e925f9d7e3e826f0ad2271120f8144a9db18d5c8"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eae237477a873ab46e8dd748e515c72c0c804fb380fbe6c85533c7de51f23a8f"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:834b5230b5dfc0c1ec37b2fda433b271cbbc0e507560b5d1588e2cc1148cf1ce"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e858ac0a25074ba4bce653f9b5d0a85b7456eaddadc0ce82d3878c22489fa4ee"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2fd41f6eff4c20778d717af1cc50eca52f5afe7805ee530a4fbd0bae284f16e9"}, + {file = "pydantic_core-2.18.4.tar.gz", hash = "sha256:ec3beeada09ff865c344ff3bc2f427f5e6c26401cc6113d77e372c3fdac73864"}, ] [[package]] @@ -747,7 +747,7 @@ files = [ [[package]] name = "pytest" -version = "8.2.0" +version = "8.2.2" requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" dependencies = [ @@ -757,21 +757,21 @@ dependencies = [ "pluggy<2.0,>=1.5", ] files = [ - {file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"}, - {file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"}, + {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, + {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] [[package]] name = "pytest-asyncio" -version = "0.23.6" -requires_python = ">=3.8" +version = "0.21.2" +requires_python = ">=3.7" summary = "Pytest support for asyncio" dependencies = [ - "pytest<9,>=7.0.0", + "pytest>=7.0.0", ] files = [ - {file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"}, - {file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"}, + {file = "pytest_asyncio-0.21.2-py3-none-any.whl", hash = "sha256:ab664c88bb7998f711d8039cacd4884da6430886ae8bbd4eded552ed2004f16b"}, + {file = "pytest_asyncio-0.21.2.tar.gz", hash = "sha256:d67738fc232b94b326b9d060750beb16e0074210b98dd8b58a5239fa2a154f45"}, ] [[package]] @@ -927,37 +927,27 @@ files = [ [[package]] name = "ruff" -version = "0.4.4" +version = "0.4.10" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." files = [ - {file = "ruff-0.4.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:29d44ef5bb6a08e235c8249294fa8d431adc1426bfda99ed493119e6f9ea1bf6"}, - {file = "ruff-0.4.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c4efe62b5bbb24178c950732ddd40712b878a9b96b1d02b0ff0b08a090cbd891"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c8e2f1e8fc12d07ab521a9005d68a969e167b589cbcaee354cb61e9d9de9c15"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:60ed88b636a463214905c002fa3eaab19795679ed55529f91e488db3fe8976ab"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b90fc5e170fc71c712cc4d9ab0e24ea505c6a9e4ebf346787a67e691dfb72e85"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8e7e6ebc10ef16dcdc77fd5557ee60647512b400e4a60bdc4849468f076f6eef"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b9ddb2c494fb79fc208cd15ffe08f32b7682519e067413dbaf5f4b01a6087bcd"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c51c928a14f9f0a871082603e25a1588059b7e08a920f2f9fa7157b5bf08cfe9"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5eb0a4bfd6400b7d07c09a7725e1a98c3b838be557fee229ac0f84d9aa49c36"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b1867ee9bf3acc21778dcb293db504692eda5f7a11a6e6cc40890182a9f9e595"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1aecced1269481ef2894cc495647392a34b0bf3e28ff53ed95a385b13aa45768"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9da73eb616b3241a307b837f32756dc20a0b07e2bcb694fec73699c93d04a69e"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:958b4ea5589706a81065e2a776237de2ecc3e763342e5cc8e02a4a4d8a5e6f95"}, - {file = "ruff-0.4.4-py3-none-win32.whl", hash = "sha256:cb53473849f011bca6e754f2cdf47cafc9c4f4ff4570003a0dad0b9b6890e876"}, - {file = "ruff-0.4.4-py3-none-win_amd64.whl", hash = "sha256:424e5b72597482543b684c11def82669cc6b395aa8cc69acc1858b5ef3e5daae"}, - {file = "ruff-0.4.4-py3-none-win_arm64.whl", hash = "sha256:39df0537b47d3b597293edbb95baf54ff5b49589eb7ff41926d8243caa995ea6"}, - {file = "ruff-0.4.4.tar.gz", hash = "sha256:f87ea42d5cdebdc6a69761a9d0bc83ae9b3b30d0ad78952005ba6568d6c022af"}, -] - -[[package]] -name = "sentinel" -version = "1.0.0" -requires_python = ">=3.6,<4.0" -summary = "Create sentinel objects, akin to None, NotImplemented, Ellipsis" -files = [ - {file = "sentinel-1.0.0-py3-none-any.whl", hash = "sha256:24f02a34cc9f0fcba5a666a23b6c7f56aff332fc624632ee442e7237751a9f60"}, - {file = "sentinel-1.0.0.tar.gz", hash = "sha256:190928f9951af6e94a1f84eefcaed791c28097dd152b88e988906be300451fd2"}, + {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, + {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, + {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, + {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, + {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, + {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index f01a3fa..79d0536 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ license = { text = "MIT" } requires-python = ">=3.12.0" dependencies = [ "yarl>=1.8.1", - "sentinel>=1.0.0", "rich>=12.5.1", "pydantic>=2.4.2", "aiohttp<4", @@ -32,7 +31,7 @@ dev = [ ] testing = [ "pytest>=7.3.1", - "pytest-asyncio>=0.19.0", + "pytest-asyncio>=0.19.0,<0.23.0", # 0.23.0 has a bugs with scopes "pytest-pretty>=1.2.0", "pytest-mock>=3.12.0", "pytest-xdist>=3.6.1", @@ -48,20 +47,27 @@ source = "call" getter = "scritps.version:get_version" [tool.pdm.scripts] +coverage = "pytest --cov=asyncord --cov-report html --cov-append -n auto tests" docs = "mkdocs serve -f docs/mkdocs.yml" build-docs = "mkdocs build -f docs/mkdocs.yml" [tool.pytest.ini_options] minversion = "6.2" -addopts = "-ra --color=yes" +addopts = "-ra --dist=loadscope --color=yes" asyncio_mode = "auto" testpaths = ["tests"] -[tool.isort] -line_length = 110 -multi_line_output = 3 -include_trailing_comma = true -length_sort = true +[tool.coverage.run] +branch = true + +[tool.coverage.report] +exclude_also = [ + "if TYPE_CHECKING:", + "^class .*\\(Protocol\\):$", + "if reason:", + "@overload", + "NotImplementedError", +] [tool.ruff] line-length = 120 diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/client/test_interactions.py b/tests/client/test_interactions.py new file mode 100644 index 0000000..1eaa3a4 --- /dev/null +++ b/tests/client/test_interactions.py @@ -0,0 +1,126 @@ +from typing import cast +from unittest.mock import AsyncMock + +import pytest + +from asyncord.client.http.headers import JSON_CONTENT_TYPE +from asyncord.client.http.models import FormPayload, JsonField +from asyncord.client.interactions.models.common import InteractionResponseType +from asyncord.client.interactions.models.requests import ( + InteractionRespAutocompleteRequest, + InteractionRespDeferredMessageRequest, + InteractionRespMessageRequest, + InteractionRespModalRequest, + InteractionResponseRequestType, + InteractionRespUpdateDeferredMessageRequest, + InteractionRespUpdateMessageRequest, +) +from asyncord.client.interactions.resources import ( + InteractionResource, +) +from asyncord.client.messages.models.requests.components import TextInput +from asyncord.client.models.attachments import Attachment + +TEST_ATTACHMENTS = [Attachment(content=b'png:...')] + + +@pytest.fixture() +def interaction_res() -> InteractionResource: + """Returns an instance of InteractionResource.""" + return InteractionResource(AsyncMock()) + + +@pytest.mark.parametrize( + 'resp', + [ + InteractionRespMessageRequest(content='Hello, World!'), + InteractionRespUpdateMessageRequest(content='Hello, World!'), + InteractionRespDeferredMessageRequest(content='Hello, World!'), + InteractionRespUpdateDeferredMessageRequest(content='Hello, World!'), + InteractionRespAutocompleteRequest(choices=[]), + InteractionRespModalRequest( + custom_id='1234567890', + title='Title', + components=TextInput(custom_id='1234567891', label='Label'), + ), + ], +) +async def test_send_not_pong_response( + resp: InteractionResponseRequestType, + interaction_res: InteractionResource, +) -> None: + """Test send_response method of InteractionResource.""" + interaction_id = '1234567890' + interaction_token = 'token' # noqa: S105 + await interaction_res.send_response(interaction_id, interaction_token, resp) + + method_caller = cast(AsyncMock, interaction_res._http_client.post) + method_caller.assert_called_once() + + request_url = method_caller.call_args.kwargs['url'] + payload = method_caller.call_args.kwargs['payload'] + + assert str(request_url).endswith(f'/interactions/{interaction_id}/{interaction_token}/callback') + assert payload['data'] == resp.model_dump(mode='json', exclude_none=True) + assert payload.get('type') + + +async def test_send_pong(interaction_res: InteractionResource) -> None: + """Test send pong response method of InteractionResource.""" + interaction_id = '1234567890' + interaction_token = 'token' # noqa: S105 + + await interaction_res.send_pong(interaction_id, interaction_token) + + method_caller = cast(AsyncMock, interaction_res._http_client.post) + method_caller.assert_called_once() + + request_url = method_caller.call_args.kwargs['url'] + payload = method_caller.call_args.kwargs['payload'] + + assert str(request_url).endswith(f'/interactions/{interaction_id}/{interaction_token}/callback') + assert payload == {'type': InteractionResponseType.PONG.value} + + +@pytest.mark.parametrize( + 'resp', + [ + InteractionRespMessageRequest(content='Hello, World!', attachments=TEST_ATTACHMENTS), + InteractionRespUpdateMessageRequest(content='Hello, World!', attachments=TEST_ATTACHMENTS), + InteractionRespDeferredMessageRequest(content='Hello, World!', attachments=TEST_ATTACHMENTS), + InteractionRespUpdateDeferredMessageRequest(content='Hello, World!', attachments=TEST_ATTACHMENTS), + ], +) +async def test_attachments_in_response( + resp: InteractionResponseRequestType, + interaction_res: InteractionResource, +) -> None: + """Test attachments in response.""" + interaction_id = '1234567890' + interaction_token = 'token' # noqa: S105 + await interaction_res.send_response(interaction_id, interaction_token, resp) + + method_caller = cast(AsyncMock, interaction_res._http_client.post) + method_caller.assert_called_once() + + request_url = method_caller.call_args.kwargs['url'] + payload = method_caller.call_args.kwargs['payload'] + + assert str(request_url).endswith(f'/interactions/{interaction_id}/{interaction_token}/callback') + assert isinstance(payload, FormPayload) + + (json_field_name, json_payload), (attachment_field_name, attachment) = payload + + assert isinstance(json_payload, JsonField) + assert json_field_name == 'payload_json' + assert json_payload.content_type == JSON_CONTENT_TYPE + assert not json_payload.filename + assert isinstance(json_payload.value, dict) + assert json_payload.value.get('type') + assert json_payload.value['data'] == resp.model_dump(mode='json', exclude_none=True) + + assert attachment_field_name == 'files[0]' + assert not attachment.content_type + assert not attachment.filename + assert isinstance(attachment.value, bytes) + assert attachment.value == TEST_ATTACHMENTS[0].content diff --git a/tests/client/test_rest.py b/tests/client/test_rest.py new file mode 100644 index 0000000..0602f57 --- /dev/null +++ b/tests/client/test_rest.py @@ -0,0 +1,43 @@ +from unittest.mock import Mock + +import pytest + +from asyncord.client.rest import RestClient + + +def test_create_with_http_and_session_fail() -> None: + """Test creating RestClient with both session and http_client.""" + with pytest.raises(ValueError, match='Cannot pass both session and http_client'): + RestClient('token', session=Mock(), http_client=Mock()) + + +def test_create_with_http_client() -> None: + """Test creating RestClient with http_client.""" + http_client = Mock() + client = RestClient('token', http_client=http_client) + assert client._http_client == http_client + + +def test_create_without_auth_fail() -> None: + """Test creating RestClient without auth.""" + with pytest.raises(ValueError, match='Auth strategy is required'): + RestClient(None) + + +def test_create_with_http_client_and_no_auth() -> None: + """Test creating RestClient with http_client and no auth.""" + RestClient(None, http_client=Mock()) + + +def test_create_with_no_rate_limit_strategy() -> None: + """Test creating RestClient with no rate limit strategy.""" + RestClient('token', http_client=Mock(), ratelimit_strategy=None) + + +def test_create_with_castom_auth_strategy() -> None: + """Test creating RestClient with custom auth strategy.""" + auth = Mock() + client = RestClient(auth, http_client=Mock()) + mdlwr_append: Mock = client._http_client.system_middlewares.append # type: ignore + assert mdlwr_append.call_count == 2 + assert mdlwr_append.call_args_list[0][0][0] == auth diff --git a/tests/conftest.py b/tests/conftest.py index b9c92b6..de063e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,27 @@ +import asyncio import os -from dataclasses import dataclass +from collections.abc import Iterator from pathlib import Path from typing import Final import pytest +from pydantic import BaseModel, SecretStr INTEGRATION_TEST_DIR: Final[Path] = Path(__file__).parent / 'integration' -@dataclass -class IntegrationTestData: +@pytest.fixture(scope='module') +def event_loop(request: pytest.FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]: + """Create an instance of the default event loop for each test module.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +class IntegrationTestData(BaseModel): """Data to perform integration tests.""" - token: str + token: SecretStr channel_id: str voice_channel_id: str guild_id: str @@ -24,7 +33,7 @@ class IntegrationTestData: app_id: str role_id: str user_to_ban: str - stage_id: str + role_to_prune: str @pytest.fixture(scope='session') @@ -34,7 +43,7 @@ def integration_data() -> IntegrationTestData: if token is None: raise RuntimeError('ASYNCORD_TEST_TOKEN env variable is not set') return IntegrationTestData( - token=token, + token=token, # type: ignore channel_id=os.environ['ASYNCORD_TEST_CHANNEL_ID'], voice_channel_id=os.environ['ASYNCORD_TEST_VOICE_CHANNEL_ID'], guild_id=os.environ['ASYNCORD_TEST_GUILD_ID'], @@ -46,14 +55,14 @@ def integration_data() -> IntegrationTestData: app_id=os.environ['ASYNCORD_TEST_APP_ID'], role_id=os.environ['ASYNCORD_TEST_ROLE_ID'], user_to_ban=os.environ['ASYNCORD_TEST_USER_TO_BAN'], - stage_id=os.environ['ASYNCORD_TEST_STAGE_ID'], + role_to_prune=os.environ['ASYNCORD_TEST_ROLE_TO_PRUNE'], ) @pytest.fixture(scope='session') def token(integration_data: IntegrationTestData) -> str: """Get token to perform integration tests.""" - return integration_data.token + return integration_data.token.get_secret_value() def pytest_addoption(parser: pytest.Parser) -> None: diff --git a/tests/gateway/conftest.py b/tests/gateway/conftest.py new file mode 100644 index 0000000..3ae520c --- /dev/null +++ b/tests/gateway/conftest.py @@ -0,0 +1,19 @@ +from unittest.mock import Mock + +import aiohttp +import pytest +from pytest_mock import MockFixture + +from asyncord.gateway.client.client import ConnectionData, GatewayClient +from asyncord.gateway.client.heartbeat import Heartbeat + + +@pytest.fixture() +def gw_client(mocker: MockFixture) -> GatewayClient: + """Create a GatewayClient instance.""" + return GatewayClient( + token='token', # noqa: S106 + session=Mock(spec=aiohttp.ClientSession), + conn_data=ConnectionData(token='token'), # noqa: S106 + heartbeat_class=Mock(spec=type(Heartbeat)), + ) diff --git a/tests/gateway/test_client.py b/tests/gateway/test_client.py new file mode 100644 index 0000000..db084d1 --- /dev/null +++ b/tests/gateway/test_client.py @@ -0,0 +1,501 @@ +import asyncio +from types import MappingProxyType +from typing import Literal +from unittest.mock import AsyncMock, Mock, patch + +import aiohttp +import pytest +from pytest_mock import MockFixture + +from asyncord.client.http.middleware.auth import BotTokenAuthStrategy +from asyncord.gateway.client.client import ConnectionData, GatewayClient, GatewayCommandOpcode +from asyncord.gateway.client.errors import ConnectionClosedError +from asyncord.gateway.client.heartbeat import Heartbeat, HeartbeatFactory +from asyncord.gateway.commands import IdentifyCommand, PresenceUpdateData, ResumeCommand +from asyncord.gateway.dispatcher import EventDispatcher +from asyncord.gateway.intents import DEFAULT_INTENTS, Intent +from asyncord.gateway.message import ( + DatalessMessage, + DispatchMessage, + FallbackGatewayMessage, + GatewayMessageOpcode, + HelloMessage, + HelloMessageData, +) + + +@pytest.mark.parametrize('token', ['token', BotTokenAuthStrategy('token')]) +@pytest.mark.parametrize('session', [Mock()]) +@pytest.mark.parametrize('conn_data', [None, ConnectionData(token='token')]) # noqa: S106 +@pytest.mark.parametrize('intents', [DEFAULT_INTENTS, Intent.GUILDS]) +@pytest.mark.parametrize('heartbeat_class', [Heartbeat, HeartbeatFactory(), None]) +@pytest.mark.parametrize('dispatcher', [None, EventDispatcher()]) +@pytest.mark.parametrize('name', [None, 'TestClient']) +def test_init( # noqa: PLR0917, PLR0913 + token: BotTokenAuthStrategy | Literal['token'], + session: aiohttp.ClientSession, + conn_data: None | ConnectionData, + intents: Intent, + heartbeat_class: type[Heartbeat] | HeartbeatFactory | None, + dispatcher: None | EventDispatcher, + name: None | Literal['TestClient'], + mocker: MockFixture, +) -> None: + """Test initializing the GatewayClient. + + Init logic looks like overkill, but it's working fine at the moment. + I don't think that is a good idea to make separate tests for each parameter + for at the moment. + Candidate for refactoring. + """ + mocker.patch('asyncio.get_event_loop', return_value=Mock()) + if heartbeat_class: + client = GatewayClient( + token=token, + session=session, + conn_data=conn_data, + intents=intents, + heartbeat_class=heartbeat_class, + dispatcher=dispatcher, + name=name, + ) + else: + client = GatewayClient( + token=token, + session=session, + conn_data=conn_data, + intents=intents, + dispatcher=dispatcher, + name=name, + ) + + if isinstance(token, str): + str_token = token + else: + str_token = token.token + + assert client.session == session + assert client.conn_data == (conn_data or ConnectionData(token=str_token)) + assert client.intents == intents + assert isinstance(client.heartbeat, Heartbeat) + if dispatcher: + assert client.dispatcher is dispatcher + else: + assert isinstance(client.dispatcher, EventDispatcher) + assert not client.is_started + assert client.name == name + assert client._ws is None + assert not client._need_restart.is_set() + assert isinstance(client._opcode_handlers, MappingProxyType) + assert len(client._opcode_handlers) + + +async def test_connect_when_already_started(gw_client: GatewayClient) -> None: + """Test connecting when the client is already started.""" + gw_client.is_started = True + with pytest.raises(RuntimeError, match='Client is already started'): + await gw_client.connect() + + +async def test_connect_when_not_started(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test connecting when the client is not started.""" + gw_client.is_started = False + + mock_connect = mocker.patch.object(gw_client, '_connect', return_value=asyncio.Future()) + mock_connect.return_value.set_result(None) + + await gw_client.connect() + + assert gw_client.is_started + mock_connect.assert_called_once() + + +async def test_close_when_not_started_and_no_ws(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test closing when the client is not started and no websocket.""" + gw_client.is_started = False + gw_client._ws = None + mock_stop = mocker.patch.object(gw_client.heartbeat, 'stop') + await gw_client.close() + assert not gw_client.is_started + assert not gw_client._need_restart.is_set() + mock_stop.assert_not_called() + + +async def test_close_when_started_and_no_ws(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test closing when the client is started and no websocket.""" + gw_client.is_started = True + gw_client._ws = None + mock_stop = mocker.patch.object(gw_client.heartbeat, 'stop') + await gw_client.close() + assert not gw_client.is_started + assert gw_client._need_restart.is_set() + mock_stop.assert_called_once() + + +async def test_close_when_started_and_ws(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test closing when the client is started and websocket exists.""" + gw_client.is_started = True + gw_client._ws = Mock() + mock_stop = mocker.patch.object(gw_client.heartbeat, 'stop') + mock_close = mocker.patch.object(gw_client._ws, 'close', return_value=asyncio.Future()) + mock_close.return_value.set_result(None) + await gw_client.close() + assert not gw_client.is_started + assert gw_client._need_restart.is_set() + mock_stop.assert_called_once() + mock_close.assert_called_once() + + +async def test_send_command_no_ws(gw_client: GatewayClient) -> None: + """Test sending a command when the websocket is not connected.""" + gw_client._ws = None + with pytest.raises(RuntimeError, match='Client is not connected'): + await gw_client.send_command(GatewayCommandOpcode.HEARTBEAT, {}) + + +async def test_send_command_with_ws(gw_client: GatewayClient) -> None: + """Test sending a command when the websocket is connected.""" + mock_ws = AsyncMock() + gw_client._ws = mock_ws + opcode = GatewayCommandOpcode.HEARTBEAT + data = {'test': 'data'} + await gw_client.send_command(opcode, data) + mock_ws.send_json.assert_called_once_with({'op': opcode, 'd': data}) + + +async def test_reconnect_no_ws(gw_client: GatewayClient) -> None: + """Test reconnecting when the websocket is not connected.""" + gw_client._ws = None + with pytest.raises(RuntimeError, match='Client is not started'): + gw_client.reconnect() + + +async def test_reconnect_with_ws(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test reconnecting when the websocket is connected.""" + gw_client._ws = Mock() + mock_stop = mocker.patch.object(gw_client.heartbeat, 'stop') + gw_client.reconnect() + assert gw_client._need_restart.is_set() + mock_stop.assert_called_once() + + +async def test_identify(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test identifying with the gateway.""" + mock_send_command = mocker.patch.object(gw_client, 'send_command', return_value=asyncio.Future()) + mock_send_command.return_value.set_result(None) + command_data = IdentifyCommand( + token='token', # noqa: S106 + properties={}, # type: ignore + compress=False, + large_threshold=250, + ) + await gw_client.identify(command_data) + payload = command_data.model_dump(mode='json', exclude_none=True) + mock_send_command.assert_called_once_with(GatewayCommandOpcode.IDENTIFY, payload) + + +async def test_send_resume_no_ws(gw_client: GatewayClient) -> None: + """Test sending a resume command when the websocket is not connected.""" + gw_client._ws = None + command_data = ResumeCommand( + token='token', # noqa: S106 + session_id='session_id', + seq=1, + ) + with pytest.raises(RuntimeError, match='Client is not connected'): + await gw_client.send_resume(command_data) + + +async def test_send_resume_with_ws(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test sending a resume command when the websocket is connected.""" + mock_ws = AsyncMock() + gw_client._ws = mock_ws + command_data = ResumeCommand( + token='token', # noqa: S106 + session_id='session_id', + seq=1, + ) + mock_send_command = mocker.patch.object(gw_client, 'send_command', return_value=asyncio.Future()) + mock_send_command.return_value.set_result(None) + await gw_client.send_resume(command_data) + mock_send_command.assert_called_once_with(GatewayCommandOpcode.RESUME, command_data.model_dump(mode='json')) + + +async def test_update_presence(gw_client: GatewayClient) -> None: + """Test updating the client's presence.""" + mock_presence_data = Mock(spec=PresenceUpdateData) + # can't set model_dump attribute for pydantic models + mock_presence_data.model_dump = Mock(return_value={'test': 'data'}) + + with patch.object(gw_client, 'send_command', new_callable=AsyncMock) as mock_send_command: + await gw_client.update_presence(mock_presence_data) + + mock_send_command.assert_called_once_with(GatewayCommandOpcode.PRESENCE_UPDATE, {'test': 'data'}) + + +async def test_send_heartbeat_no_ws(gw_client: GatewayClient) -> None: + """Test send_heartbeat when the client is not connected.""" + gw_client._ws = None + with pytest.raises(RuntimeError, match='Client is not connected'): + await gw_client.send_heartbeat(1) + + +async def test_send_heartbeat_with_ws(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test send_heartbeat when the client is connected.""" + mock_ws = mocker.Mock() + mock_send_json = mocker.AsyncMock() + mock_ws.send_json = mock_send_json + gw_client._ws = mock_ws + seq = 1 + await gw_client.send_heartbeat(seq) + mock_send_json.assert_called_once_with({'op': GatewayCommandOpcode.HEARTBEAT, 'd': seq}) + + +async def test__connect_when_not_started(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _connect when the client is not started.""" + gw_client.is_started = False + mock_ws_connect = mocker.patch.object(gw_client.session, 'ws_connect', new_callable=AsyncMock) + await gw_client._connect() + mock_ws_connect.assert_not_called() + + +async def test__connect_when_started_and_connection_closed_immediately( + gw_client: GatewayClient, + mocker: MockFixture, +) -> None: + """Test _connect when the client is started and the connection is closed immediately.""" + gw_client.is_started = True + mock_ws = AsyncMock() + mock_ws.__aenter__.return_value = mock_ws + mock_ws_connect = mocker.patch.object(gw_client.session, 'ws_connect', return_value=mock_ws) + + def _stop(_ws: object) -> None: + gw_client.is_started = False + + mocker.patch.object(gw_client, '_ws_recv_loop', side_effect=_stop) + await gw_client._connect() + mock_ws_connect.assert_called_once() + + +async def test__connect_when_started_and_connection_closed_with_error( + gw_client: GatewayClient, + mocker: MockFixture, +) -> None: + """Test _connect when the client is started and the connection is closed with an error.""" + gw_client.is_started = True + mock_ws = AsyncMock() + mock_ws.__aenter__.return_value = mock_ws + mock_ws_connect = mocker.patch.object(gw_client.session, 'ws_connect', return_value=mock_ws) + + first_raise = True + + def _raise(_ws: object) -> None: + nonlocal first_raise + if first_raise: + first_raise = False + raise ConnectionClosedError + + gw_client.is_started = False + raise ConnectionClosedError + + mock_ws_recv_loop = mocker.patch.object(gw_client, '_ws_recv_loop', side_effect=_raise) + + await gw_client._connect() + mock_ws_connect.assert_called() + mock_ws_recv_loop.assert_called() + + +async def test__handle_heartbeat_ack(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test handling the heartbeat acknowledgement.""" + mock_handle_heartbeat_ack = mocker.patch.object(gw_client.heartbeat, 'handle_heartbeat_ack', new_callable=AsyncMock) + await gw_client._handle_heartbeat_ack(Mock()) + mock_handle_heartbeat_ack.assert_called_once() + + +async def test__ws_recv_loop_not_started(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _ws_recv_loop when the client is not started.""" + gw_client.is_started = False + mock_get_message = mocker.patch.object(gw_client, '_get_message', return_value=asyncio.Future()) + mock_get_message.return_value.set_result(None) + mock_handle_message = mocker.patch.object(gw_client, '_handle_message') + + await gw_client._ws_recv_loop(Mock()) + + mock_get_message.assert_not_called() + mock_handle_message.assert_not_called() + + +async def test__ws_recv_loop_need_restart(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _ws_recv_loop when _need_restart is set.""" + gw_client.is_started = True + gw_client._need_restart.set() + mock_get_message = mocker.patch.object(gw_client, '_get_message', return_value=asyncio.Future()) + mock_get_message.return_value.set_result(None) + mock_handle_message = mocker.patch.object(gw_client, '_handle_message') + + await gw_client._ws_recv_loop(Mock()) + + mock_get_message.assert_not_called() + mock_handle_message.assert_not_called() + + +async def test__ws_recv_loop_when_not_started(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _ws_recv_loop when the client is not started.""" + mocker.patch.object(gw_client, '_get_message', new_callable=AsyncMock) + mocker.patch.object(gw_client, '_handle_message', new_callable=AsyncMock) + gw_client.is_started = False + + await gw_client._ws_recv_loop(Mock()) + + gw_client._get_message.assert_not_called() # type: ignore + gw_client._handle_message.assert_not_called() # type: ignore + + +async def test__ws_recv_loop_when_need_restart(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _ws_recv_loop when the client needs to restart.""" + mocker.patch.object(gw_client, '_get_message', new_callable=AsyncMock) + mocker.patch.object(gw_client, '_handle_message', new_callable=AsyncMock) + gw_client.is_started = True + gw_client._need_restart.set() + + await gw_client._ws_recv_loop(Mock()) + + gw_client._get_message.assert_not_called() # type: ignore + gw_client._handle_message.assert_not_called() # type: ignore + + +async def test_websocket_receive_loop_message_waiting(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _ws_recv_loop when a message is received.""" + + async def _get_message_mock(_ws_resp: object) -> FallbackGatewayMessage: + # Add a delay before returning a message + # This will ensure that the loop continues to the message branch + await asyncio.sleep(1) + return FallbackGatewayMessage(op=100) # type: ignore + + mocker.patch.object(gw_client, '_get_message', new=_get_message_mock) + + async def _mock_handler_message(_message: object) -> None: + gw_client.is_started = False + + mocker.patch.object(gw_client, '_handle_message', new=_mock_handler_message) + gw_client.is_started = True + + await gw_client._ws_recv_loop(Mock()) + assert not gw_client._need_restart.is_set() + + +async def test_ws_recv_loop_restart_waiting(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _ws_recv_loop when waiting for restart.""" + # dummy message which will not be completed + message_future = asyncio.Future() + + async def _coro(_ws_resp: object) -> None: + await message_future + + mocker.patch.object(gw_client, '_get_message', new=_coro) + + # Mock the handler to stop the client + async def _mock_handler_message(_message: object) -> None: + gw_client.is_started = False + + mocker.patch.object(gw_client, '_handle_message', new=_mock_handler_message) + gw_client.is_started = True + + async def _set_need_restart() -> None: + # Add a delay before setting the need_restart flag + # This will ensure that the loop continues to the restart branch + await asyncio.sleep(1) + return gw_client._need_restart.set() + + # After not so long, the need_restart flag will be set + # This will continue the loop to the restart branch + await asyncio.gather( + gw_client._ws_recv_loop(Mock()), + _set_need_restart(), + ) + + assert gw_client._need_restart.is_set() + assert message_future.cancelled() + + +async def test__get_message_text(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test _get_message when the message type is TEXT.""" + ws = AsyncMock() + message = Mock() + message.type = aiohttp.WSMsgType.TEXT + message.json.return_value = {'op': 43214, 'd': {'foo': 'bar'}} + ws.receive.return_value = message + + result = await gw_client._get_message(ws) + + assert result == FallbackGatewayMessage(op=43214, d={'foo': 'bar'}) # type: ignore + + +@pytest.mark.parametrize('msg_type', [aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED]) +async def test__get_message_close_types( + gw_client: GatewayClient, + msg_type: aiohttp.WSMsgType, +) -> None: + """Test _get_message when the message type is CLOSE, CLOSING, or CLOSED.""" + ws = AsyncMock() + message = Mock() + message.type = msg_type + ws.receive.return_value = message + + with pytest.raises(ConnectionClosedError): + await gw_client._get_message(ws) + + +async def test__get_message_unhandled_type(gw_client: GatewayClient) -> None: + """Test _get_message when the message type is not handled.""" + ws = AsyncMock() + message = Mock() + message.type = aiohttp.WSMsgType.BINARY # An unhandled type + ws.receive.return_value = message + + result = await gw_client._get_message(ws) + + assert result is None + + +async def test__handle_message_dispatch(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test handling a dispatch event.""" + mock_logger = mocker.patch.object(gw_client, 'logger') + message = DispatchMessage(op=GatewayMessageOpcode.DISPATCH, t='TestEvent', d={}, s=1) + await gw_client._handle_message(message) + mock_logger.info.assert_called_once_with('Dispatching event: %s', message.event_name) + + +async def test__handle_message_non_dispatch(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test handling a non-dispatch event.""" + mock_logger = mocker.patch.object(gw_client, 'logger') + gw_client._ws = AsyncMock() + message = HelloMessage( + op=GatewayMessageOpcode.HELLO, + d=HelloMessageData(heartbeat_interval=1000), + ) + await gw_client._handle_message(message) + mock_logger.info.assert_called_once_with('Received message: %s', message.opcode.name) + + +async def test__handle_message_with_opcode_handler(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test handling a message with an opcode handler.""" + opcode = GatewayMessageOpcode.HEARTBEAT_ACK + handler = gw_client._opcode_handlers[opcode] + mock_handler = mocker.patch.object(handler, 'handle') + message = DatalessMessage(op=opcode) + await gw_client._handle_message(message) + mock_handler.assert_called_once_with(message) + + +async def test__handle_message_without_opcode_handler(gw_client: GatewayClient, mocker: MockFixture) -> None: + """Test handling a message without an opcode handler.""" + mock_logger = mocker.patch.object(gw_client, 'logger') + + opcode = GatewayMessageOpcode(9999) # Non-existent opcode + assert opcode.name == 'UNKNOWN' + + message = FallbackGatewayMessage(op=opcode, data={}) + await gw_client._handle_message(message) + mock_logger.warning.assert_called_once_with('Unhandled opcode: %s', opcode) diff --git a/tests/gateway/test_connection_data.py b/tests/gateway/test_connection_data.py new file mode 100644 index 0000000..e003f58 --- /dev/null +++ b/tests/gateway/test_connection_data.py @@ -0,0 +1,35 @@ +import pytest +from yarl import URL + +from asyncord.gateway.client.client import ConnectionData +from asyncord.urls import GATEWAY_URL + + +@pytest.mark.parametrize('token', ['token', '']) +@pytest.mark.parametrize('resume_url', ['ws://localhost', '']) +@pytest.mark.parametrize('session_id', ['session_id', None]) +@pytest.mark.parametrize('seq', [1, 0]) +def test_can_resume(token: str, resume_url: str, session_id: str | None, seq: int) -> None: + """Test checking if the connection can be resumed.""" + conn_data = ConnectionData( + token=token, + resume_url=URL(resume_url), + session_id=session_id, + seq=seq, + ) + assert conn_data.can_resume is bool(resume_url and session_id and seq) + + +def test_reset() -> None: + """Test resetting connection data.""" + conn_data = ConnectionData( + token='token', # noqa: S106 + resume_url=URL('ws://localhost'), + session_id='session_id', + seq=1, + ) + conn_data.reset() + + assert conn_data.resume_url == GATEWAY_URL + assert conn_data.session_id is None + assert conn_data.seq == 0 diff --git a/tests/gateway/test_dispatcher.py b/tests/gateway/test_dispatcher.py index ae251e1..1ff792d 100644 --- a/tests/gateway/test_dispatcher.py +++ b/tests/gateway/test_dispatcher.py @@ -1,3 +1,4 @@ +import logging from unittest import mock import pytest @@ -74,7 +75,7 @@ async def handler(_: mock.Mock()) -> None: # type: ignore pass with pytest.raises(TypeError): - dispatcher.add_handler(str, handler) + dispatcher.add_handler(str, handler) # type: ignore def test_add_argument(dispatcher: EventDispatcher) -> None: @@ -87,7 +88,7 @@ def test_add_argument(dispatcher: EventDispatcher) -> None: async def handler(_: CustomEvent, arg1: str) -> None: pass - dispatcher.add_handler(GatewayEvent, handler) + dispatcher.add_handler(GatewayEvent, handler) # type: ignore assert 'arg1' not in dispatcher._args assert 'arg1' not in dispatcher._cached_args[handler] @@ -117,12 +118,74 @@ async def test_dispatch(dispatcher: EventDispatcher) -> None: handler.assert_called_once_with(event, arg1='value1') +async def test_dispatch_calls_correct_handler(dispatcher: EventDispatcher) -> None: + """Test that dispatch calls the correct handler for a given event.""" + event = CustomEvent() + handler_called = False + + async def handler(event: CustomEvent) -> None: + nonlocal handler_called + handler_called = True + + dispatcher.add_handler(CustomEvent, handler) + await dispatcher.dispatch(event) + + assert handler_called, 'Handler was not called' + + +async def test_dispatch_does_not_call_incorrect_handler(dispatcher: EventDispatcher) -> None: + """Test that dispatch does not call handlers for different events.""" + handler_called = False + + async def handler(event: CustomEvent2) -> None: + nonlocal handler_called + handler_called = True + + dispatcher.add_handler(CustomEvent2, handler) + await dispatcher.dispatch(CustomEvent()) + + assert not handler_called, 'Handler was incorrectly called' + + +async def test_dispatch_passes_arguments_to_handler(dispatcher: EventDispatcher) -> None: + """Test that dispatch passes the correct arguments to the handler.""" + event = CustomEvent() + received_args = None + + async def handler(event: CustomEvent, arg1: str) -> None: + nonlocal received_args + received_args = arg1 + + dispatcher.add_handler(CustomEvent, handler) + dispatcher.add_argument('arg1', 'value1') + await dispatcher.dispatch(event) + + assert received_args == 'value1', 'Handler did not receive correct arguments' + + +async def test_dispatch_logs_exception( + dispatcher: EventDispatcher, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that dispatch logs an exception if one is raised in the handler.""" + event = CustomEvent() + + async def handler(event: CustomEvent) -> None: + raise Exception('Test exception') + + dispatcher.add_handler(CustomEvent, handler) + with caplog.at_level(logging.ERROR): + await dispatcher.dispatch(event) + + assert 'Unhandled exception in event handler' in caplog.text, 'Exception was not logged' + + async def test_dispatch_with_no_handlers(dispatcher: EventDispatcher) -> None: """Test dispatching an event with no handlers. It should not raise any errors. """ - await dispatcher.dispatch(CustomEvent) + await dispatcher.dispatch(CustomEvent()) async def test_dispatch_with_multiple_handlers(dispatcher: EventDispatcher) -> None: diff --git a/tests/gateway/test_heartbeat.py b/tests/gateway/test_heartbeat.py new file mode 100644 index 0000000..0359881 --- /dev/null +++ b/tests/gateway/test_heartbeat.py @@ -0,0 +1,203 @@ +import datetime +import logging +from collections.abc import AsyncGenerator +from unittest.mock import Mock + +import pytest +from pytest_mock import MockerFixture + +from asyncord.gateway.client.client import ConnectionData, GatewayClient +from asyncord.gateway.client.heartbeat import Heartbeat + + +@pytest.fixture() +def gw_client() -> GatewayClient: + """Return a mock client.""" + return Mock(spec=GatewayClient) + + +@pytest.fixture() +def conn_data() -> ConnectionData: + """Return a mock connection data.""" + return Mock(spec=ConnectionData) + + +@pytest.fixture() +async def heartbeat( + gw_client: GatewayClient, + conn_data: ConnectionData, +) -> AsyncGenerator[Heartbeat, None]: + """Return a mock heartbeat.""" + heratbeat = Heartbeat(client=gw_client, conn_data=conn_data) + yield heratbeat + heratbeat.stop() + + +def test_heartbeat_init( + heartbeat: Heartbeat, + gw_client: GatewayClient, + conn_data: ConnectionData, +) -> None: + """Test initializing the heartbeat.""" + assert heartbeat.client is gw_client + assert heartbeat.conn_data is conn_data + assert heartbeat._interval.total_seconds() == 0 + assert heartbeat._task is None + + +async def test_handle_heartbeat_ack( + heartbeat: Heartbeat, +) -> None: + """Test handling a heartbeat ack.""" + await heartbeat.handle_heartbeat_ack() + assert heartbeat._ack_event.is_set() + + +def test_run_stop_cycle( + heartbeat: Heartbeat, +) -> None: + """Test running and stopping the heartbeat.""" + heartbeat.run(1000) + assert heartbeat._interval.total_seconds() == 1 + assert heartbeat._task is not None + + heartbeat.stop() + assert heartbeat._task is None + assert not heartbeat._ack_event.is_set() + assert heartbeat._interval.total_seconds() == 0 + + +def test_is_running( + heartbeat: Heartbeat, +) -> None: + """Test is_running property.""" + assert not heartbeat.is_running + + heartbeat.run(1000) + assert heartbeat.is_running + + heartbeat.stop() + assert not heartbeat.is_running + + +def test_jittered_sleep_duration( + heartbeat: Heartbeat, +) -> None: + """Test jittered sleep duration.""" + interval = 5 + heartbeat._interval = datetime.timedelta(seconds=interval) + + sleep_duration = heartbeat._jittered_sleep_duration + assert isinstance(sleep_duration, float) + assert interval * 0.3 <= sleep_duration <= interval * 0.9 + + +async def test_run( + heartbeat: Heartbeat, + mocker: MockerFixture, +) -> None: + """Test the _run method.""" + heartbeat._task = Mock() + + def _stop_loop() -> None: + heartbeat._task = None + + mock_wait_heartbeat_ack = mocker.patch.object(heartbeat, '_wait_heartbeat_ack', side_effect=_stop_loop) + mocker.patch('asyncio.sleep', return_value=None) + + await heartbeat._run( + interval=datetime.timedelta(seconds=1), + ) + + mock_wait_heartbeat_ack.assert_called_once() + + +async def test_run_with_exception( + heartbeat: Heartbeat, + mocker: MockerFixture, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the _run method when an exception is raised.""" + heartbeat._task = Mock() + mock_wait_heartbeat_ack = mocker.patch.object( + heartbeat, + '_wait_heartbeat_ack', + side_effect=Exception('Some exception occurred'), + ) + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + + interval = datetime.timedelta(seconds=1) + with caplog.at_level(logging.ERROR): + await heartbeat._run(interval) + + mock_wait_heartbeat_ack.assert_called_once() + mock_sleep.assert_called_once() + + assert 'An unexpected error occurred: Some exception occurred' in caplog.text + + +async def test_run_with_timeout( + heartbeat: Heartbeat, + mocker: MockerFixture, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the _run method when a timeout occurs.""" + heartbeat._task = Mock() + mock_wait_heartbeat_ack = mocker.patch.object( + heartbeat, + '_wait_heartbeat_ack', + side_effect=TimeoutError(), + ) + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + + interval = datetime.timedelta(seconds=1) + with caplog.at_level(logging.ERROR): + await heartbeat._run(interval) + + mock_wait_heartbeat_ack.assert_called_once() + mock_sleep.assert_called_once() + + assert 'Heartbeat ack not received in time. Reconnecting...' in caplog.text + + +async def test_wait_heartbeat_ack_received( + heartbeat: Heartbeat, + conn_data: ConnectionData, + mocker: MockerFixture, +) -> None: + """Test the _wait_heartbeat_ack method when the ack is received.""" + conn_data.seq = 1 + heartbeat._ack_event = Mock() # suppress the RuntimeWarning + + mock_send_heartbeat = mocker.patch.object(heartbeat.client, 'send_heartbeat', return_value=None) + + # if the ack is received, the method should return without raising an error + # this mock emilates the ack being received + mock_wait_for = mocker.patch('asyncio.wait_for', return_value=None) + + await heartbeat._wait_heartbeat_ack() + + mock_send_heartbeat.assert_called_once() + mock_wait_for.assert_called_once() + + +async def test_wait_heartbeat_ack_no_ack_received( + heartbeat: Heartbeat, + conn_data: ConnectionData, + mocker: MockerFixture, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the _wait_heartbeat_ack method when no ack is received.""" + conn_data.seq = 1 + heartbeat._ack_event = Mock() # suppress the RuntimeWarning + + mock_send_heartbeat = mocker.patch.object(heartbeat.client, 'send_heartbeat', return_value=None) + mock_wait_for = mocker.patch('asyncio.wait_for', side_effect=TimeoutError()) + + with caplog.at_level(logging.ERROR): + await heartbeat._wait_heartbeat_ack() + + assert mock_send_heartbeat.call_count == 100 + assert mock_wait_for.call_count == 100 + + assert 'ack not received after 100 attempts' in caplog.text diff --git a/tests/gateway/test_heartbeat_factory.py b/tests/gateway/test_heartbeat_factory.py new file mode 100644 index 0000000..30e345b --- /dev/null +++ b/tests/gateway/test_heartbeat_factory.py @@ -0,0 +1,89 @@ +import asyncio +import threading +from collections.abc import AsyncGenerator +from unittest.mock import Mock + +import pytest +from pytest_mock import MockFixture + +from asyncord.gateway.client.client import ConnectionData, GatewayClient +from asyncord.gateway.client.heartbeat import HeartbeatFactory + + +@pytest.fixture() +def gw_client() -> GatewayClient: + """Return a mock client.""" + return Mock(spec=GatewayClient) + + +@pytest.fixture() +def conn_data() -> ConnectionData: + """Return a mock connection data.""" + return Mock(spec=ConnectionData) + + +@pytest.fixture() +async def factory() -> AsyncGenerator[HeartbeatFactory, None]: + """Return a heartbeat factory.""" + factory = HeartbeatFactory() + yield factory + factory.stop() + + +def test_heartbeat_factory_initialization(factory: HeartbeatFactory) -> None: + """Test that the heartbeat factory initializes correctly.""" + assert isinstance(factory.loop, asyncio.AbstractEventLoop) + assert isinstance(factory.thread, threading.Thread) + + +def test_heartbeat_factory_create(factory: HeartbeatFactory) -> None: + """Test that the heartbeat factory creates a heartbeat correctly.""" + client = Mock(spec=GatewayClient) + conn_data = Mock(spec=ConnectionData) + heartbeat = factory.create(client, conn_data) + + assert heartbeat.client is client + assert heartbeat.conn_data is conn_data + assert heartbeat._loop is factory.loop + + +def test_heartbeat_factory_cycles(factory: HeartbeatFactory) -> None: + """Test that the heartbeat process cycles.""" + factory.start() + assert factory.is_running + assert factory.thread.is_alive() + + factory.stop() + assert not factory.is_running + assert not factory.thread.is_alive() + + +def test_multiple_heartbeats_loop_sharing(factory: HeartbeatFactory) -> None: + """Test that multiple heartbeats share the same thread.""" + session = Mock(spec=GatewayClient) + conn_data = Mock(spec=ConnectionData) + heartbeat1 = factory.create(session, conn_data) + heartbeat2 = factory.create(session, conn_data) + assert heartbeat1._loop is heartbeat2._loop + + +def test_heartbeat_continues_after_one_stops(factory: HeartbeatFactory, mocker: MockFixture) -> None: + """Test that a heartbeat continues after another stops.""" + mock_run_coroutine = mocker.patch('asyncord.gateway.client.heartbeat.asyncio.run_coroutine_threadsafe') + mocker.patch('asyncord.gateway.client.heartbeat.Heartbeat._run', new=Mock()) + session = Mock(spec=GatewayClient) + conn_data = Mock(spec=ConnectionData) + heartbeat1 = factory.create(session, conn_data) + heartbeat2 = factory.create(session, conn_data) + + heartbeat1.run(10) + heartbeat2.run(10) + factory.start() + + try: + heartbeat1.stop() + assert not heartbeat1.is_running + assert heartbeat2.is_running + assert mock_run_coroutine.call_count == 2 + finally: + factory.stop() diff --git a/tests/gateway/test_opcode_handlers.py b/tests/gateway/test_opcode_handlers.py new file mode 100644 index 0000000..8f8ee45 --- /dev/null +++ b/tests/gateway/test_opcode_handlers.py @@ -0,0 +1,152 @@ +import logging +from unittest.mock import AsyncMock, Mock + +import pytest + +from asyncord.gateway.client.client import ConnectionData +from asyncord.gateway.client.opcode_handlers import ( + DispatchHandler, + HeartbeatAckHandler, + HelloHandler, + InvalidSessionHandler, + ReconnectHandler, +) +from asyncord.gateway.commands import IdentifyCommand, ResumeCommand +from asyncord.gateway.events.base import ReadyEvent +from asyncord.gateway.message import DispatchMessage + + +@pytest.fixture() +def client() -> Mock: + """Return a mock client.""" + client = AsyncMock() + client.conn_data = ConnectionData( + token='token', # noqa: S106 + seq=0, + session_id='session_id', + ) + client.reconnect = Mock() + return client + + +async def test_dispatch_unhandled_event( + client: Mock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test dispatching an unhandled event.""" + handler = DispatchHandler(client, logging.getLogger('asyncord.gateway.client.opcode_handlers')) + message = DispatchMessage(t='unhandled_event', d={}, s=1) # type: ignore + with caplog.at_level(logging.WARNING): + await handler.handle(message) + assert 'Unhandled event: unhandled_event' in caplog.text + + +async def test_dispatch_ready_event(client: Mock) -> None: + """Test dispatching a ready event.""" + handler = DispatchHandler(client, logging.getLogger('asyncord.gateway.client.opcode_handlers')) + event_data = { + 'v': 9, + 'user': { + 'id': '1234567890', + 'username': 'example_user', + 'global_name': 'example_user#1234', + 'discriminator': '1234', + 'avatar': 'example_avatar', + }, + 'guilds': [ + { + 'id': '123', + 'unavailable': True, + }, + { + 'id': '123', + 'unavailable': False, + }, + ], + 'session_id': 'example_session_id', + 'resume_gateway_url': 'example_gateway_url', + 'shard': { + 'shard_id': 0, + 'num_shards': 1, + }, + 'application': { + 'id': '1234567890', + 'flags': 0, + }, + } + + message = DispatchMessage(t=ReadyEvent.__event_name__, d=event_data, s=1) # type: ignore + await handler.handle(message) + assert client.conn_data.seq == 1 + assert client.conn_data.session_id == 'example_session_id' + client.dispatcher.dispatch.assert_called_once() + + +async def test_reconnect_handler_handle(client: Mock) -> None: + """Test handling the RECONNECT opcode.""" + handler = ReconnectHandler(client, Mock()) + await handler.handle(Mock()) + client.reconnect.assert_called_once() + + +async def test_invalid_session_handler_handle(client: Mock) -> None: + """Test handling the INVALID_SESSION opcode.""" + client.conn_data = Mock() + handler = InvalidSessionHandler(client, Mock()) + message = Mock(data=True) + await handler.handle(message) + client.reconnect.assert_called_once() + client.conn_data.reset.assert_not_called() + + +async def test_invalid_session_handler_handle_no_data(client: Mock) -> None: + """Test handling the INVALID_SESSION opcode.""" + client.conn_data = Mock() + handler = InvalidSessionHandler(client, Mock()) + message = Mock(data=False) + await handler.handle(message) + client.reconnect.assert_called_once() + client.conn_data.reset.assert_called_once() + + +async def test_hello_handler_handle_can_resume(client: Mock) -> None: + """Test handling the HELLO opcode.""" + client.heartbeat = Mock() + client.hearbeat.run = AsyncMock() + client.conn_data.seq = 1 + handler = HelloHandler(client, Mock()) + + await handler.handle(Mock()) + client.heartbeat.run.assert_called_once() + client.send_resume.assert_called_once_with( + ResumeCommand( + token=client.conn_data.token, + session_id=client.conn_data.session_id, + seq=client.conn_data.seq, + ), + ) + client.identify.assert_not_called() + + +async def test_hello_handler_handle_cannot_resume(client: Mock) -> None: + """Test handling the HELLO opcode.""" + client.heartbeat = Mock() + client.hearbeat.run = AsyncMock() + handler = HelloHandler(client, Mock()) + + await handler.handle(Mock()) + client.heartbeat.run.assert_called_once() + client.send_resume.assert_not_called() + client.identify.assert_called_once_with( + IdentifyCommand( + token=client.conn_data.token, + intents=client.intents, + ), + ) + + +async def test_heartbeat_ack_handler_handle(client: Mock) -> None: + """Test handling the HEARTBEAT_ACK opcode.""" + handler = HeartbeatAckHandler(client, Mock()) + await handler.handle(Mock()) + client.heartbeat.handle_heartbeat_ack.assert_called_once() diff --git a/tests/integration/client/componenets/test_action_row.py b/tests/integration/client/componenets/test_action_row.py new file mode 100644 index 0000000..2f0bb09 --- /dev/null +++ b/tests/integration/client/componenets/test_action_row.py @@ -0,0 +1,60 @@ +from collections.abc import Sequence + +import pytest + +from asyncord.client.messages.models.requests.components import ( + ActionRow, + Button, + ButtonStyle, + SelectMenu, + SelectMenuOption, +) + + +def test_wrap_component_to_list_in_action_row() -> None: + """Test that components are wrapped in an ActionRow.""" + request = ActionRow([Button(custom_id='button_0', label='Button', style=ButtonStyle.PRIMARY)]) + + assert isinstance(request.components, Sequence) + assert len(request.components) == 1 + assert isinstance(request.components[0], Button) + assert request.components[0].custom_id == 'button_0' + + +def test_components_cannot_be_empty() -> None: + """Test that components cannot be empty.""" + with pytest.raises(ValueError, match='Value should have at least 1 item'): + ActionRow([]) + + +def test_action_row_can_have_max_5_components() -> None: + """Test that an ActionRow can have a maximum of 5 components.""" + # fmt: on + with pytest.raises(ValueError, match='Value should have at most 5 items'): + # fmt: off + ActionRow([ + Button( + custom_id=f'button_{i}', + label=f'Button {i}', + style=ButtonStyle.PRIMARY, + ) + for i in range(6) + ]) + # fmt: on + + +def test_dont_create_message_with_button_and_select_menu() -> None: + """Test that an ActionRow containing a select menu cannot also contain buttons.""" + exc_text = 'ActionRow containing a select menu cannot also contain buttons' + with pytest.raises(ValueError, match=exc_text): + ActionRow( + components=[ + Button(custom_id='custom_id', label='Button'), + SelectMenu( + custom_id='custom', + options=[ + SelectMenuOption(label='Option 1', value='option_1'), + ], + ), + ], + ) diff --git a/tests/integration/client/messages/test_components.py b/tests/integration/client/componenets/test_creation.py similarity index 64% rename from tests/integration/client/messages/test_components.py rename to tests/integration/client/componenets/test_creation.py index 92f6bab..ae1c5f0 100644 --- a/tests/integration/client/messages/test_components.py +++ b/tests/integration/client/componenets/test_creation.py @@ -11,18 +11,53 @@ ActionRow, Button, ButtonStyle, - Component, - SelectMenu, - SelectMenuOption, + ComponentType, ) from asyncord.client.messages.models.requests.messages import CreateMessageRequest, UpdateMessageRequest from asyncord.client.messages.resources import MessageResource from asyncord.client.threads.models.requests import ThreadMessage +class _Container(Protocol): + components: Sequence[ComponentType] | ComponentType | None + + def __call__(self, components: Sequence[ComponentType] | ComponentType | None) -> Self: # type: ignore + """Initialize the container with components.""" + + +@pytest.mark.parametrize( + 'container', + [ + CreateMessageRequest, + UpdateMessageRequest, + InteractionRespMessageRequest, + InteractionRespUpdateMessageRequest, + ThreadMessage, + ], +) +def test_wrap_component_to_list_and_action_row(container: _Container) -> None: + """Test that components are wrapped in an ActionRow.""" + request = container( + components=Button( + custom_id='button_0', + label='Button', + style=ButtonStyle.PRIMARY, + ), + ) + + assert isinstance(request.components, Sequence) + assert len(request.components) == 1 + assert isinstance(request.components[0], ActionRow) + + assert isinstance(request.components[0].components, Sequence) + assert len(request.components[0].components) == 1 + assert not isinstance(request.components[0].components[0], ActionRow) + assert request.components[0].components[0].custom_id == 'button_0' + + async def test_create_message_with_buttons(messages_res: MessageResource) -> None: """Test creating a message with buttons.""" - components: Sequence[Component] = [ + components: Sequence[ComponentType] = [ ActionRow( components=[ Button( @@ -49,7 +84,7 @@ async def test_create_message_with_buttons(messages_res: MessageResource) -> Non label='Link', style=ButtonStyle.LINK, url='https://discord.com', - ), # type: ignore + ), ], ), ] @@ -68,37 +103,26 @@ async def test_create_message_with_buttons(messages_res: MessageResource) -> Non await messages_res.delete(message.id) -def test_dont_create_message_with_button_and_select_menu() -> None: - """Test that an ActionRow containing a select menu cannot also contain buttons.""" - exc_text = 'ActionRow containing a select menu cannot also contain buttons' - with pytest.raises(ValueError, match=exc_text): - ActionRow( - components=[ - Button(custom_id='custom_id', label='Button'), - SelectMenu( - custom_id='custom', - options=[SelectMenuOption(label='Option 1', value='option_1')], # type: ignore - ), - ], - ) - - def test_components_can_be_max_5() -> None: """Test that components can be a maximum of 5.""" # fmt: off components = [ - ActionRow(components=[ - Button(custom_id=f'button_{i}', label=f'Button {i}', style=ButtonStyle.PRIMARY), + ActionRow([ + Button( + custom_id=f'button_{i}', + label=f'Button {i}', + style=ButtonStyle.PRIMARY, + ), ]) for i in range(6) ] # fmt: on with pytest.raises(ValueError, match='Components must have 5 or fewer action rows'): - CreateMessageRequest(components=components) # type: ignore + CreateMessageRequest(components=components) def test_wrap_components_in_action_row() -> None: - """Test that components are wrapped in an ActionRow.""" + """Test that components are wrapped in an ActionRow implicitly.""" # fmt: off components = [ Button(custom_id=f'button_{i}', label=f'Button {i}', style=ButtonStyle.PRIMARY) @@ -106,7 +130,7 @@ def test_wrap_components_in_action_row() -> None: ] # fmt: on - request = CreateMessageRequest(components=components) # type: ignore + request = CreateMessageRequest(components=components) assert isinstance(request.components, Sequence) assert len(request.components) == 1 @@ -119,71 +143,3 @@ def test_wrap_components_in_action_row() -> None: assert request.components[0].components[0].custom_id == 'button_0' assert isinstance(request.components[0].components[4], Button) assert request.components[0].components[4].custom_id == 'button_4' - - -class _Container(Protocol): - components: Sequence[Component] | Component | None - - def __call__(self, components: Sequence[Component] | Component | None) -> Self: - """Initialize the container with components.""" - ... - - -@pytest.mark.parametrize( - 'container', - [ - CreateMessageRequest, - UpdateMessageRequest, - InteractionRespMessageRequest, - InteractionRespUpdateMessageRequest, - ThreadMessage, - ], -) -def test_wrap_component_to_list_and_action_row(container: _Container) -> None: - """Test that components are wrapped in an ActionRow.""" - request = container( - components=Button( - custom_id='button_0', - label='Button', - style=ButtonStyle.PRIMARY, - ), - ) - - assert isinstance(request.components, Sequence) - assert len(request.components) == 1 - assert isinstance(request.components[0], ActionRow) - - assert isinstance(request.components[0].components, Sequence) - assert len(request.components[0].components) == 1 - assert not isinstance(request.components[0].components[0], ActionRow) - assert request.components[0].components[0].custom_id == 'button_0' - - -def test_wrap_component_to_list_in_action_row() -> None: - """Test that components are wrapped in an ActionRow.""" - request = ActionRow(components=[Button(custom_id='button_0', label='Button', style=ButtonStyle.PRIMARY)]) - - assert isinstance(request.components, Sequence) - assert len(request.components) == 1 - assert isinstance(request.components[0], Button) - assert request.components[0].custom_id == 'button_0' - - -def test_components_cannot_be_empty() -> None: - """Test that components cannot be empty.""" - with pytest.raises(ValueError, match='Value should have at least 1 item'): - ActionRow(components=[]) - - -def test_action_row_can_have_max_5_components() -> None: - """Test that an ActionRow can have a maximum of 5 components.""" - # fmt: on - with pytest.raises(ValueError, match='Value should have at most 5 items'): - # fmt: off - ActionRow( - components=[ - Button(custom_id=f'button_{i}', label=f'Button {i}', style=ButtonStyle.PRIMARY) - for i in range(6) - ], - ) - # fmt: on diff --git a/tests/integration/client/componenets/test_emoji.py b/tests/integration/client/componenets/test_emoji.py new file mode 100644 index 0000000..10ca92a --- /dev/null +++ b/tests/integration/client/componenets/test_emoji.py @@ -0,0 +1,17 @@ +import pytest + +from asyncord.client.messages.models.requests.components import ( + ComponentEmoji, +) + + +def test_emoji_fail_with_both_name_and_id() -> None: + """Test that an emoji can contain a name or an id.""" + with pytest.raises(ValueError, match='Only one of'): + ComponentEmoji(name='emoji_name', id=1241) + + +def test_emoji_fail_with_no_name_and_id() -> None: + """Test that an emoji must contain a name or an id.""" + with pytest.raises(ValueError, match='least one of'): + ComponentEmoji() diff --git a/tests/integration/client/conftest.py b/tests/integration/client/conftest.py new file mode 100644 index 0000000..1c2c443 --- /dev/null +++ b/tests/integration/client/conftest.py @@ -0,0 +1,32 @@ +import datetime +from collections.abc import AsyncGenerator + +import pytest + +from asyncord.client.scheduled_events.models.common import ( + EventEntityType, + EventPrivacyLevel, +) +from asyncord.client.scheduled_events.models.requests import ( + CreateScheduledEventRequest, + EventEntityMetadata, +) +from asyncord.client.scheduled_events.models.responses import ScheduledEventResponse +from asyncord.client.scheduled_events.resources import ScheduledEventsResource + + +@pytest.fixture() +async def event(events_res: ScheduledEventsResource) -> AsyncGenerator[ScheduledEventResponse, None]: + """Fixture that creates a scheduled event and deletes it after the test.""" + creation_data = CreateScheduledEventRequest( + entity_type=EventEntityType.EXTERNAL, + name='Test Event', + description='This is a test event.', + entity_metadata=EventEntityMetadata(location='https://example.com'), + privacy_level=EventPrivacyLevel.GUILD_ONLY, + scheduled_start_time=datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=1), + scheduled_end_time=datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), + ) + event = await events_res.create(creation_data) + yield event + await events_res.delete(event.id) diff --git a/tests/integration/client/messages/test_attachments.py b/tests/integration/client/messages/test_attachments.py index d594b8c..a5b61ea 100644 --- a/tests/integration/client/messages/test_attachments.py +++ b/tests/integration/client/messages/test_attachments.py @@ -10,12 +10,12 @@ EmbedThumbnail, ) from asyncord.client.messages.models.requests.messages import ( - Attachment, CreateMessageRequest, UpdateMessageRequest, ) from asyncord.client.messages.models.responses.messages import MessageResponse from asyncord.client.messages.resources import MessageResource +from asyncord.client.models.attachments import Attachment TEST_FILE_NAMES = ['test_image_1.png', 'test_image_2.png'] diff --git a/tests/integration/client/messages/test_messages.py b/tests/integration/client/messages/test_messages.py index 4173a03..91e8298 100644 --- a/tests/integration/client/messages/test_messages.py +++ b/tests/integration/client/messages/test_messages.py @@ -1,49 +1,121 @@ +import contextlib +import itertools import random import string -from typing import Literal +import warnings +from collections.abc import AsyncGenerator import pytest +from asyncord.client.channels.models.responses import ChannelResponse +from asyncord.client.channels.resources import ChannelResource +from asyncord.client.http.errors import NotFoundError from asyncord.client.messages.models.requests.messages import CreateMessageRequest, UpdateMessageRequest from asyncord.client.messages.models.responses.messages import MessageResponse from asyncord.client.messages.resources import MessageResource -from asyncord.snowflake import SnowflakeInputType from tests.conftest import IntegrationTestData +@pytest.fixture() +async def multiple_messages(messages_res: MessageResource) -> AsyncGenerator[list[MessageResponse], None]: + """Create multiple messages for testing bulk delete. + + In general, all messages must be deleted after the test is done. But we check it + in the teardown fixture to make sure that the messages are deleted even if the test fails. + """ + messages = [] + for _ in range(3): + message = await messages_res.create(CreateMessageRequest(content='test')) + messages.append(message) + + yield messages + + for message in messages: + try: + with contextlib.suppress(NotFoundError) as err: + await messages_res.delete(message.id) + except Exception as err: + warnings.warn(f'Error deleting message: {err}', stacklevel=2) + + @pytest.mark.parametrize( - ('around', 'before', 'after', 'limit'), + 'limit', + [None, 3], +) +@pytest.mark.parametrize( + ('filter_name', 'filter_value'), [ - ('message_id', None, None, 3), - (None, 'message_id', None, 1), - (None, None, 'message_id', 1), + ('around', 'message_id'), + ('before', 'message_id'), + ('after', 'message_id'), ], ) -async def test_get_channel_messages( +async def test_get_channel_messages_with_filter( + filter_name: str, + filter_value: str, + limit: int, messages_res: MessageResource, integration_data: IntegrationTestData, - around: SnowflakeInputType | Literal['message_id'], - before: SnowflakeInputType | Literal['message_id'], - after: SnowflakeInputType | Literal['message_id'], +) -> None: + """Test getting messages from a channel with individual filters and a limit.""" + filter_value = getattr(integration_data, filter_value) + + messages = await messages_res.get(**{filter_name: filter_value, 'limit': limit}) + + assert len(messages) > 0 + if limit: + assert len(messages) <= limit + + message_ids = [message.id for message in messages] + if filter_name == 'around': + assert filter_value in message_ids + else: + assert filter_value not in message_ids + + +# Three filter combination +_three_filter_combination = [{'around': 'message_id', 'before': 'message_id', 'after': 'message_id'}] + + +@pytest.mark.parametrize( + 'limit', + [None, 3], +) # limit should not affect the test, but it's good to test it +@pytest.mark.parametrize( + 'filters', + # Generate all unique combinations of filters without repetitions + [ + {combo[0]: 'message_id', combo[1]: 'message_id', 'none': None} + for combo in itertools.combinations(['around', 'before', 'after'], 2) + ] + + _three_filter_combination, +) +async def test_get_channel_messages_multiple_filters_error( limit: int, + filters: dict, + messages_res: MessageResource, + integration_data: IntegrationTestData, ) -> None: - """Test getting messages from a channel.""" - if around: - around = integration_data.message_id - if before: - before = integration_data.message_id - if after: - after = integration_data.message_id - - messages = await messages_res.get( - around=around, - before=before, - after=after, - limit=limit, - ) - assert len(messages) == limit - if around: - assert messages[1].id == around + """Test that using more than one filter together raises a ValueError.""" + # Convert filter values from string to actual message ID from integration_data + + prepared_filters = {} + for key in filters: # noqa: PLC0206 + if filters[key] is None: + continue + prepared_filters[key] = getattr(integration_data, filters[key]) + + with pytest.raises(ValueError, match='Only one of around, before, after can be specified'): + await messages_res.get(**prepared_filters, limit=limit) + + +async def test_error_on_multiple_message_filters(messages_res: MessageResource) -> None: + """Test that an error is raised when multiple filters are used.""" + with pytest.raises( + ValueError, + match='Only one of around, before, after can be specified', + ): + await messages_res.get(around=1, before=1) async def test_create_and_delete_simple_message(messages_res: MessageResource) -> None: @@ -94,3 +166,29 @@ async def test_message_pin_unpin_flow( # Check that the message is no longer pinned pins_after_unpin = await messages_res.get_pinned_messages(integration_data.channel_id) assert message.id not in [msg.id for msg in pins_after_unpin] + + +async def test_bulk_delete_messages( + multiple_messages: list[MessageResponse], + messages_res: MessageResource, +) -> None: + """Test bulk deleting messages.""" + message_ids = [message.id for message in multiple_messages] + await messages_res.bulk_delete(message_ids) + + +async def test_crosspost_message( + announcement_channel: ChannelResponse, + channel_res: ChannelResource, +) -> None: + """Test crossposting a message.""" + messages_res = channel_res.messages(announcement_channel.id) + message = await messages_res.create(CreateMessageRequest(content='test')) + + crossposted_message = await messages_res.crosspost_message(message.id) + + messages = await messages_res.get(around=message.id, limit=10) + assert crossposted_message.id in [msg.id for msg in messages] + + assert crossposted_message.id == message.id + assert crossposted_message.content == message.content diff --git a/tests/integration/client/test_applications.py b/tests/integration/client/test_applications.py index e523dc7..5a7c381 100644 --- a/tests/integration/client/test_applications.py +++ b/tests/integration/client/test_applications.py @@ -2,50 +2,97 @@ import pytest -from asyncord.client.applications.resources import ApplicationResource +from asyncord.client.applications.models.common import ApplicationFlag from asyncord.client.applications.models.requests import ( UpdateApplicationRequest, - UpdateApplicationRoleConnectionMetadataRequest, ) -from asyncord.client.http.errors import ClientError +from asyncord.client.applications.resources import ApplicationResource from tests.conftest import IntegrationTestData -@pytest.mark.skip(reason='Dangerous operation. Needs manual control.') -async def test_update_application( - app_managment: ApplicationResource, - integration_data: IntegrationTestData, -) -> None: +async def test_update_application(applications_res: ApplicationResource) -> None: """Test update application. This test is skipped by default because I have not enough friends to test this. """ - app = await app_managment.update_application( + app = await applications_res.get_application() + presaved_description = app.description + presaved_tags = app.tags + + new_description = 'This is a test description' + new_tags = {'test', 'tags'} + app = await applications_res.update_application( + UpdateApplicationRequest( + description=new_description, + tags=new_tags, + ), + ) + try: + assert app.description == new_description + assert app.tags == new_tags + + finally: + # Reset the description + app = await applications_res.update_application( UpdateApplicationRequest( - description='This is a test description.', - ) + description=presaved_description, + tags=presaved_tags, + ), ) - - assert app.description == 'This is a test description.' + + assert app.description == presaved_description + assert app.tags == presaved_tags -async def get_application( - app_managment: ApplicationResource, +async def test_get_application( + applications_res: ApplicationResource, integration_data: IntegrationTestData, -) -> None: +) -> None: """Test get application.""" - app = await app_managment.get_application() + app = await applications_res.get_application() assert app.id == integration_data.app_id -async def get_application_role_metadata( - app_managment: ApplicationResource, +async def test_get_application_role_metadata( + applications_res: ApplicationResource, integration_data: IntegrationTestData, -) -> None: +) -> None: """Test get application role connection metadata.""" - metadata_records = await app_managment.get_application_role_connection_metadata_records( - integration_data.app_id + metadata_records = await applications_res.get_application_role_connection_metadata_records( + integration_data.app_id, ) - assert metadata_records \ No newline at end of file + assert isinstance(metadata_records, list) + + +@pytest.mark.parametrize( + ('flags', 'should_fail'), + [ + pytest.param(None, False, id='no_flags'), + pytest.param(ApplicationFlag.GATEWAY_PRESENCE_LIMITED, False, id='valid_flag'), + pytest.param( + ApplicationFlag.GATEWAY_MESSAGE_CONTENT_LIMITED | ApplicationFlag.GATEWAY_PRESENCE_LIMITED, + False, + id='valid_flags', + ), + pytest.param(ApplicationFlag(ApplicationFlag.EMBEDDED), True, id='invalid_flag'), + pytest.param(ApplicationFlag.EMBEDDED | ApplicationFlag.APPLICATION_COMMAND_BADGE, True, id='invalid_flags'), + pytest.param( + ApplicationFlag.GATEWAY_PRESENCE_LIMITED | ApplicationFlag.EMBEDDED, + True, + id='invalid_with_valid_flag', + ), + ], +) +async def test_application_flag_validatior(flags: ApplicationFlag | None, should_fail: bool) -> None: + """Test application flag validator.""" + if should_fail: + with pytest.raises(ValueError, match='Invalid flag'): + UpdateApplicationRequest( + flags=flags, + ) + else: + assert UpdateApplicationRequest( + flags=flags, + ) diff --git a/tests/integration/client/test_bans.py b/tests/integration/client/test_bans.py index 26cf667..8be2d42 100644 --- a/tests/integration/client/test_bans.py +++ b/tests/integration/client/test_bans.py @@ -3,12 +3,13 @@ import pytest from asyncord.client.bans.resources import BanResource +from asyncord.client.http import errors from asyncord.client.http.errors import ClientError from tests.conftest import IntegrationTestData -@pytest.mark.skip(reason='Dangerous operation. Needs manual control.') -async def test_ban_managment( +@pytest.mark.skip(reason='Dangerous operation. Needs manual control') +async def test_ban_cycle( ban_managment: BanResource, integration_data: IntegrationTestData, ) -> None: @@ -40,13 +41,64 @@ async def test_ban_managment( assert not await ban_managment.get_list() # test bulk ban and then unban + + +@pytest.mark.skip(reason='Dangerous operation. Needs manual control') +async def test_bulk_ban(ban_managment: BanResource, integration_data: IntegrationTestData) -> None: + """Test the bulk ban method.""" await ban_managment.bulk_ban([integration_data.user_to_ban]) bans = await ban_managment.get_list() - assert len(bans) == 1 + assert integration_data.user_to_ban in [ban.user.id for ban in bans] await ban_managment.unban(integration_data.user_to_ban) -async def test_get_ban_list(ban_managment: BanResource) -> None: +@pytest.mark.parametrize('delete_message_secs', [None, 1]) +async def test_ban_unknown_user(delete_message_secs: int, ban_managment: BanResource) -> None: + """Test banning an unknown user. + + It's easier to test this with a known user, but someones can be banned. + """ + with pytest.raises(errors.NotFoundError): + await ban_managment.ban( + user_id=12232323232, + delete_message_seconds=delete_message_secs, + ) + + +async def test_unban_unknown_user(ban_managment: BanResource) -> None: + """Test unbanning an unknown user.""" + with pytest.raises(errors.NotFoundError): + await ban_managment.unban(12232323232) + + +@pytest.mark.parametrize('delete_message_secs', [None, 1]) +async def test_empty_bulk_ban(delete_message_secs: int, ban_managment: BanResource) -> None: + """Test the bulk ban method with empty list.""" + await ban_managment.bulk_ban( + user_ids=[], + delete_message_seconds=delete_message_secs, + ) + + +async def test_get_ban_of_unknown_user(ban_managment: BanResource) -> None: + """Test getting a ban of an unknown user.""" + with pytest.raises(errors.NotFoundError): + await ban_managment.get(12232323232) + + +@pytest.mark.parametrize('after', [None, 12345]) +@pytest.mark.parametrize('before', [None, 12345]) +@pytest.mark.parametrize('limit', [None, 1]) +async def test_get_ban_list( + limit: int | None, + before: int | None, + after: int | None, + ban_managment: BanResource, +) -> None: """Test the get ban list method.""" - bans = await ban_managment.get_list() + bans = await ban_managment.get_list( + limit=limit, + before=before, + after=after, + ) assert isinstance(bans, list) diff --git a/tests/integration/client/test_channels.py b/tests/integration/client/test_channels.py index f825116..34fdcfa 100644 --- a/tests/integration/client/test_channels.py +++ b/tests/integration/client/test_channels.py @@ -1,7 +1,10 @@ +from collections.abc import AsyncGenerator + import pytest from asyncord.client.channels.models.common import ChannelType from asyncord.client.channels.models.requests.creation import ( + ChannelInviteRequest, CreateAnoncementChannelRequest, CreateCategoryChannelRequest, CreateChannelRequestType, @@ -11,9 +14,14 @@ CreateTextChannelRequest, CreateVoiceChannelRequest, ) -from asyncord.client.channels.models.requests.updating import UpdateChannelPositionRequest, UpdateTextChannelRequest +from asyncord.client.channels.models.requests.updating import ( + UpdateChannelPermissionsRequest, + UpdateChannelPositionRequest, + UpdateTextChannelRequest, +) +from asyncord.client.channels.models.responses import ChannelResponse from asyncord.client.channels.resources import ChannelResource -from asyncord.client.http.errors import ClientError +from asyncord.client.models.permissions import PermissionFlag from tests.conftest import IntegrationTestData CHANNEL_NAME = 'test' @@ -23,6 +31,22 @@ ) +@pytest.fixture() +async def channel( + channel_res: ChannelResource, + integration_data: IntegrationTestData, +) -> AsyncGenerator[ChannelResponse, None]: + """Create a channel for testing.""" + channel = await channel_res.create_channel( + guild_id=integration_data.guild_id, + channel_data=CreateTextChannelRequest(name='test'), + ) + + yield channel + + await channel_res.delete(channel.id) + + @pytest.mark.parametrize( 'channel_input', [ @@ -68,8 +92,6 @@ async def test_create_and_delete_channel( assert channel.name == CHANNEL_NAME await channel_res.delete(channel.id) - with pytest.raises(ClientError, match='Unknown Channel'): - await channel_res.get(channel.id) async def test_create_subchannel( @@ -84,27 +106,32 @@ async def test_create_subchannel( position=999, ), ) - - text_chan = await channel_res.create_channel( - integration_data.guild_id, - channel_data=CreateTextChannelRequest( - name='test text subchannel', - parent_id=category.id, - rate_limit_per_user=2, - ), # type: ignore - ) - voice_chan = await channel_res.create_channel( - integration_data.guild_id, - channel_data=CreateVoiceChannelRequest( - name='test voice subchannel', - parent_id=category.id, - bitrate=96000, - ), # type: ignore - ) - - # Delete channels - for channel_id in {text_chan.id, voice_chan.id, category.id}: - await channel_res.delete(channel_id) + text_chan_id = None + voice_chan_id = None + try: + text_chan = await channel_res.create_channel( + integration_data.guild_id, + channel_data=CreateTextChannelRequest( + name='test text subchannel', + parent_id=category.id, + rate_limit_per_user=2, + ), # type: ignore + ) + text_chan_id = text_chan.id + voice_chan = await channel_res.create_channel( + integration_data.guild_id, + channel_data=CreateVoiceChannelRequest( + name='test voice subchannel', + parent_id=category.id, + bitrate=96000, + ), # type: ignore + ) + voice_chan_id = voice_chan.id + assert text_chan.parent_id == category.id + finally: + # Delete channels + for channel_id in {text_chan_id, voice_chan_id, category.id}: + await channel_res.delete(channel_id) async def test_get_channel( @@ -127,6 +154,23 @@ async def test_get_channel_invites( assert isinstance(invites, list) +@pytest.mark.parametrize( + 'invite_data', + [ + None, + ChannelInviteRequest(max_age=60, max_uses=1), + ], +) +async def test_create_channel_invite( + invite_data: ChannelInviteRequest | None, + stage_channel: ChannelResponse, + channel_res: ChannelResource, +) -> None: + """Test creating a channel invite.""" + invite = await channel_res.create_channel_invite(stage_channel.id, invite_data) + assert invite.code + + async def test_trigger_typping_indicator( channel_res: ChannelResource, integration_data: IntegrationTestData, @@ -135,57 +179,80 @@ async def test_trigger_typping_indicator( await channel_res.trigger_typing_indicator(integration_data.channel_id) -@pytest.mark.limited() async def test_update_channel_position( + stage_channel: ChannelResponse, channel_res: ChannelResource, integration_data: IntegrationTestData, ) -> None: """Test updating channel position.""" - channel = await channel_res.get(integration_data.channel_id) - position = channel.position - await channel_res.update_channel_position( integration_data.guild_id, [ UpdateChannelPositionRequest( - id=integration_data.channel_id, - position=channel.position + 1, + id=stage_channel.id, + position=stage_channel.position + 1, # type: ignore ), ], ) - updated_channel = await channel_res.get(integration_data.channel_id) + updated_channel = await channel_res.get(stage_channel.id) + assert updated_channel.position == stage_channel.position + 1 # type: ignore - assert channel.position == updated_channel.position - 1 - await channel_res.update_channel_position( - integration_data.guild_id, - [ - UpdateChannelPositionRequest( - id=integration_data.channel_id, - position=position, - ), - ], - ) - - -@pytest.mark.limited() async def test_update_channel( + stage_channel: ChannelResponse, channel_res: ChannelResource, - integration_data: IntegrationTestData, ) -> None: """Test updating a channel.""" - preserved_name = (await channel_res.get(integration_data.channel_id)).name + assert stage_channel.name != 'test' channel = await channel_res.update( - integration_data.channel_id, + stage_channel.id, UpdateTextChannelRequest(name='test'), # type: ignore ) - assert channel.id == integration_data.channel_id + assert channel.name == 'test' - channel = await channel_res.update( - integration_data.channel_id, - UpdateTextChannelRequest(name=preserved_name), # type: ignore + +async def test_permissions_lifecycle( + stage_channel: ChannelResponse, + channel_res: ChannelResource, + integration_data: IntegrationTestData, +) -> None: + """Test full lifecycle of channel permissions.""" + allowed_permissions = PermissionFlag.VIEW_AUDIT_LOG | PermissionFlag.SEND_MESSAGES + await channel_res.update_permissions( + channel_id=stage_channel.id, + role_or_user_id=integration_data.role_id, + permission_data=UpdateChannelPermissionsRequest( + type='role', + allow=allowed_permissions, + deny=PermissionFlag.USE_APPLICATION_COMMANDS, + ), + ) + + update_channel = await channel_res.get(stage_channel.id) + assert update_channel.permission_overwrites + assert update_channel.permission_overwrites[0].allow == allowed_permissions + + await channel_res.delete_permission( + channel_id=stage_channel.id, + role_or_user_id=integration_data.role_id, + ) + + update_channel = await channel_res.get(stage_channel.id) + assert not update_channel.permission_overwrites + + +async def test_follow_announcement_channel( + announcement_channel: ChannelResponse, + channel: ChannelResponse, + channel_res: ChannelResource, +) -> None: + """Test following an announcement channel.""" + followed_chan_resp = await channel_res.follow_announcement_channel( + channel_id=announcement_channel.id, + target_channel_id=channel.id, ) - assert channel.name == preserved_name + assert followed_chan_resp.webhook_id + assert followed_chan_resp.channel_id == announcement_channel.id diff --git a/tests/integration/client/test_emoji.py b/tests/integration/client/test_emoji.py index 753f622..d668fc6 100644 --- a/tests/integration/client/test_emoji.py +++ b/tests/integration/client/test_emoji.py @@ -1,56 +1,56 @@ -from asyncord.client.emojis.resources import EmojiResource +from collections.abc import AsyncGenerator +from pathlib import Path + +import pytest from asyncord.client.emojis.models.requests import CreateEmojiRequest, UpdateEmojiRequest +from asyncord.client.emojis.models.responses import EmojiResponse +from asyncord.client.emojis.resources import EmojiResource +EMOJI_PATH = Path('tests/data/test_emoji.png') -from tests.conftest import IntegrationTestData -TEST_EMOJI = { - 'name': 'test_emoji', - 'path': 'test_emoji.png', -} +@pytest.fixture() +async def emoji(emoji_res: EmojiResource) -> AsyncGenerator[EmojiResponse, None]: + """Fixture to create a guild emojim and remove it after the test.""" + emoji = await emoji_res.create_guild_emoji( + CreateEmojiRequest( + name='test_emoji', + image=EMOJI_PATH, + ), + ) + yield emoji + await emoji_res.delete_guild_emoji(emoji.id) # type: ignore (id should be set after creation) -async def test_guild_emoji_lifecycle( +async def test_get_guild_emoji( + emoji: EmojiResponse, emoji_res: EmojiResource, ) -> None: - """Test the lifecycle (create, get, modify, delete) of a guild emoji.""" - with open(f'tests/data/{TEST_EMOJI['path']}', 'rb') as f: - emoji_data = f.read() + """Test getting a guild emoji.""" + emoji = await emoji_res.get_guild_emoji(emoji.id) # type: ignore + assert emoji.id == emoji.id + assert emoji.name == emoji.name - # Check initial state - initial_emojis = await emoji_res.get_guild_emojis() - assert TEST_EMOJI['name'] not in [emoji.name for emoji in initial_emojis] - # Create the emoji - emoji = await emoji_res.create_guild_emoji( - CreateEmojiRequest( - name=TEST_EMOJI['name'], - image=emoji_data, - ), - ) - assert emoji.name == TEST_EMOJI['name'] +async def test_get_guild_emojis( + emoji: EmojiResponse, + emoji_res: EmojiResource, +) -> None: + """Test getting all guild emojis.""" + emojis = await emoji_res.get_guild_emojis() + assert emoji.id in [emoji.id for emoji in emojis] - # Check that the emoji exists in the guild - emojis_after_creation = await emoji_res.get_guild_emojis() - assert TEST_EMOJI['name'] in [emoji.name for emoji in emojis_after_creation] - # Modify the emoji +async def test_update_guild_emoji( + emoji: EmojiResponse, + emoji_res: EmojiResource, +) -> None: + """Test updating a guild emoji.""" updated_emoji = await emoji_res.update_guild_emoji( - emoji.id, + emoji.id, # type: ignore UpdateEmojiRequest( - name=f'{TEST_EMOJI['name']}_updated', + name='updated_emoji', ), ) - assert TEST_EMOJI['name'] != updated_emoji.name - - # Check that the updated emoji exists in the guild - emojis_after_modification = await emoji_res.get_guild_emojis() - assert f'{TEST_EMOJI['name']}_updated' in [emoji.name for emoji in emojis_after_modification] - - # Delete the emoji - await emoji_res.delete_guild_emoji(emoji.id) - - # Check that the emoji no longer exists in the guild - emojis_after_deletion = await emoji_res.get_guild_emojis() - assert f'{TEST_EMOJI['name']}_updated' not in [emoji.name for emoji in emojis_after_deletion] + assert updated_emoji.name == 'updated_emoji' diff --git a/tests/integration/client/test_guild_templates.py b/tests/integration/client/test_guild_templates.py index eaadf9d..5a03c85 100644 --- a/tests/integration/client/test_guild_templates.py +++ b/tests/integration/client/test_guild_templates.py @@ -1,44 +1,76 @@ +from collections.abc import AsyncGenerator + +import pytest + from asyncord.client.guild_templates.models.requests import ( + CreateGuildFromTemplateRequest, CreateGuildTemplateRequest, UpdateGuildTemplateRequest, ) +from asyncord.client.guild_templates.models.responses import GuildTemplateResponse from asyncord.client.guild_templates.resources import GuildTemplatesResource -from tests.conftest import IntegrationTestData +from asyncord.client.guilds.resources import GuildResource -async def test_guild_template_cycle( +@pytest.fixture() +async def guild_templates( guild_templates_res: GuildTemplatesResource, - integration_data: IntegrationTestData, -) -> None: - """Test getting a guild template. - - Doesn't test the guild creation from template. - """ - created_template = await guild_templates_res.create_guild_template( +) -> AsyncGenerator[GuildTemplateResponse, None]: + """Create a guild template and delete it after the test.""" + template = await guild_templates_res.create_guild_template( CreateGuildTemplateRequest( - name='test-template', - description='test template description', + name='TestTemplate', + description='Test template description', ), ) + yield template + await guild_templates_res.delete_guild_template(template.code) + + +async def test_get_template( + guild_templates: GuildTemplateResponse, + guild_templates_res: GuildTemplatesResource, +) -> None: + """Test getting a guild template.""" + retrieved_template = await guild_templates_res.get_template(guild_templates.code) + assert retrieved_template.name == 'TestTemplate' - assert created_template.name == 'test-template' +async def test_update_template( + guild_templates: GuildTemplateResponse, + guild_templates_res: GuildTemplatesResource, +) -> None: + """Test updating a guild template.""" updated_template = await guild_templates_res.update_guild_template( - created_template.code, + guild_templates.code, UpdateGuildTemplateRequest( - name='updated-test-template', - description='updated test template description', + name='UpdatedTemplate', + description='Updated template description', ), ) + assert updated_template.name == 'UpdatedTemplate' - requested_template = await guild_templates_res.get_template( - created_template.code, - ) - assert requested_template.name == updated_template.name +async def test_sync_template( + guild_templates: GuildTemplateResponse, + guild_templates_res: GuildTemplatesResource, +) -> None: + """Test syncing a guild template.""" + assert await guild_templates_res.sync_guild_template(guild_templates.code) - templates = await guild_templates_res.get_guild_templates() - assert templates +async def test_create_guild_from_template( + guild_templates: GuildTemplateResponse, + guild_templates_res: GuildTemplatesResource, + guilds_res: GuildResource, +) -> None: + """Test creating a guild from a template.""" + guild = await guild_templates_res.create_guild_from_template( + guild_templates.code, + CreateGuildFromTemplateRequest(name='TestGuild'), + ) - await guild_templates_res.delete_guild_template(updated_template.code) + try: + assert guild.owner_id == guild_templates.creator_id + finally: + await guilds_res.delete(guild.id) diff --git a/tests/integration/client/test_guilds.py b/tests/integration/client/test_guilds.py index bb7aff1..7ed55e0 100644 --- a/tests/integration/client/test_guilds.py +++ b/tests/integration/client/test_guilds.py @@ -1,13 +1,22 @@ +import datetime import random import pytest +from asyncord.client.guilds.models.common import MFALevel, OnboardingMode, OnboardingPromptType, WidgetStyleOptions from asyncord.client.guilds.models.requests import ( CreateAutoModerationRuleRequest, CreateGuildRequest, + OnboardingPrompt, + OnboardingPromptOption, + PruneRequest, + UpdateAutoModerationRuleRequest, + UpdateOnboardingRequest, + UpdateWelcomeScreenRequest, UpdateWidgetSettingsRequest, ) from asyncord.client.guilds.resources import GuildResource +from asyncord.client.http import errors from asyncord.client.models.automoderation import ( AutoModerationRuleEventType, RuleAction, @@ -62,15 +71,42 @@ async def test_create_delete_guild( await guilds_res.delete(guild.id) +@pytest.mark.parametrize('days', [None, 1]) +@pytest.mark.parametrize('is_include_roles', [False, True]) async def test_get_prune_count( + days: int | None, + is_include_roles: bool, guilds_res: GuildResource, integration_data: IntegrationTestData, ) -> None: """Test getting the prune count.""" - prune_count = await guilds_res.get_prune_count(integration_data.guild_id) + if is_include_roles: + include_roles = [integration_data.role_to_prune] + else: + include_roles = None + prune_count = await guilds_res.get_prune_count(integration_data.guild_id, days=days, include_roles=include_roles) assert prune_count.pruned is not None +@pytest.mark.limited() +async def test_begin_prune( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test beginning a prune.""" + include_roles = [integration_data.role_to_prune] + + prune_obj = await guilds_res.begin_prune( + guild_id=integration_data.guild_id, + prune_data=PruneRequest( + days=1, + compute_prune_count=True, + include_roles=include_roles, + ), + ) + assert prune_obj.pruned is not None + + async def test_get_voice_regions( guilds_res: GuildResource, integration_data: IntegrationTestData, @@ -105,19 +141,41 @@ async def test_get_integrations( assert await guilds_res.get_integrations(integration_data.guild_id) +@pytest.mark.parametrize('is_user_id', [False, True]) +@pytest.mark.parametrize('limit', [None, 10]) async def test_get_audit_log( + is_user_id: bool, + limit: int | None, guilds_res: GuildResource, integration_data: IntegrationTestData, ) -> None: """Test getting the audit log.""" - assert await guilds_res.get_audit_log(integration_data.guild_id) + query_params = {} + if is_user_id is not None: + query_params['user_id'] = integration_data.user_id + + if limit is not None: + query_params['limit'] = limit + assert await guilds_res.get_audit_log(integration_data.guild_id, **query_params) + + +async def test_get_list_auto_moderation_rules( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test getting the auto moderation rules.""" + rules = await guilds_res.get_list_auto_moderation_rules( + integration_data.guild_id, + ) + assert isinstance(rules, list) -async def test_create_get_delete_auto_moderation_rule( +async def test_auto_moderation_rule_management( guilds_res: GuildResource, integration_data: IntegrationTestData, ) -> None: """Test creating an auto moderation rule.""" + # create rule rule = await guilds_res.create_auto_moderation_rule( integration_data.guild_id, CreateAutoModerationRuleRequest( @@ -135,29 +193,31 @@ async def test_create_get_delete_auto_moderation_rule( enabled=True, ), ) + assert rule.id - assert rule - - rule_response = await guilds_res.get_auto_moderation_rule( + # get rule + assert await guilds_res.get_auto_moderation_rule( integration_data.guild_id, rule.id, ) - assert rule.name == rule_response.name - - await guilds_res.delete_auto_moderation_rule( - integration_data.guild_id, - rule.id, + # update rule + updated_rule = await guilds_res.update_auto_moderation_rule( + guild_id=integration_data.guild_id, + rule_id=rule.id, + rule=UpdateAutoModerationRuleRequest( + name='Updated Rule', + enabled=False, + ), ) + assert updated_rule.name == 'Updated Rule' + assert not updated_rule.enabled -async def test_get_list_auto_moderation_rules( - guilds_res: GuildResource, - integration_data: IntegrationTestData, -) -> None: - """Test getting the auto moderation rules.""" - assert await guilds_res.get_list_auto_moderation_rules( + # delete rule + await guilds_res.delete_auto_moderation_rule( integration_data.guild_id, + rule.id, ) @@ -166,21 +226,35 @@ async def test_get_update_widget( integration_data: IntegrationTestData, ) -> None: """Test getting and updating the widget.""" - assert await guilds_res.get_widget_settings(integration_data.guild_id) - updated_widget_settings = await guilds_res.update_widget( integration_data.guild_id, widget_data=UpdateWidgetSettingsRequest( enabled=True, ), ) - assert updated_widget_settings.enabled + +async def test_get_widget( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test getting the widget.""" widget = await guilds_res.get_widget(integration_data.guild_id) - assert widget + assert widget.id == integration_data.guild_id - widget_image = await guilds_res.get_widget_image(integration_data.guild_id) + +@pytest.mark.parametrize('style', [None] + [style for style in WidgetStyleOptions]) +async def test_get_widget_image( + style: WidgetStyleOptions | None, + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test getting the widget image.""" + widget_image = await guilds_res.get_widget_image( + integration_data.guild_id, + style=style, + ) assert widget_image.image_data.startswith('data:') @@ -189,14 +263,172 @@ async def test_get_onboarding( integration_data: IntegrationTestData, ) -> None: """Test getting and updating the onboarding.""" - onboarding = await guilds_res.get_onboarding(integration_data.guild_id) - assert onboarding + assert await guilds_res.get_onboarding(integration_data.guild_id) + + +async def test_update_onboarding( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating the onboarding.""" + onboarding = await guilds_res.update_onboarding( + integration_data.guild_id, + UpdateOnboardingRequest( + prompts=[ + OnboardingPrompt( + type=OnboardingPromptType.DROPDOWN, + title='Select a channel to get started!', + options=[ + OnboardingPromptOption( + channel_ids=[integration_data.channel_id], + title='Welcome to the server!', + ), + ], + single_select=False, + required=True, + in_onboarding=True, + ), + OnboardingPrompt( + type=OnboardingPromptType.DROPDOWN, + title='Select a channel to get started 2!', + options=[ + OnboardingPromptOption( + role_ids={integration_data.role_id}, + title='Welcome to the server 2!', + ), + ], + single_select=False, + ), + ], + default_channel_ids=[], + enabled=False, + mode=OnboardingMode.ONBOARDING_DEFAULT, + ), + ) + assert onboarding.enabled is False + assert onboarding.mode == OnboardingMode.ONBOARDING_DEFAULT + assert len(onboarding.prompts) == 2 + assert onboarding.prompts[0].id + + onboarding = await guilds_res.update_onboarding( + integration_data.guild_id, + UpdateOnboardingRequest( + prompts=[], + default_channel_ids=[], + ), + ) + assert not onboarding.prompts + assert not onboarding.default_channel_ids + + +async def test_welcome_screen_management( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating and getting welcome screen.""" + assert await guilds_res.update_welcome_screen( + integration_data.guild_id, + UpdateWelcomeScreenRequest( + enabled=True, + description='Welcome to the server!', + ), + ) + assert await guilds_res.get_welcome_screen(integration_data.guild_id) + + assert await guilds_res.update_welcome_screen( + integration_data.guild_id, + UpdateWelcomeScreenRequest( + enabled=False, + description=None, + ), + ) + + +@pytest.mark.skip('Should be tested when bot connected to voice channel') +@pytest.mark.parametrize('suppress', [True, False]) +@pytest.mark.parametrize('with_request_to_speak_timestamp', [True, False]) +async def test_update_current_user_voice_state( + suppress: bool, + with_request_to_speak_timestamp: bool, + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating the current user's voice state.""" + if with_request_to_speak_timestamp: + request_to_speak_timestamp = datetime.datetime.now(datetime.UTC) + else: + request_to_speak_timestamp = None + + await guilds_res.update_current_user_voice_state( + integration_data.guild_id, + suppress=suppress, + request_to_speak_timestamp=request_to_speak_timestamp, + ) + + await guilds_res.update_current_user_voice_state( + integration_data.guild_id, + suppress=False, + ) + + +async def test_update_current_user_voice_state_without_connection( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating the current user's voice state.""" + with pytest.raises(errors.ClientError, match='Unknown Voice State'): + await guilds_res.update_current_user_voice_state(integration_data.guild_id, suppress=True) + + +async def test_update_user_voice_state( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating the current user's voice state.""" + with pytest.raises(errors.ClientError, match='Unknown Voice State'): + await guilds_res.update_user_voice_state( + integration_data.guild_id, + integration_data.user_id, + suppress=True, + ) + + +async def test_update_mfa_level( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating the MFA level. + + We just check that method and input in general valid. Bot has no enough + permissions to update MFA level. + """ + with pytest.raises(errors.ClientError, match='Missing Access'): + await guilds_res.update_mfa( + integration_data.guild_id, + level=MFALevel.NONE, + ) + + +async def test_delete_integration( + guilds_res: GuildResource, + integration_data: IntegrationTestData, +) -> None: + """Test deleting an integration.""" + with pytest.raises(errors.ClientError, match='Unknown Integration'): + assert await guilds_res.delete_integration( + integration_data.guild_id, + '34124324134', + ) -@pytest.mark.skip('Requires more rights than test bot has.') async def test_get_vanity_url( guilds_res: GuildResource, integration_data: IntegrationTestData, ) -> None: - """Test getting the vanity url.""" - assert await guilds_res.get_vanity_url(integration_data.guild_id) + """Test getting the vanity url. + + We just check that method and input in general valid. Bot has no enough + permissions to get vanity url. + """ + with pytest.raises(errors.ClientError, match='Missing Access'): + assert await guilds_res.get_vanity_url(integration_data.guild_id) diff --git a/tests/integration/client/test_invites.py b/tests/integration/client/test_invites.py index 8b152b0..f92e722 100644 --- a/tests/integration/client/test_invites.py +++ b/tests/integration/client/test_invites.py @@ -1,21 +1,64 @@ +from collections.abc import AsyncIterator + +import pytest + from asyncord.client.channels.resources import ChannelResource +from asyncord.client.guilds.models.responses import InviteResponse from asyncord.client.invites.resources import InvitesResource +from asyncord.client.scheduled_events.models.responses import ScheduledEventResponse from tests.conftest import IntegrationTestData -async def test_create_get_delete_channel_invite( +@pytest.fixture() +async def invite( channel_res: ChannelResource, invite_res: InvitesResource, integration_data: IntegrationTestData, -) -> None: - """Test creating channel invite, get invite, delete invite.""" +) -> AsyncIterator[InviteResponse]: + """Create a channel invite and delete it after the test.""" invite = await channel_res.create_channel_invite(integration_data.channel_id) - assert invite.code + yield invite + await invite_res.delete_invite(invite.code) + +async def test_get_invite( + invite: InviteResponse, + invite_res: InvitesResource, +) -> None: + """Test getting an invite.""" invite_response = await invite_res.get_invite(invite.code) + assert invite_response.code == invite.code + assert str(invite_response.url).endswith(invite.code) + + +async def test_get_invite_for_event( + event: ScheduledEventResponse, + invite: InviteResponse, + invite_res: InvitesResource, +) -> None: + """Test getting an invite for a guild scheduled event.""" + invite_response = await invite_res.get_invite(invite.code, guild_scheduled_event_id=event.id) + assert invite_response.code == invite.code + assert invite_response.guild_scheduled_event + assert invite_response.guild_scheduled_event.id == event.id + +async def test_get_invite_with_expiration( + invite: InviteResponse, + invite_res: InvitesResource, +) -> None: + """Test getting an invite with expiration.""" + invite_response = await invite_res.get_invite(invite.code, with_expiration=True) assert invite_response.code == invite.code + assert invite_response.expires_at - deleted_invite = await invite_res.delete_invite(invite.code) - assert deleted_invite.code == invite.code +async def test_get_invite_with_counts( + invite: InviteResponse, + invite_res: InvitesResource, +) -> None: + """Test getting an invite with counts.""" + invite_response = await invite_res.get_invite(invite.code, with_counts=True) + assert invite_response.code == invite.code + assert invite_response.approximate_presence_count + assert invite_response.approximate_member_count diff --git a/tests/integration/client/test_members.py b/tests/integration/client/test_members.py index 0fb6692..f83c687 100644 --- a/tests/integration/client/test_members.py +++ b/tests/integration/client/test_members.py @@ -1,6 +1,7 @@ import pytest -from asyncord.client.http.errors import ClientError +from asyncord.client.http import errors +from asyncord.client.members.models.requests import UpdateMemberRequest from asyncord.client.members.resources import MemberResource from tests.conftest import IntegrationTestData @@ -15,7 +16,6 @@ async def test_get_member( assert member.user.id == integration_data.member_id -@pytest.mark.skip(reason='Need OAUTH GUILD_MEMBERS permission to access members list') async def test_list_members(members_res: MemberResource) -> None: """Test getting a list of members.""" member_list = await members_res.get_list(limit=1) @@ -24,13 +24,12 @@ async def test_list_members(members_res: MemberResource) -> None: assert member_list[0].user.id -@pytest.mark.skip(reason='Need OAUTH GUILD_MEMBERS permission to access members list') async def test_search_members( members_res: MemberResource, integration_data: IntegrationTestData, ) -> None: """Test searching for a member.""" - member_list = await members_res.search('Cucaryamba') + member_list = await members_res.search(nick_or_name='Cuca') assert len(member_list) == 1 assert member_list[0].user assert member_list[0].user.id == integration_data.member_id @@ -39,33 +38,73 @@ async def test_search_members( async def test_update_current_member(members_res: MemberResource) -> None: """Test updating the current member.""" new_nickname = 'Cucaracha' - member = await members_res.update_current_member(new_nickname) + member = await members_res.update_current_member(nickname=new_nickname) assert member.nick == new_nickname # reset to default nickname - await members_res.update_current_member(None) + await members_res.update_current_member(nickname=None) -async def test_add_and_remove_role( +async def test_role_operations( members_res: MemberResource, integration_data: IntegrationTestData, ) -> None: """Test adding and removing a role.""" - member = await members_res.get(integration_data.member_id) - assert integration_data.role_id not in member.roles - # check role addition - await members_res.add_role(integration_data.member_id, integration_data.role_id) + await members_res.add_role( + user_id=integration_data.member_id, + role_id=integration_data.role_id, + ) member = await members_res.get(integration_data.member_id) assert integration_data.role_id in member.roles # check role removal - await members_res.remove_role(integration_data.member_id, integration_data.role_id) + await members_res.remove_role( + user_id=integration_data.member_id, + role_id=integration_data.role_id, + ) member = await members_res.get(integration_data.member_id) assert integration_data.role_id not in member.roles +@pytest.mark.parametrize( + 'user_id', + [ + '@me', + pytest.param( + 'from_config', + marks=pytest.mark.skip( + reason=("Function should work, but discord doesn't allow changing nicknames"), + ), + ), + ], +) +async def test_update_member( + user_id: str, + members_res: MemberResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating a member.""" + if user_id == 'from_config': + user_id = integration_data.member_id + + new_nickname = 'Cucaracha' + member = await members_res.update( + user_id=user_id, + member_data=UpdateMemberRequest(nick=new_nickname), + ) + assert member.nick == new_nickname + + # reset to default nickname + member = await members_res.update( + user_id=user_id, + member_data=UpdateMemberRequest(nick=None), + ) + + assert member.nick is None + + @pytest.mark.skip(reason='Dangerous operation. Needs manual control.') async def test_kick( members_res: MemberResource, @@ -77,5 +116,8 @@ async def test_kick( assert member.user await members_res.kick(verum_user_id, 'test') - with pytest.raises(ClientError, match=r'\(\d*\) Unknown Member'): - member = await members_res.get(verum_user_id) + +async def test_kick_unkown_user(members_res: MemberResource) -> None: + """Test kicking a member.""" + with pytest.raises(errors.NotFoundError): + await members_res.kick('123451231', 'test') diff --git a/tests/integration/client/test_polls.py b/tests/integration/client/test_polls.py index d9542bc..eb6d315 100644 --- a/tests/integration/client/test_polls.py +++ b/tests/integration/client/test_polls.py @@ -1,45 +1,111 @@ """Contains tests for polls resource.""" +from collections.abc import AsyncGenerator + +import pytest + from asyncord.client.messages.models.requests.messages import CreateMessageRequest +from asyncord.client.messages.models.responses.messages import MessageResponse from asyncord.client.messages.resources import MessageResource from asyncord.client.polls.models.common import PollLayoutType -from asyncord.client.polls.models.requests import PollAnswer, PollMedia, PollRequest +from asyncord.client.polls.models.requests import Answer, Poll, PollEmoji from asyncord.client.polls.resources import PollsResource -# FIXME: The DISCORD API doesn't return answer_id. -# I don't know how to test PollsReqource.get_answer_voters() without it. -async def test_create_and_delete_poll( - messages_res: MessageResource, - polls_res: PollsResource, -) -> None: - """Test creating and deleting a message.""" - message = await messages_res.create( - CreateMessageRequest( - content='test', - poll=PollRequest( - question=PollMedia( - text='test?', +@pytest.fixture() +async def poll_message(messages_res: MessageResource) -> AsyncGenerator[MessageResponse, None]: + """Create a message with a poll. + + After the test, delete the message. + """ + message_with_poll = CreateMessageRequest( + content='test', + poll=Poll( + question='test?', + answers=[ + Answer( + text='test', ), - answers=[ - PollAnswer( - poll_media=PollMedia( - text='test', - ), - ), - PollAnswer( - poll_media=PollMedia( - text='test', - ), + Answer( + text='test', + emoji=PollEmoji( + name='👍', ), - ], - allow_multiselect=True, - layout_type=PollLayoutType.DEFAULT, - ), + ), + ], + allow_multiselect=True, + layout_type=PollLayoutType.DEFAULT, ), ) + message = await messages_res.create(message_with_poll) + yield message + await messages_res.delete(message.id) - await polls_res.end_poll(message.id) - assert message.poll - await messages_res.delete(message.id) +async def test_create_and_delete_message_with_poll(poll_message: MessageResponse) -> None: + """Test creating and deleting a message with a poll.""" + poll = poll_message.poll + + assert poll + assert poll.question == 'test?' + assert len(poll.answers) == 2 + + answers = poll.answers + assert answers[0].answer_id + assert answers[0].poll_media.text + + assert answers[1].answer_id + assert answers[1].poll_media.emoji + assert answers[1].poll_media.emoji.name == '👍' + + assert poll.expiry + + +@pytest.mark.parametrize('after', [None, 10]) +@pytest.mark.parametrize('limit', [None, 10]) +async def test_get_voters_for_answer( + after: int | None, + limit: int | None, + poll_message: MessageResponse, + polls_res: PollsResource, +) -> None: + """Test getting voters for an answer.""" + poll = poll_message.poll + assert poll + + answer_id = poll.answers[0].answer_id + + voters = await polls_res.get_answer_voters( + message_id=poll_message.id, + answer_id=answer_id, + after=after, + limit=limit, + ) + assert not voters.users + + +async def test_end_poll(poll_message: MessageResponse, polls_res: PollsResource) -> None: + """Test ending a poll.""" + poll = poll_message.poll + assert poll + + await polls_res.end_poll(message_id=poll_message.id) + + # We doesn't check finalization because it's not guaranteed that the poll + # will be finalized after ending. + # Read how to pool working at: https://discord.com/developers/docs/resources/poll#poll-results-object + + +async def test_poll_emoji_cant_contain_id_and_name() -> None: + """Test that a poll emoji can't contain an id and a name at the same time.""" + with pytest.raises(ValueError, match='Only one of id or name'): + PollEmoji( + name='👍', + id=123, + ) + + +async def test_poll_emoji_must_contain_id_or_name() -> None: + """Test that a poll emoji must contain an id or a name.""" + with pytest.raises(ValueError, match='Either id or name must be set'): + PollEmoji() diff --git a/tests/integration/client/test_reactions.py b/tests/integration/client/test_reactions.py index 206c1a7..d373fc7 100644 --- a/tests/integration/client/test_reactions.py +++ b/tests/integration/client/test_reactions.py @@ -1,79 +1,85 @@ +import pytest + +from asyncord.client.messages.models.responses.messages import MessageResponse +from asyncord.client.messages.resources import MessageResource from asyncord.client.reactions.resources import ReactionResource from tests.conftest import IntegrationTestData +TEST_EMOJI1 = '👍' +TEST_EMOJI2 = '👎' + + +@pytest.fixture() +async def reactions_res( + message: MessageResponse, + messages_res: MessageResource, +) -> ReactionResource: + """Get reactions resource for the message.""" + resource = messages_res.reactions(message.id) + await resource.add(TEST_EMOJI1) + await resource.add(TEST_EMOJI2) + return resource -async def test_add_and_get_reactions( + +async def test_get_reactions( reactions_res: ReactionResource, integration_data: IntegrationTestData, ) -> None: """Test adding and getting reactions.""" - test_emoji1 = '👍' - test_emoji2 = '👎' + assert (await reactions_res.get(TEST_EMOJI1))[0].id == integration_data.member_id - await reactions_res.add(test_emoji1) - await reactions_res.add(test_emoji2) - assert (await reactions_res.get(test_emoji1))[0].id == integration_data.member_id - assert (await reactions_res.get(test_emoji2))[0].id == integration_data.member_id +async def test_get_reactions_with_after_param( + reactions_res: ReactionResource, + integration_data: IntegrationTestData, +) -> None: + """Test adding and getting reactions with the after parameter.""" + reactions = await reactions_res.get(TEST_EMOJI1, after=integration_data.member_id) + # it should return an empty list because there are no reactions after the bot test member + assert not reactions -async def test_delete_own_reaction(reactions_res: ReactionResource) -> None: - """Test deleting own reaction.""" - test_emoji = '👍' - await reactions_res.add(test_emoji) - assert await reactions_res.get(test_emoji) +async def test_get_reactions_with_limit_param( + reactions_res: ReactionResource, +) -> None: + """Test adding and getting reactions with the limit parameter. - await reactions_res.delete_own_reaction(test_emoji) - assert not await reactions_res.get(test_emoji) + Dummy test to check if the limit parameter is sending in general. + """ + reactions = await reactions_res.get(TEST_EMOJI1, limit=1) + assert len(reactions) == 1 -async def test_delete_user_reaction(reactions_res: ReactionResource, integration_data: IntegrationTestData) -> None: - """Test deleting user reaction.""" - test_emoji = '👍' - await reactions_res.add(test_emoji) - assert await reactions_res.get(test_emoji) +@pytest.mark.parametrize('user_id', [None, 'member_id', '@me']) +async def test_delete_reaction( + user_id: str | None, + reactions_res: ReactionResource, + integration_data: IntegrationTestData, +) -> None: + """Test deleting own reaction.""" + if user_id == 'member_id': + user_id = integration_data.member_id + + await reactions_res.delete(TEST_EMOJI1, user_id=user_id) - await reactions_res.delete(test_emoji, integration_data.member_id) - assert not await reactions_res.get(test_emoji) + assert not await reactions_res.get(TEST_EMOJI1) async def test_delete_all_reactions(reactions_res: ReactionResource) -> None: """Test deleting all reactions.""" - test_emoji1 = '👍' - test_emoji2 = '👎' - - await reactions_res.add(test_emoji1) - await reactions_res.add(test_emoji2) - assert await reactions_res.get(test_emoji1) - assert await reactions_res.get(test_emoji2) - await reactions_res.delete() - assert not await reactions_res.get(test_emoji1) - assert not await reactions_res.get(test_emoji2) + assert not await reactions_res.get(TEST_EMOJI1) + assert not await reactions_res.get(TEST_EMOJI2) async def test_delete_all_reactions_for_emoji(reactions_res: ReactionResource) -> None: """Test deleting all reactions for an emoji.""" - test_emoji1 = '👍' - test_emoji2 = '👎' - - await reactions_res.add(test_emoji1) - await reactions_res.add(test_emoji2) - assert await reactions_res.get(test_emoji1) - assert await reactions_res.get(test_emoji2) + await reactions_res.delete(TEST_EMOJI1) + assert not await reactions_res.get(TEST_EMOJI1) + assert await reactions_res.get(TEST_EMOJI2) - await reactions_res.delete(test_emoji1) - assert not await reactions_res.get(test_emoji1) - assert await reactions_res.get(test_emoji2) - - -async def test_add_and_delete( - reactions_res: ReactionResource, - integration_data: IntegrationTestData, -) -> None: - """Test adding and deleting reactions.""" - await reactions_res.add(integration_data.custom_emoji) - assert await reactions_res.get(integration_data.custom_emoji) - await reactions_res.delete(integration_data.custom_emoji) - assert not await reactions_res.get(integration_data.custom_emoji) +async def test_delete_fail_on_user_id_without_emoji(reactions_res: ReactionResource) -> None: + """Test deleting a reaction with a user id but no emoji.""" + with pytest.raises(ValueError, match='Cannot delete a reaction for a user without an emoji.'): + await reactions_res.delete(user_id='123') diff --git a/tests/integration/client/test_roles.py b/tests/integration/client/test_roles.py index 9af2d44..f6ebe69 100644 --- a/tests/integration/client/test_roles.py +++ b/tests/integration/client/test_roles.py @@ -1,12 +1,17 @@ -from asyncord.client.roles.models.requests import CreateRoleRequest +from collections.abc import AsyncGenerator + +import pytest + +from asyncord.client.roles.models.requests import CreateRoleRequest, RolePositionRequest, UpdateRoleRequest +from asyncord.client.roles.models.responses import RoleResponse from asyncord.client.roles.resources import RoleResource -async def test_create_and_delete_role(roles_res: RoleResource) -> None: - """Test creating and deleting a role.""" - new_role_name = 'TestRole' - role = await roles_res.create(CreateRoleRequest(name=new_role_name)) - assert role.name == new_role_name +@pytest.fixture() +async def role(roles_res: RoleResource) -> AsyncGenerator[RoleResponse, None]: + """Create a role and delete it after the test.""" + role = await roles_res.create(CreateRoleRequest(name='TestRole')) + yield role await roles_res.delete(role.id) @@ -16,3 +21,27 @@ async def test_get_role_list(roles_res: RoleResource) -> None: assert role_list assert role_list[0].id assert role_list[0].name + + +async def test_change_role_positions(role: RoleResponse, roles_res: RoleResource) -> None: + """Test changing the position of a role.""" + roles = await roles_res.change_role_positions( + [ + RolePositionRequest( + id=role.id, + position=2, + ), + ], + ) + + updated_role = next(filter(lambda r: r.id == role.id, roles)) + assert updated_role.position == 2 + + +async def test_update_role(role: RoleResponse, roles_res: RoleResource) -> None: + """Test updating a role.""" + role = await roles_res.update_role( + role_id=role.id, + role_data=UpdateRoleRequest(name='UpdatedRole'), + ) + assert role.name == 'UpdatedRole' diff --git a/tests/integration/client/test_scheduled_events.py b/tests/integration/client/test_scheduled_events.py index e93debf..4babbe1 100644 --- a/tests/integration/client/test_scheduled_events.py +++ b/tests/integration/client/test_scheduled_events.py @@ -1,5 +1,6 @@ +import contextlib import datetime -from collections.abc import AsyncGenerator +from typing import Any import pytest @@ -17,23 +18,6 @@ from tests.conftest import IntegrationTestData -@pytest.fixture() -async def event(events_res: ScheduledEventsResource) -> AsyncGenerator[ScheduledEventResponse, None]: - """Fixture that creates a scheduled event and deletes it after the test.""" - creation_data = CreateScheduledEventRequest( - entity_type=EventEntityType.EXTERNAL, - name='Test Event', - description='This is a test event.', - entity_metadata=EventEntityMetadata(location='https://example.com'), - privacy_level=EventPrivacyLevel.GUILD_ONLY, - scheduled_start_time=datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=1), - scheduled_end_time=datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=1), - ) - event = await events_res.create(creation_data) - yield event - await events_res.delete(event.id) - - @pytest.mark.parametrize('event_type', [EventEntityType.EXTERNAL, EventEntityType.VOICE]) async def test_create_event( events_res: ScheduledEventsResource, @@ -109,3 +93,145 @@ async def test_get_event_users( """Test getting a list of users who have signed up for a scheduled event.""" users = await events_res.get_event_users(event.id) assert isinstance(users, list) + + +@pytest.mark.parametrize( + 'event_type', + [pytest.param(event_type, id=event_type.name) for event_type in EventEntityType], +) +@pytest.mark.parametrize( + 'model_fields', + [ + pytest.param( + {}, + id='without_fields', + ), + pytest.param( + { + 'entity_metadata': {'location': 'https://example.com'}, + }, + id='with_entity_metadata', + ), + pytest.param( + { + 'scheduled_end_time': datetime.datetime.now(datetime.UTC), + }, + id='with_end_time', + ), + pytest.param( + { + 'entity_metadata': {'location': 'https://example.com'}, + 'scheduled_end_time': datetime.datetime.now(datetime.UTC), + }, + id='with_all_necessary_fields', + ), + ], +) +@pytest.mark.parametrize('channel_id', [None, 1234567890]) +async def test_envent_type_validation_on_creation( + event_type: EventEntityType, + model_fields: dict[str, Any], + channel_id: int | None, +) -> None: + """Test entity type validation on creation.""" + entity_metada: dict[str, Any] | None = model_fields.get('entity_metadata') + # fmt: off + has_fields = bool( + entity_metada + and model_fields.get('scheduled_end_time') + and entity_metada.get('location'), + ) + # fmt: on + + err_context = contextlib.nullcontext() + + if event_type is EventEntityType.EXTERNAL: + if not has_fields: + err_context = pytest.raises( + ValueError, + match='EXTERNAL type requires the fields', + ) + + elif not channel_id: + err_context = pytest.raises(ValueError, match='`channel_id` must be set if') + + with err_context: + CreateScheduledEventRequest( + name='Test Event', + privacy_level=EventPrivacyLevel.GUILD_ONLY, + scheduled_start_time=datetime.datetime.now(datetime.UTC), + entity_type=event_type, + channel_id=channel_id, + **model_fields, + ) + + +@pytest.mark.parametrize( + 'event_type', + [pytest.param(event_type, id=event_type.name) for event_type in EventEntityType] + [None], +) +@pytest.mark.parametrize( + 'model_fields', + [ + pytest.param( + {}, + id='without_fields', + ), + pytest.param( + { + 'entity_metadata': {'location': 'https://example.com'}, + }, + id='with_entity_metadata', + ), + pytest.param( + { + 'scheduled_end_time': datetime.datetime.now(datetime.UTC), + }, + id='with_end_time', + ), + pytest.param( + { + 'entity_metadata': {'location': 'https://example.com'}, + 'scheduled_end_time': datetime.datetime.now(datetime.UTC), + }, + id='with_all_necessary_fields', + ), + ], +) +@pytest.mark.parametrize('channel_id', [None, 1234567890]) +async def test_envent_type_validation_on_updating( + event_type: EventEntityType | None, + model_fields: dict[str, Any], + channel_id: int | None, +) -> None: + """Test entity type validation on updating. + + It differs from the creation test because event type can be None for updating. + """ + entity_metada: dict[str, Any] | None = model_fields.get('entity_metadata') + # fmt: off + has_fields = bool( + entity_metada + and model_fields.get('scheduled_end_time') + and entity_metada.get('location'), + ) + # fmt: on + + err_context = contextlib.nullcontext() + + if event_type is EventEntityType.EXTERNAL: + if not has_fields: + err_context = pytest.raises( + ValueError, + match='EXTERNAL type requires the fields', + ) + + elif event_type and not channel_id: + err_context = pytest.raises(ValueError, match='`channel_id` must be set if') + + with err_context: + UpdateScheduledEventRequest( + entity_type=event_type, + channel_id=channel_id, + **model_fields, + ) diff --git a/tests/integration/client/test_stages_instances.py b/tests/integration/client/test_stages_instances.py index cee34a0..ede76db 100644 --- a/tests/integration/client/test_stages_instances.py +++ b/tests/integration/client/test_stages_instances.py @@ -1,34 +1,41 @@ +from collections.abc import AsyncGenerator + +import pytest + +from asyncord.client.channels.models.responses import ChannelResponse from asyncord.client.stage_instances.models.requests import CreateStageInstanceRequest, UpdateStageInstanceRequest +from asyncord.client.stage_instances.models.responses import StageInstanceResponse from asyncord.client.stage_instances.resources import StageInstancesResource -from tests.conftest import IntegrationTestData -async def test_stage_instance_cycle( +@pytest.fixture() +async def stage( + stage_channel: ChannelResponse, stage_instances_res: StageInstancesResource, - integration_data: IntegrationTestData, -) -> None: - """Test creating, getting, updating, deleting stage instance.""" - created_stage = await stage_instances_res.create_stage_instance( +) -> AsyncGenerator[StageInstanceResponse, None]: + """Fixture for creating stage instance and deleting it after test.""" + stage = await stage_instances_res.create_stage_instance( CreateStageInstanceRequest( - channel_id=integration_data.stage_id, + channel_id=stage_channel.id, topic='Test topic', ), ) + yield stage + await stage_instances_res.delete(stage_channel.id) - requested_stage = await stage_instances_res.get_stage_instance( - integration_data.stage_id, - ) - assert created_stage.id == requested_stage.id +async def test_stage_instance_lifecycle( + stage_channel: ChannelResponse, + stage: StageInstanceResponse, + stage_instances_res: StageInstancesResource, +) -> None: + """Test full lifecycle of stage instance.""" + requested_stage = await stage_instances_res.get(stage_channel.id) - updated_stage = await stage_instances_res.update_stage_instance( - integration_data.stage_id, - UpdateStageInstanceRequest( - topic='Updated topic', - ), - ) - assert updated_stage.topic != created_stage.topic + assert stage.id == requested_stage.id - await stage_instances_res.delete_stage_instance( - integration_data.stage_id, + updated_stage = await stage_instances_res.update( + stage_channel.id, + UpdateStageInstanceRequest(topic='Updated topic'), ) + assert updated_stage.topic == 'Updated topic' diff --git a/tests/integration/client/test_stickers.py b/tests/integration/client/test_stickers.py index f25615e..aae1b4c 100644 --- a/tests/integration/client/test_stickers.py +++ b/tests/integration/client/test_stickers.py @@ -1,5 +1,9 @@ +from collections.abc import AsyncGenerator, Callable, Iterable from pathlib import Path +import pytest + +from asyncord.client.models.stickers import Sticker from asyncord.client.stickers.models.requests import ( CreateGuildStickerRequest, UpdateGuildStickerRequest, @@ -10,43 +14,113 @@ TEST_STICKER = Path('tests/data/test_sticker.png') -async def test_sticker_cycle( +@pytest.fixture() +async def sticker( stickers_res: StickersResource, integration_data: IntegrationTestData, -) -> None: - """Test create, get, update, delete sticker.""" - created_sticker = await stickers_res.create_guild_sticker( +) -> AsyncGenerator[Sticker, None]: + """Create a sticker and delete it after the test.""" + sticker = await stickers_res.create_guild_sticker( integration_data.guild_id, CreateGuildStickerRequest( - name='test_sticker', - description='test_sticker_description', - tags='test_sticker_tags', + name='TestSticker', + description='Test sticker description', + tags='test sticker tags', image_data=TEST_STICKER, ), ) - try: - sticker = await stickers_res.get_guild_sticker( - integration_data.guild_id, - created_sticker.id, - ) - - assert sticker.id == created_sticker.id - assert sticker.name == 'test_sticker' - - updated_sticker = await stickers_res.update_guild_sticker( - integration_data.guild_id, - created_sticker.id, - UpdateGuildStickerRequest( - name='test_sticker_updated', - description='test_sticker_description_updated', - tags='test_sticker_tags_updated', - ), - ) - - assert updated_sticker.name != created_sticker.name - - finally: - await stickers_res.delete_guild_sticker( - integration_data.guild_id, - created_sticker.id, - ) + yield sticker + await stickers_res.delete_guild_sticker( + integration_data.guild_id, + sticker.id, + ) + + +async def test_get_sticker(sticker: Sticker, stickers_res: StickersResource) -> None: + """Test getting a sticker.""" + assert await stickers_res.get_sticker(sticker.id) + + +async def test_get_sticker_pack_list(stickers_res: StickersResource) -> None: + """Test getting a list of sticker packs.""" + assert await stickers_res.get_sticker_pack_list() + + +async def test_get_guild_stickers_list( + stickers_res: StickersResource, + integration_data: IntegrationTestData, +) -> None: + """Test getting a list of stickers in a guild.""" + stickers = await stickers_res.get_guild_stickers_list(integration_data.guild_id) + assert isinstance(stickers, list) + + +async def test_get_guild_sticker( + sticker: Sticker, + stickers_res: StickersResource, + integration_data: IntegrationTestData, +) -> None: + """Test getting a sticker in a guild.""" + guild_sticker = await stickers_res.get_guild_sticker( + integration_data.guild_id, + sticker.id, + ) + assert guild_sticker.id == sticker.id + + +async def test_update_guild_sticker( + sticker: Sticker, + stickers_res: StickersResource, + integration_data: IntegrationTestData, +) -> None: + """Test updating a sticker in a guild.""" + updated_sticker = await stickers_res.update_guild_sticker( + integration_data.guild_id, + sticker.id, + UpdateGuildStickerRequest( + name='UpdatedSticker', + description='Updated sticker description', + tags='updated sticker tags', + ), + ) + assert updated_sticker.name == 'UpdatedSticker' + assert updated_sticker.description == 'Updated sticker description' + assert updated_sticker.tags == 'updated sticker tags' + + +@pytest.mark.parametrize( + 'iterable_type', + [ + set, + list, + tuple, + iter, + ], +) +def test_tags_can_be_iterable(iterable_type: Callable[[Iterable[str]], None]) -> None: + """Test tags can be an iterable.""" + tags = iterable_type(['updated', 'sticker', 'tags']) + model = UpdateGuildStickerRequest(tags=tags) + assert model.tags == {'updated', 'sticker', 'tags'} + + +@pytest.mark.parametrize( + 'tags', + [ + 'updated, sticker, tags', + 'updated , sticker, tags', + ' updated, sticker , \ntags ', + ], +) +def test_tags_can_be_str(tags: str) -> None: + """Test tags can be a string with comma-separated values.""" + model = UpdateGuildStickerRequest( + tags=tags, + ) + assert model.tags == {'updated', 'sticker', 'tags'} + + +def test_tags_too_long() -> None: + """Test tags cannot be too long.""" + with pytest.raises(ValueError, match='length must be less than'): + UpdateGuildStickerRequest(tags='a' * 201) diff --git a/tests/integration/client/test_users.py b/tests/integration/client/test_users.py index 9eb9de2..047432f 100644 --- a/tests/integration/client/test_users.py +++ b/tests/integration/client/test_users.py @@ -1,7 +1,9 @@ from pathlib import Path +from typing import Literal import pytest +from asyncord.client.http import errors from asyncord.client.users.models.requests import UpdateUserRequest from asyncord.client.users.resources import UserResource from tests.conftest import IntegrationTestData @@ -37,14 +39,44 @@ async def test_update_current_user(users_res: UserResource) -> None: assert await users_res.update_user(user_data) -async def test_get_guilds(users_res: UserResource) -> None: +@pytest.mark.parametrize('after', [None, 'from_config']) +@pytest.mark.parametrize('before', [None, 'from_config']) +@pytest.mark.parametrize('limit', [None, 1]) +async def test_get_guilds( + limit: int | None, + before: Literal['from_config'] | int | None, + after: Literal['from_config'] | int | None, + users_res: UserResource, + integration_data: IntegrationTestData, +) -> None: """Test getting the current user's guilds.""" - guilds = await users_res.get_guilds() - assert len(guilds) - assert guilds[0].id + if before == 'from_config': + before = int(integration_data.guild_id) + + if after == 'from_config': + after = int(integration_data.guild_id) + + guilds = await users_res.get_guilds( + limit=limit, + before=before, + after=after, + ) + assert isinstance(guilds, list) + + +async def test_create_group_dm( + users_res: UserResource, + integration_data: IntegrationTestData, +) -> None: + """Test creating a group DM. + + Just a smoke test to test some models. + """ + with pytest.raises(errors.ClientError, match='CHANNEL_RECIPIENT_REQUIRED'): + await users_res.create_group_dm([integration_data.user_id]) -@pytest.mark.skip(reason='Skip the test because bots cannot use this endpoint') +@pytest.mark.skip(reason='Bot cannot use this endpoint') async def test_get_current_user_guild_member( users_res: UserResource, integration_data: IntegrationTestData, diff --git a/tests/integration/client/test_webhooks.py b/tests/integration/client/test_webhooks.py index 1175707..806cf7b 100644 --- a/tests/integration/client/test_webhooks.py +++ b/tests/integration/client/test_webhooks.py @@ -1,107 +1,165 @@ -"""Contains webhook tests.""" +import pytest -from asyncord.client.messages.models.requests.embeds import ( - Embed, -) +from asyncord.client.http.errors import NotFoundError +from asyncord.client.messages.models.responses.messages import MessageResponse +from asyncord.client.rest import RestClient from asyncord.client.webhooks.models.requests import ( - CreateWebhookRequest, ExecuteWebhookRequest, UpdateWebhookMessageRequest, UpdateWebhookRequest, ) +from asyncord.client.webhooks.models.responces import WebhookResponse from asyncord.client.webhooks.resources import WebhooksResource from tests.conftest import IntegrationTestData -# FIXME: Add more tests. + +@pytest.fixture(scope='module') +async def module_webhook_res(module_client: RestClient) -> WebhooksResource: + """Get webhooks resource for the module.""" + return module_client.webhooks -async def test_webhook_cycle( +@pytest.fixture(scope='module') +async def webhook_message( + webhook: WebhookResponse, + module_webhook_res: WebhooksResource, +) -> MessageResponse: + """Create a webhook message.""" + return await module_webhook_res.execute_webhook( + webhook_id=webhook.id, + token=webhook.token, # type: ignore + execution_data=ExecuteWebhookRequest(content='Hello World'), + ) + + +async def test_get_channel_webhooks( webhooks_res: WebhooksResource, integration_data: IntegrationTestData, ) -> None: - """Test. - - Create webhook. - Get channel webhooks. - Get guild webhooks. + """Test getting all webhooks in a channel.""" + webhooks = await webhooks_res.get_channel_webhooks(integration_data.channel_id) + assert isinstance(webhooks, list) - Update webhook. - Execute webhook(send message). +async def test_get_guild_webhooks( + webhooks_res: WebhooksResource, + integration_data: IntegrationTestData, +) -> None: + """Test getting all webhooks in a guild.""" + webhooks = await webhooks_res.get_guild_webhooks(integration_data.guild_id) + assert isinstance(webhooks, list) - Get webhook message. - Update webhook message. - Delete webhook message. - Delete webhook. - """ - webhook = await webhooks_res.create_webhook( - integration_data.channel_id, - create_data=CreateWebhookRequest( - name='Test Webhook', - avatar=None, - ), +@pytest.mark.parametrize('with_token', [False, True]) +async def test_get_webhook( + with_token: bool, + webhooks_res: WebhooksResource, + webhook: WebhookResponse, +) -> None: + """Test getting a webhook by its id.""" + if with_token: + token = webhook.token # type: ignore + else: + token = None + new_webhook_obj = await webhooks_res.get_webhook(webhook.id, token=token) + assert new_webhook_obj.id == webhook.id + + +async def test_update_with_token_and_channel_forbidden(webhooks_res: WebhooksResource) -> None: + """Test updating a webhook with token and channel_id.""" + with pytest.raises(ValueError, match='`channel_id` cannot be set'): + await webhooks_res.update_webhook( + webhook_id='webhook_id', + update_data=UpdateWebhookRequest(name='Updated Webhook', channel_id=123), + token='token', # noqa: S106 + ) + + +async def test_update_webhook( + webhook: WebhookResponse, + module_webhook_res: WebhooksResource, +) -> None: + """Test updating a webhook.""" + update_data = UpdateWebhookRequest(name='Updated Webhook') + webhook = await module_webhook_res.update_webhook( + webhook_id=webhook.id, + update_data=update_data, ) - assert webhook - assert webhook.token + assert webhook.name == 'Updated Webhook' - webhook_channel_resp = await webhooks_res.get_channel_webhooks( - integration_data.channel_id, - ) - webhook_guild_resp = await webhooks_res.get_guild_webhooks( - integration_data.guild_id, +@pytest.mark.parametrize( + 'wait', + [False, True], +) +async def test_execute_webhook( + wait: bool, + webhook: WebhookResponse, + module_webhook_res: WebhooksResource, +) -> None: + """Test executing a webhook.""" + message = await module_webhook_res.execute_webhook( + webhook_id=webhook.id, + token=webhook.token, # type: ignore + execution_data=ExecuteWebhookRequest(content='Hello World'), + wait=wait, ) + assert bool(message) is wait + + if not message: + return + + assert message.content == 'Hello World' - updated_webhook = await webhooks_res.update_webhook( - webhook.id, - update_data=UpdateWebhookRequest( - name='Updated Test Webhook', - avatar=None, - ), - ) - created_message = await webhooks_res.execute_webhook( +async def test_get_webhook_message( + webhook_message: MessageResponse, + webhook: WebhookResponse, + module_webhook_res: WebhooksResource, +) -> None: + """Test getting a webhook message.""" + token: str = webhook.token # type: ignore + retrieved_message = await module_webhook_res.get_webhook_message( webhook_id=webhook.id, - webhook_token=webhook.token, - execute_data=ExecuteWebhookRequest( - embeds=[ - Embed( - title='Webhook Test', - description='This is a test webhook', - ), - ], - ), - wait=True, + token=token, + message_id=webhook_message.id, ) + assert webhook_message.id == retrieved_message.id - assert created_message - message = await webhooks_res.get_webhook_message( - webhook.id, - webhook.token, - created_message.id, +async def test_update_webhook_message( + webhook_message: MessageResponse, + webhook: WebhookResponse, + module_webhook_res: WebhooksResource, +) -> None: + """Test updating a webhook message.""" + token: str = webhook.token # type: ignore + updated_message = await module_webhook_res.update_webhook_message( + webhook_id=webhook.id, + token=token, + message_id=webhook_message.id, + update_data=UpdateWebhookMessageRequest(content='updated'), ) + assert updated_message.content == 'updated' - updated_message = await webhooks_res.update_webhook_message( - webhook.id, - webhook.token, - created_message.id, - UpdateWebhookMessageRequest( - embeds=[ - Embed( - title='Updated Webhook Test', - description='This is an updated test webhook', - ), - ], - ), - ) - await webhooks_res.delete_webhook_message(webhook.id, webhook.token, created_message.id) - await webhooks_res.delete_webhook(webhook.id) +async def test_delete_webhook_message( + webhook_message: MessageResponse, + webhook: WebhookResponse, + module_webhook_res: WebhooksResource, +) -> None: + """Test deleting a webhook message.""" + token: str = webhook.token # type: ignore - assert message - assert updated_webhook.name != webhook.name - assert updated_message - assert webhook_channel_resp is not None - assert webhook_guild_resp is not None + await module_webhook_res.delete_webhook_message( + webhook_id=webhook.id, + token=token, + message_id=webhook_message.id, + ) + # Assert deletion by trying to fetch the deleted message + with pytest.raises(NotFoundError): + await module_webhook_res.get_webhook_message( + webhook_id=webhook.id, + token=token, + message_id=webhook_message.id, + ) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 1eded9d..a5132e8 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,6 +4,12 @@ from asyncord.client.applications.resources import ApplicationResource from asyncord.client.bans.resources import BanResource +from asyncord.client.channels.models.requests.creation import ( + CreateAnoncementChannelRequest, + CreateStageChannelRequest, + CreateTextChannelRequest, +) +from asyncord.client.channels.models.responses import ChannelResponse from asyncord.client.channels.resources import ChannelResource from asyncord.client.commands.resources import CommandResource from asyncord.client.emojis.resources import EmojiResource @@ -16,7 +22,6 @@ from asyncord.client.messages.models.responses.messages import MessageResponse from asyncord.client.messages.resources import MessageResource from asyncord.client.polls.resources import PollsResource -from asyncord.client.reactions.resources import ReactionResource from asyncord.client.rest import RestClient from asyncord.client.roles.resources import RoleResource from asyncord.client.scheduled_events.resources import ScheduledEventsResource @@ -27,6 +32,8 @@ from asyncord.client.threads.models.responses import ThreadResponse from asyncord.client.threads.resources import ThreadResource from asyncord.client.users.resources import UserResource +from asyncord.client.webhooks.models.requests import CreateWebhookRequest +from asyncord.client.webhooks.models.responces import WebhookResponse from asyncord.client.webhooks.resources import WebhooksResource from tests.conftest import IntegrationTestData @@ -163,15 +170,6 @@ async def events_res( return guilds_res.events(integration_data.guild_id) -@pytest.fixture() -async def reactions_res( - message: MessageResponse, - messages_res: MessageResource, -) -> ReactionResource: - """Get reactions resource for the message.""" - return messages_res.reactions(message.id) - - @pytest.fixture() async def commands_res( client: RestClient, @@ -220,3 +218,77 @@ async def thread(thread_res: ThreadResource) -> AsyncGenerator[ThreadResponse, N ) yield thread await thread_res.delete(thread.id) + + +@pytest.fixture() +async def channel( + channel_res: ChannelResource, + integration_data: IntegrationTestData, +) -> AsyncGenerator[ChannelResponse, None]: + """Fixture for creating channel and deleting it after test.""" + channel = await channel_res.create_channel( + guild_id=integration_data.guild_id, + channel_data=CreateTextChannelRequest(name='test-channel'), + ) + + yield channel + await channel_res.delete(channel.id) + + +@pytest.fixture() +async def stage_channel( + channel_res: ChannelResource, + integration_data: IntegrationTestData, +) -> AsyncGenerator[ChannelResponse, None]: + """Fixture for creating channel and deleting it after test.""" + channel = await channel_res.create_channel( + integration_data.guild_id, + CreateStageChannelRequest(name='Test stage channel'), + ) + yield channel + await channel_res.delete(channel.id) + + +@pytest.fixture() +async def announcement_channel( + channel_res: ChannelResource, + integration_data: IntegrationTestData, +) -> AsyncGenerator[ChannelResponse, None]: + """Get the announcement channel.""" + channel = await channel_res.create_channel( + guild_id=integration_data.guild_id, + channel_data=CreateAnoncementChannelRequest(name='test-announcement'), + ) + + yield channel + + await channel_res.delete(channel.id) + + +@pytest.fixture(scope='module') +async def module_client(token: str) -> RestClient: + """Get new rest client.""" + return RestClient( + token, + ratelimit_strategy=BackoffRateLimitStrategy( + max_retries=10, + min_wait_time=2, + max_wait_time=60, + ), + ) + + +@pytest.fixture(scope='module') +async def webhook( + module_client: RestClient, + integration_data: IntegrationTestData, +) -> AsyncGenerator[WebhookResponse, None]: + """Get new webhook for module.""" + webhook = await module_client.webhooks.create_webhook( + integration_data.channel_id, + CreateWebhookRequest(name='Test Webhook'), + ) + + yield webhook + + await module_client.webhooks.delete_webhook(webhook.id) diff --git a/tests/test_base64_image.py b/tests/test_base64_image.py index f700071..0065528 100644 --- a/tests/test_base64_image.py +++ b/tests/test_base64_image.py @@ -27,24 +27,73 @@ def test_build_image(img_name: Path) -> None: assert len(enc_image.image_data) > 100 -@pytest.mark.parametrize('img_name', [Path(f'tests/data/{file_name}') for file_name in TEST_FILE_NAMES]) -def test_base64_images_in_models(img_name: Path) -> None: - """Test that images can be used in pydantic models.""" +@pytest.mark.parametrize('image_path', [Path(f'tests/data/{file_name}') for file_name in TEST_FILE_NAMES]) +def test_images_in_models_from_data(image_path: Path) -> None: + """Test image convertation_to_base64_image from binary data.""" class TestModel(BaseModel): image: Base64ImageInputType - with Path(img_name).open('rb') as file: + with image_path.open('rb') as file: image_data = file.read() model = TestModel(image=image_data) assert isinstance(model.image, Base64Image) +@pytest.mark.parametrize('image_path', [Path(f'tests/data/{file_name}') for file_name in TEST_FILE_NAMES]) +def test_images_in_models_from_pathes(image_path: Path) -> None: + """Test image convertation_to_base64_image from pathes.""" + + class TestModel(BaseModel): + image: Base64ImageInputType + + model = TestModel(image=image_path) + assert isinstance(model.image, Base64Image) + + +def test_base64_image_in_models_from_base64_image() -> None: + """Test that base64 image can be passed to the model.""" + + class TestModel(BaseModel): + image: Base64ImageInputType + + image = Base64Image('data:image/png;base64,123') + model = TestModel(image=image) + assert model.image == image + + def test_base64_image_error_in_models() -> None: - """Test that images can be used in pydantic models.""" + """Test that error is raised if image is not base64 encoded.""" class TestModel(BaseModel): image: Base64ImageInputType with pytest.raises(ValueError, match='Icon must be a base64 encoded image'): TestModel(image='not a base64 image') + + +def test_image_equal() -> None: + """Test that images can be compared.""" + image1 = Base64Image('data:image/png;base64,123') + image2 = Base64Image('data:image/png;base64,123') + image3 = Base64Image('data:image/png;base64,124') + + assert image1 == image2 + assert image1 != image3 + assert image2 != image3 + + +def test_image_hash() -> None: + """Test that images can be hashed.""" + image1 = Base64Image('data:image/png;base64,123') + image2 = Base64Image('data:image/png;base64,123') + image3 = Base64Image('data:image/png;base64,124') + + assert hash(image1) == hash(image2) + assert hash(image1) != hash(image3) + + +def test_image_to_str() -> None: + """Test that images can be converted to string.""" + image = Base64Image('data:image/png;base64,123') + assert str(image) == 'data:image/png;base64,123' diff --git a/tests/test_client_hub.py b/tests/test_client_hub.py new file mode 100644 index 0000000..d683895 --- /dev/null +++ b/tests/test_client_hub.py @@ -0,0 +1,69 @@ +import logging +from unittest.mock import AsyncMock, Mock + +import pytest +from pytest_mock import MockerFixture + +from asyncord.client_hub import ClientHub + + +async def test_create_hub_without_session() -> None: + """Test creating a hub without a session.""" + hub = ClientHub() + assert not hub._is_outer_session + assert hub.session + + +async def test_setup_single_client_group(mocker: MockerFixture) -> None: + """Test setup_single_client_group method.""" + mock_gather = mocker.patch('asyncio.gather', new=mocker.async_stub('gather')) + mock_client_group_class = mocker.patch('asyncord.client_hub.ClientGroup') + + async with ClientHub.setup_single_client_group(auth='token', session=Mock()) as client_group: + mock_client_group_class.assert_called_once() + mock_client_group = mock_client_group_class.return_value + assert client_group is mock_client_group + + mock_gather.assert_called_once() + mock_client_group.connect.assert_called_once() + + +async def test_setup_with_dispatcher(mocker: MockerFixture, caplog: pytest.LogCaptureFixture) -> None: + """Test setup method with a dispatcher.""" + mock_gather = mocker.patch('asyncio.gather', new=mocker.async_stub('gather')) + hub_context = ClientHub.setup_single_client_group(auth='token', session=Mock(), dispatcher=Mock()) + mocker.patch('asyncord.client_hub.ClientGroup') + with caplog.at_level(logging.WARNING): + async with hub_context: + pass + + mock_gather.assert_called_once() + assert 'dispatcher is passed' in caplog.text + + +@pytest.mark.skip(reason='Not implemented yet. https://github.com/pytest-dev/pytest/discussions/12540') +async def test_start_handles_exceptions(mocker: MockerFixture) -> None: + """Test start method handles exceptions.""" + # Setup ClientHub instance + hub = ClientHub() + hub.heartbeat_factory = AsyncMock() + + client1 = AsyncMock() + # Simulate KeyboardInterrupt during asyncio.gather + client1.connect.side_effect = KeyboardInterrupt + + hub.client_groups = { + 'client1': client1, + 'client2': AsyncMock(), + } + + # Mock logger + mock_logger = mocker.patch('asyncord.client_hub.logger') + + await hub.start() + + # Assertions + mock_logger.info.assert_any_call('Shutting down...') + hub.heartbeat_factory.start.assert_called_once() + for client in hub.client_groups.values(): + client.connect.assert_called() # type: ignore diff --git a/tests/test_color.py b/tests/test_color.py index 153d392..f60826e 100644 --- a/tests/test_color.py +++ b/tests/test_color.py @@ -13,6 +13,11 @@ def color() -> Color: return Color(0xAABBCC) +def test_rgb_repr() -> None: + """Test the string representation of an RGB object.""" + assert repr(RGB(170, 187, 204)) == 'RGB(170, 187, 204)' + + def test_create_color_from_constructor(color: Color) -> None: """Test creating a color from the constructor.""" assert Color(0xAABBCC).value == DECIMAL_VALUE @@ -67,6 +72,11 @@ def test_color_to_rgb(color: Color) -> None: assert color.to_rgb() == RGB(170, 187, 204) +def test_color_repr(color: Color) -> None: + """Test the string representation of a color.""" + assert repr(color) == f'Color({hex(DECIMAL_VALUE)})' + + @pytest.mark.parametrize( 'color_value', [ @@ -77,6 +87,7 @@ def test_color_to_rgb(color: Color) -> None: '#aabbcc', (170, 187, 204), RGB(170, 187, 204), + Color(DECIMAL_VALUE), ], ids=[ 'int', @@ -86,6 +97,7 @@ def test_color_to_rgb(color: Color) -> None: 'web hex str', 'rgb tuple', 'RGB class', + 'Color class', ], ) def test_color_in_models( diff --git a/tests/test_heartbeat.py b/tests/test_heartbeat.py deleted file mode 100644 index d26d4a8..0000000 --- a/tests/test_heartbeat.py +++ /dev/null @@ -1,76 +0,0 @@ -import asyncio -import threading -from unittest.mock import Mock - -from asyncord.gateway.client.client import ConnectionData, GatewayClient -from asyncord.gateway.client.heartbeat import Heartbeat, HeartbeatFactory - - -def test_heartbeat_factory_initialization() -> None: - """Test that the heartbeat factory initializes correctly.""" - factory = HeartbeatFactory() - assert isinstance(factory.loop, asyncio.AbstractEventLoop) - assert isinstance(factory.thread, threading.Thread) - - -def test_heartbeat_factory_create() -> None: - """Test that the heartbeat factory creates a heartbeat correctly.""" - factory = HeartbeatFactory() - client = Mock(spec=GatewayClient) - conn_data = Mock(spec=ConnectionData) - heartbeat = factory.create(client, conn_data) - assert isinstance(heartbeat, Heartbeat) - assert heartbeat.client == client - assert heartbeat.conn_data == conn_data - assert heartbeat._loop == factory.loop - - -def test_heartbeat_factory_start() -> None: - """Test that the heartbeat factory starts correctly.""" - factory = HeartbeatFactory() - factory.start() - assert factory.is_running - assert factory.thread.is_alive() - - -def test_heartbeat_factory_stop() -> None: - """Test that the heartbeat factory stops correctly.""" - factory = HeartbeatFactory() - factory.start() - factory.stop() - assert not factory.is_running - assert not factory.thread.is_alive() - - -def test_multiple_heartbeats_same_thread() -> None: - """Test that multiple heartbeats share the same thread.""" - factory = HeartbeatFactory() - session = Mock(spec=GatewayClient) - conn_data = Mock(spec=ConnectionData) - heartbeat1 = factory.create(session, conn_data) - heartbeat2 = factory.create(session, conn_data) - assert heartbeat1._loop == heartbeat2._loop - - -def test_multiple_heartbeats() -> None: - """Test that multiple heartbeats are different.""" - factory = HeartbeatFactory() - session = Mock(spec=GatewayClient) - conn_data = Mock(spec=ConnectionData) - heartbeat1 = factory.create(session, conn_data) - heartbeat2 = factory.create(session, conn_data) - assert heartbeat1 != heartbeat2 - - -def test_heartbeat_continues_after_one_stops() -> None: - """Test that a heartbeat continues after another stops.""" - factory = HeartbeatFactory() - session = Mock(spec=GatewayClient) - conn_data = Mock(spec=ConnectionData) - heartbeat1 = factory.create(session, conn_data) - heartbeat2 = factory.create(session, conn_data) - heartbeat1.run(10) - heartbeat2.run(10) - heartbeat1.stop() - assert not heartbeat1.is_running - assert heartbeat2.is_running diff --git a/tests/test_http_client.py b/tests/test_http_client.py index fbcf028..cc93804 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -2,7 +2,6 @@ from http import HTTPStatus from unittest.mock import ANY, AsyncMock, Mock -import aiohttp import pytest from multidict import CIMultiDict from pytest_mock import MockFixture @@ -54,15 +53,15 @@ ), ], ) -async def test_http_client_general_methods( - session: aiohttp.ClientSession | None, +async def test_http_methods( + session: Mock | None, method: str, url: str, headers: dict[str, str], payload: dict[str, str] | None, mocker: MockFixture, ) -> None: - """Complete test for the HTTP client general methods. + """Test calling HTTP methods. This test looks complex, but it's actually quite simple. It's not super necessary to test every single method of the HTTP client, but it's good to have a test that @@ -89,10 +88,10 @@ async def test_http_client_general_methods( method=method, url=url, headers=headers, - data=ANY, # Check data later. + data=ANY, # Check data later ) - # Some methods don't have payloads, so we need to check if the payload is correct. + # Some methods don't have payloads, so we need to check if the payload is correct if payload is not None: data_arg = request_mock.call_args.kwargs['data'] assert json.loads(data_arg._value) == payload @@ -129,6 +128,30 @@ async def test_extract_body_not_json() -> None: assert result == {} +async def test_add_middleware() -> None: + """Test adding middleware to the HTTP client.""" + client = HttpClient() + middleware = Mock() + assert middleware not in client.middlewares + + client.add_middleware(middleware) + assert middleware in client.middlewares + + +async def test_skip_middleware() -> None: + """Test skipping middleware in a request.""" + client = HttpClient(request_handler=AsyncMock()) + middleware = AsyncMock() + client.system_middlewares = [] + client.middlewares = [middleware] + + await client.request(Mock(), skip_middleware=True) + middleware.assert_not_called() + + await client.request(Mock()) + middleware.assert_called_once() + + async def test_apply_middleware(mocker: MockFixture) -> None: """Test applying middleware to a request. @@ -167,3 +190,9 @@ async def wrapper(request: Request, http_client: HttpClient, next_call: NextCall raw_request_mock.assert_called_once_with(request_mock) assert middleware1.call_order < middleware2.call_order < system_middleware.call_order + + +async def test_init_with_both_request_handler_and_session() -> None: + """Test initializing the HTTP client with both a request handler and a session.""" + with pytest.warns(UserWarning, match=r'.* should not provide both .*'): + HttpClient(request_handler=Mock(), session=Mock()) diff --git a/tests/test_hub.py b/tests/test_hub.py deleted file mode 100644 index 0ea580a..0000000 --- a/tests/test_hub.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from pytest_mock import MockerFixture - -from asyncord.client_hub import ClientHub - - -async def test_setup_single_client_group(mocker: MockerFixture) -> None: - """Test setup_single_client_group method.""" - mock_gather = mocker.patch('asyncio.gather', new=mocker.async_stub('gather')) - mock_client_group_class = mocker.patch('asyncord.client_hub.ClientGroup') - - async with ClientHub.setup_single_client_group(auth='token') as client_group: - mock_client_group_class.assert_called_once() - mock_client_group = mock_client_group_class.return_value - assert client_group is mock_client_group - - mock_gather.assert_called_once() - mock_client_group.connect.assert_called_once() - - -async def test_create_client_group(mocker: MockerFixture) -> None: - """Test creation of a client group.""" - mock_client_group_class = mocker.patch('asyncord.client_hub.ClientGroup') - hub = ClientHub() - hub.create_client_group('group_name', 'token') - - mock_client_group_class.assert_called_once() - - with pytest.raises(ValueError, match=r'Client group group_name already exists'): - hub.create_client_group('group_name', auth='token') diff --git a/tests/test_strflag.py b/tests/test_strflag.py new file mode 100644 index 0000000..dc3989e --- /dev/null +++ b/tests/test_strflag.py @@ -0,0 +1,98 @@ +import pytest + +from asyncord.typedefs import StrFlag + + +class _TestStrFlag(StrFlag): + """Test string flag enum.""" + + TEST1 = 'test1' + TEST2 = 'test2' + TEST3 = 'test3' + + +@pytest.mark.parametrize( + 'flag', + [ + pytest.param('no_flag', id='unknown_flag'), + pytest.param({'no_flag'}, id='unknown_flag_set'), + pytest.param({'no_flag1', 'no_flag2'}, id='unknown_flag_set_multiple'), + pytest.param({'test1', 'no_flag'}, id='flag_set_with_existing_flag_and_unknown_flag'), + ], +) +def test_get_with_unknown_value(flag: str) -> None: + """Test getting the value of the scope with unknown value.""" + with pytest.raises(ValueError, match='has no member'): + _TestStrFlag(flag) + + +@pytest.mark.parametrize( + 'flag', + [ + pytest.param('test1', id='string_flag'), + pytest.param({'test1'}, id='flag_set'), + pytest.param({'test1', 'test2'}, id='flag_set_multiple'), + ], +) +def test_get_with_value(flag: str) -> None: + """Test getting the value of the scope.""" + assert _TestStrFlag(flag)._value_ == (flag if isinstance(flag, set) else {flag}) + + +def test_get_value() -> None: + """Test value of the scope.""" + assert _TestStrFlag['TEST1'] + with pytest.raises(KeyError): + _TestStrFlag['NO_FLAG'] + + +def test_str() -> None: + """Test string representation of the scope.""" + assert str(_TestStrFlag.TEST1) == 'test1' + assert str(_TestStrFlag({'test1', 'test2'})) == 'test1 test2' + + +def test_repr() -> None: + """Test representation of the scope.""" + assert repr(_TestStrFlag.TEST1) == '<_TestStrFlag.TEST1>' + assert repr(_TestStrFlag({'test1', 'test2'})) == '<_TestStrFlag.TEST1|TEST2>' + + +@pytest.mark.parametrize( + ('flag1', 'flag2', 'expected'), + [ + pytest.param( + _TestStrFlag.TEST1, + _TestStrFlag.TEST1, + _TestStrFlag.TEST1, + id='same_flag', + ), + pytest.param( + _TestStrFlag.TEST1, + _TestStrFlag.TEST2, + _TestStrFlag({'test1', 'test2'}), + id='different_flag', + ), + pytest.param( + _TestStrFlag({'test1', 'test2'}), + _TestStrFlag.TEST1, + _TestStrFlag({'test1', 'test2'}), + id='flag_set_and_same_flag', + ), + pytest.param( + _TestStrFlag({'test1', 'test2'}), + _TestStrFlag({'test2', 'test3'}), + _TestStrFlag({'test1', 'test2', 'test3'}), + id='flag_set_and_flag_set', + ), + pytest.param( + _TestStrFlag({'test1', 'test2'}), + _TestStrFlag({'test2', 'test3'}), + _TestStrFlag({'test1', 'test2', 'test3'}), + id='flag_set_and_flag_set_union', + ), + ], +) +def test_concatenation(flag1: _TestStrFlag, flag2: _TestStrFlag, expected: _TestStrFlag) -> None: + """Test the union of the two scopes.""" + assert (flag1 | flag2) == expected