Skip to content

Commit

Permalink
Rework of models and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vadim-su authored Jun 28, 2024
1 parent ac773cb commit 71c6c09
Show file tree
Hide file tree
Showing 105 changed files with 4,277 additions and 1,460 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
ASYNCORD_TEST_APP_ID: ${{ vars.ASYNCORD_TEST_APP_ID }}
ASYNCORD_TEST_ROLE_ID: ${{ vars.ASYNCORD_TEST_ROLE_ID }}
ASYNCORD_TEST_USER_TO_BAN: ${{ vars.ASYNCORD_TEST_USER_TO_BAN }}
ASYNCORD_TEST_STAGE_ID: ${{ vars.ASYNCORD_TEST_STAGE_ID }}
ASYNCORD_TEST_WORKERS: ${{ vars.ASYNCORD_TEST_WORKERS }}
ASYNCORD_TEST_ROLE_TO_PRUNE: ${{ vars.ASYNCORD_TEST_ROLE_TO_PRUNE }}

run: pdm run pytest -n $ASYNCORD_TEST_WORKERS
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.4
rev: v0.4.10
hooks:
# Run the linter.
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
},
"editor.defaultFormatter": "charliermarsh.ruff"
},
}
}
15 changes: 9 additions & 6 deletions asyncord/base64_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def build(cls, image_data: bytes | str, image_type: str | None = None) -> Self:
encoded_image = base64.b64encode(image_data).decode()
return cls(f'data:{image_type};base64, {encoded_image}')

raise ValueError('Invalid image data type')

@classmethod
def from_file(cls, file_path: str | Path) -> Self:
"""Build Base64Image from file path.
Expand Down Expand Up @@ -92,13 +90,17 @@ def validate(cls, value: bytes | str | Self) -> Self:
Raises:
ValueError: If value is not
"""
if isinstance(value, cls):
return value

if isinstance(value, bytes | str):
return cls.build(value)

if isinstance(value, cls):
return value
if isinstance(value, Path):
return cls.from_file(value)

raise ValueError('Invalid value type')
# This should never happen because of the pydantic schema
raise ValueError('Invalid value type') # pragma: no cover

@classmethod
def __get_pydantic_core_schema__(
Expand All @@ -118,6 +120,7 @@ def __get_pydantic_core_schema__(
schema = core_schema.union_schema([
core_schema.bytes_schema(),
core_schema.str_schema(),
core_schema.is_instance_schema(Path),
core_schema.is_instance_schema(cls),
])

Expand All @@ -143,7 +146,7 @@ def __str__(self) -> str:
return self.image_data


Base64ImageInputType = Annotated[Base64Image | bytes | str, Base64Image]
Base64ImageInputType = Annotated[Base64Image | bytes | str | Path, Base64Image]
"""Base64Image input type for pydantic models.
Base64ImageInput must validate and convert other types to Base64Image.
Expand Down
128 changes: 49 additions & 79 deletions asyncord/client/applications/models/requests.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""This module contains models related to Discord applications."""

from enum import IntEnum
from typing import Annotated

from pydantic import BaseModel, Field, ValidationInfo, field_validator
from pydantic import BaseModel, Field, HttpUrl, field_validator

from asyncord.base64_image import Base64ImageInputType
from asyncord.client.applications.models.common import (
Expand All @@ -13,6 +14,22 @@
from asyncord.client.models.permissions import PermissionFlag
from asyncord.locale import LocaleInputType

__all__ = (
'ApplicationIntegrationType',
'ApplicationIntegrationTypeConfig',
'InstallParams',
'UpdateApplicationRequest',
'UpdateApplicationRoleConnectionMetadataRequest',
)

APPLICATION_ALLOWED_TYPES = (
ApplicationFlag.GATEWAY_PRESENCE_LIMITED
| ApplicationFlag.GATEWAY_GUILD_MEMBERS_LIMITED
| ApplicationFlag.GATEWAY_MESSAGE_CONTENT_LIMITED
)

"""Allowed application uodate flags."""


class InstallParams(BaseModel):
"""Application install parameters.
Expand Down Expand Up @@ -73,13 +90,7 @@ class UpdateApplicationRequest(BaseModel):
install_params: InstallParams | None = None
"""Settings for the app's default in-app authorization link, if enabled."""

integration_types_config: (
dict[
ApplicationIntegrationType,
ApplicationIntegrationTypeConfig,
]
| None
) = None
integration_types_config: dict[ApplicationIntegrationType, ApplicationIntegrationTypeConfig] | None = None
"""Default scopes and permissions for each supported installation context.
Value for each key is an integration type configuration object.
Expand All @@ -94,53 +105,40 @@ class UpdateApplicationRequest(BaseModel):
cover_image: Base64ImageInputType | None = None
"""Default rich presence invite cover image for the app."""

interactions_endpoint_url: str | None = None
interactions_endpoint_url: HttpUrl | None = None
"""Interactions endpoint URL for the app."""

tags: list[str] | None = None
tags: (
Annotated[
set[Annotated[str, Field(max_length=20)]],
Field(max_length=5),
]
| None
) = None
"""List of tags describing the content and functionality of the app.
Maximum of 5 tags.
Maximum of 20 characters per tag.
"""

@field_validator('tags')
@classmethod
def validate_tags(
cls,
tags: list[str] | None,
field_info: ValidationInfo,
) -> list[str] | None:
"""Ensures that the length of tags is less than or equal to 5.
And each tag is less than or equal to 20 characters.
"""
max_tags = 5
max_tag_length = 20

if tags is not None:
if len(tags) > max_tags:
raise ValueError('Maximum of 5 tags allowed.')
for tag in tags:
if len(tag) > max_tag_length:
raise ValueError('Maximum of 20 characters per tag allowed.')
return tags

@field_validator('flags')
@classmethod
def validate_flags(
cls,
tags: ApplicationFlag | None,
) -> ApplicationFlag:
"""Ensures that the flag is valid."""
if tags is not None:
if tags not in {
ApplicationFlag.GATEWAY_PRESENCE_LIMITED,
ApplicationFlag.GATEWAY_GUILD_MEMBERS_LIMITED,
ApplicationFlag.GATEWAY_MESSAGE_CONTENT_LIMITED,
}:
raise ValueError('Invalid flag.')
return tags
def validate_flags(cls, flags: ApplicationFlag | None) -> ApplicationFlag | None:
"""Ensures that the flag is one of the allowed types."""
if not flags:
return None

if (flags & APPLICATION_ALLOWED_TYPES) != flags:
err_msg = 'Invalid flag. Must be one of the following: ' + ', '.join(
[
'GATEWAY_PRESENCE_LIMITED',
'GATEWAY_GUILD_MEMBERS_LIMITED',
'GATEWAY_MESSAGE_CONTENT_LIMITED',
],
)
raise ValueError(err_msg)

return flags


class UpdateApplicationRoleConnectionMetadataRequest(BaseModel):
Expand All @@ -153,48 +151,20 @@ class UpdateApplicationRoleConnectionMetadataRequest(BaseModel):
type: ApplicationRoleConnectionMetadataType
"""Type of metadata value."""

key: str
key: Annotated[str, Field(min_length=1, max_length=50, pattern=r'^[a-z0-9_]{1,50}$')]
"""Dictionary key for the metadata field.
Must be a - z, 0 - 9, or _ characters;
Must be a - z, 0 - 9, or _ characters.
1 - 50 characters.
"""

name: str = Field(None, min_length=1, max_length=100)
"""Name of the metadata field.
(1 - 100 characters).
"""
name: Annotated[str, Field(min_length=1, max_length=100)]
"""Name of the metadata field."""

name_localizations: dict[LocaleInputType, str] | None = None
"""Translations of the name."""

description: str = Field(None, min_length=1, max_length=200)
"""Description of the metadata field.
(1 - 200 characters).
"""
description: Annotated[str, Field(min_length=1, max_length=200)]
"""Description of the metadata field."""
description_localizations: dict[LocaleInputType, str] | None = None
"""Translations of the description."""

@field_validator('key')
@classmethod
def validate_key(
cls,
key: str,
field_info: ValidationInfo,
) -> list[str] | None:
"""Ensures that the length of key is 1 - 50 characters.
And a - z, 0 - 9, or _ characters.
"""
max_length = 50
allowed_symbols = set('abcdefghijklmnopqrstuvwxyz0123456789_')

if not 1 <= len(key) <= max_length:
raise ValueError('Key length must be between 1 and 50 characters.')

if not set(key).issubset(allowed_symbols):
raise ValueError('Key must contain only a - z, 0 - 9, or _ characters.')

return key
16 changes: 15 additions & 1 deletion asyncord/client/applications/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
from asyncord.locale import LocaleInputType
from asyncord.snowflake import Snowflake

__all__ = (
'ApplicationCommandPermissionOut',
'ApplicationOut',
'ApplicationRoleConnectionMetadataOut',
'ApplicationUserOut',
'BotApplicationOut',
'GuildApplicationCommandPermissionsOut',
'InstallParamsOut',
'InviteCreateEventApplication',
'TeamMemberOut',
'TeamMemberUserOut',
'TeamOut',
)


class InstallParamsOut(BaseModel):
"""Application install parameters.
Expand Down Expand Up @@ -264,7 +278,7 @@ class ApplicationOut(BaseModel):
role_connections_verification_url: AnyHttpUrl | None = None
"""Application's default role connection verification url."""

tags: list[str] = Field(default_factory=list, max_length=5)
tags: set[str] = Field(default_factory=list, max_length=5)
"""Tags describing the content and functionality of the application.
Maximum of 5 tags.
Expand Down
24 changes: 16 additions & 8 deletions asyncord/client/applications/resources.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
"""This module contains the applications resource for the client."""

from asyncord.client.applications.models.requests import (
UpdateApplicationRequest,
UpdateApplicationRoleConnectionMetadataRequest,
)
from __future__ import annotations

from typing import TYPE_CHECKING

from asyncord.client.applications.models.responses import ApplicationOut, ApplicationRoleConnectionMetadataOut
from asyncord.client.commands.resources import CommandResource
from asyncord.client.resources import APIResource
from asyncord.snowflake import SnowflakeInputType
from asyncord.typedefs import list_model
from asyncord.typedefs import CURRENT_USER, list_model
from asyncord.urls import REST_API_URL

if TYPE_CHECKING:
from asyncord.client.applications.models.requests import (
UpdateApplicationRequest,
UpdateApplicationRoleConnectionMetadataRequest,
)
from asyncord.snowflake import SnowflakeInputType

__all__ = ('ApplicationResource',)


class ApplicationResource(APIResource):
"""Represents the applications resource for the client.
Expand Down Expand Up @@ -38,7 +46,7 @@ async def get_application(self) -> ApplicationOut:
Returns:
Application object.
"""
resp = await self._http_client.get(url=self.apps_url / '@me')
resp = await self._http_client.get(url=self.apps_url / CURRENT_USER)
return ApplicationOut.model_validate(resp.body)

async def update_application(
Expand All @@ -55,7 +63,7 @@ async def update_application(
"""
payload = application_data.model_dump(mode='json', exclude_unset=True)
resp = await self._http_client.patch(
url=self.apps_url / '@me',
url=self.apps_url / CURRENT_USER,
payload=payload,
)
return ApplicationOut.model_validate(resp.body)
Expand Down
6 changes: 6 additions & 0 deletions asyncord/client/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from asyncord.client.users.models.responses import UserResponse
from asyncord.typedefs import StrFlag

__all__ = (
'AuthorizationInfoApplication',
'AuthorizationInfoResponse',
'OAuthScope',
)


@enum.unique
class OAuthScope(StrFlag):
Expand Down
9 changes: 7 additions & 2 deletions asyncord/client/auth/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
https://discord.com/developers/docs/topics/oauth2
"""

from __future__ import annotations

from asyncord.client.applications.models.responses import ApplicationOut
from asyncord.client.auth.models import AuthorizationInfoResponse
from asyncord.client.resources import APIResource
from asyncord.typedefs import CURRENT_USER
from asyncord.urls import REST_API_URL

__all__ = ('OAuthResource',)


class OAuthResource(APIResource):
"""Represents an OAuth2 resource.
Expand All @@ -31,7 +36,7 @@ async def get_current_application_info(self) -> ApplicationOut:
Reference:
https://discord.com/developers/docs/topics/oauth2#get-current-bot-application-information
"""
url = self.oauth_url / 'applications' / '@me'
url = self.oauth_url / 'applications' / CURRENT_USER
resp = await self._http_client.get(url=url)
return ApplicationOut.model_validate(resp.body)

Expand All @@ -41,6 +46,6 @@ async def get_current_authorization_info(self) -> AuthorizationInfoResponse:
Reference:
https://discord.com/developers/docs/topics/oauth2#get-current-authorization-information
"""
url = self.oauth_url / '@me'
url = self.oauth_url / CURRENT_USER
resp = await self._http_client.get(url=url)
return AuthorizationInfoResponse.model_validate(resp.body)
5 changes: 5 additions & 0 deletions asyncord/client/bans/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from asyncord.client.users.models.responses import UserResponse
from asyncord.snowflake import SnowflakeInputType

__all__ = (
'BanResponse',
'BulkBanResponse',
)


class BanResponse(BaseModel):
"""Ban object.
Expand Down
Loading

0 comments on commit 71c6c09

Please sign in to comment.