diff --git a/.github/workflows/pre-commit-updater.yml b/.github/workflows/pre-commit-updater.yml index ceee25164..5949a4b4f 100644 --- a/.github/workflows/pre-commit-updater.yml +++ b/.github/workflows/pre-commit-updater.yml @@ -8,7 +8,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v4.3.0 + uses: actions/setup-python@v4.6.0 with: python-version: '3.10' - name: Install pre-commit @@ -16,7 +16,7 @@ jobs: - name: Run pre-commit autoupdate run: pre-commit autoupdate - name: Create Pull Request - uses: peter-evans/create-pull-request@v5.0.0 + uses: peter-evans/create-pull-request@v5.0.1 with: token: ${{ secrets.GITHUB_TOKEN }} branch: update/pre-commit-autoupdate diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index 21f6c7d9e..a76af628b 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v3.3.0 - name: Set up Python 3.10 - uses: actions/setup-python@v4.5.0 + uses: actions/setup-python@v4.6.0 with: python-version: "3.10" - name: Install build @@ -21,7 +21,7 @@ jobs: run: >- python3 -m build - name: Publish release to PyPI - uses: pypa/gh-action-pypi-publish@v1.8.5 + uses: pypa/gh-action-pypi-publish@v1.8.6 with: user: __token__ password: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 44f43e801..462a39393 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,7 @@ jobs: - name: Check out code from GitHub uses: actions/checkout@v3.3.0 - name: Set up Python - uses: actions/setup-python@v4.5.0 + uses: actions/setup-python@v4.6.0 with: python-version: "3.11" - name: Install dependencies @@ -41,7 +41,7 @@ jobs: - name: Check out code from GitHub uses: actions/checkout@v3.3.0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4.5.0 + uses: actions/setup-python@v4.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b72722517..19a7fd7dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - --branch=main - id: debug-statements - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.262' + rev: 'v0.0.265' hooks: - id: ruff - repo: https://github.com/psf/black diff --git a/music_assistant/common/models/config_entries.py b/music_assistant/common/models/config_entries.py index f16223d10..ab2bdd2d2 100644 --- a/music_assistant/common/models/config_entries.py +++ b/music_assistant/common/models/config_entries.py @@ -20,8 +20,8 @@ CONF_LOG_LEVEL, CONF_OUTPUT_CHANNELS, CONF_OUTPUT_CODEC, - CONF_VOLUME_NORMALISATION, - CONF_VOLUME_NORMALISATION_TARGET, + CONF_VOLUME_NORMALIZATION, + CONF_VOLUME_NORMALIZATION_TARGET, SECURE_STRING_SUBSTITUTE, ) @@ -116,12 +116,14 @@ def parse_value( if expected_type == int and isinstance(value, float): self.value = int(value) return self.value - if expected_type == int and isinstance(value, str) and value.isnumeric(): - self.value = int(value) - return self.value - if expected_type == float and isinstance(value, str) and value.isnumeric(): - self.value = float(value) - return self.value + for val_type in (int, float): + # convert int/float from string + if expected_type == val_type and isinstance(value, str): + try: + self.value = val_type(value) + return self.value + except ValueError: + pass if self.type in UI_ONLY: self.value = self.default_value return self.value @@ -327,8 +329,8 @@ class PlayerConfig(Config): advanced=True, ) -CONF_ENTRY_VOLUME_NORMALISATION = ConfigEntry( - key=CONF_VOLUME_NORMALISATION, +CONF_ENTRY_VOLUME_NORMALIZATION = ConfigEntry( + key=CONF_VOLUME_NORMALIZATION, type=ConfigEntryType.BOOLEAN, label="Enable volume normalization (EBU-R128 based)", default_value=True, @@ -336,14 +338,14 @@ class PlayerConfig(Config): "standard without affecting dynamic range", ) -CONF_ENTRY_VOLUME_NORMALISATION_TARGET = ConfigEntry( - key=CONF_VOLUME_NORMALISATION_TARGET, +CONF_ENTRY_VOLUME_NORMALIZATION_TARGET = ConfigEntry( + key=CONF_VOLUME_NORMALIZATION_TARGET, type=ConfigEntryType.INTEGER, range=(-30, 0), default_value=-14, - label="Target level for volume normalisation", + label="Target level for volume normalization", description="Adjust average (perceived) loudness to this target level, " "default is -14 LUFS", - depends_on=CONF_VOLUME_NORMALISATION, + depends_on=CONF_VOLUME_NORMALIZATION, advanced=True, ) @@ -407,9 +409,9 @@ class PlayerConfig(Config): ) DEFAULT_PLAYER_CONFIG_ENTRIES = ( - CONF_ENTRY_VOLUME_NORMALISATION, + CONF_ENTRY_VOLUME_NORMALIZATION, CONF_ENTRY_FLOW_MODE, - CONF_ENTRY_VOLUME_NORMALISATION_TARGET, + CONF_ENTRY_VOLUME_NORMALIZATION_TARGET, CONF_ENTRY_EQ_BASS, CONF_ENTRY_EQ_MID, CONF_ENTRY_EQ_TREBLE, diff --git a/music_assistant/common/models/media_items.py b/music_assistant/common/models/media_items.py index 83ebf87ed..207f87d0b 100755 --- a/music_assistant/common/models/media_items.py +++ b/music_assistant/common/models/media_items.py @@ -55,9 +55,13 @@ def quality(self) -> int: score += 1 return int(score) - def __hash__(self): + def __hash__(self) -> int: """Return custom hash.""" - return hash((self.provider_domain, self.item_id)) + return hash((self.provider_instance, self.item_id)) + + def __eq__(self, other: ProviderMapping) -> bool: + """Check equality of two items.""" + return self.provider_instance == other.provider_instance and self.item_id == other.item_id @dataclass(frozen=True) @@ -67,10 +71,14 @@ class MediaItemLink(DataClassDictMixin): type: LinkType url: str - def __hash__(self): + def __hash__(self) -> int: """Return custom hash.""" return hash(self.type) + def __eq__(self, other: MediaItemLink) -> bool: + """Check equality of two items.""" + return self.url == other.url + @dataclass(frozen=True) class MediaItemImage(DataClassDictMixin): @@ -82,9 +90,13 @@ class MediaItemImage(DataClassDictMixin): # if the path is just a plain (remotely accessible) URL, set it to 'url' provider: str = "url" - def __hash__(self): + def __hash__(self) -> int: """Return custom hash.""" - return hash(self.type.value, self.path) + return hash((self.type.value, self.path)) + + def __eq__(self, other: MediaItemImage) -> bool: + """Check equality of two items.""" + return self.__hash__() == other.__hash__() @dataclass(frozen=True) @@ -96,10 +108,14 @@ class MediaItemChapter(DataClassDictMixin): position_end: float | None = None title: str | None = None - def __hash__(self): + def __hash__(self) -> int: """Return custom hash.""" return hash(self.chapter_id) + def __eq__(self, other: MediaItemChapter) -> bool: + """Check equality of two items.""" + return self.chapter_id == other.chapter_id + @dataclass class MediaItemMetadata(DataClassDictMixin): @@ -255,10 +271,6 @@ def add_provider_mapping(self, prov_mapping: ProviderMapping) -> None: } self.provider_mappings.add(prov_mapping) - def __hash__(self): - """Return custom hash.""" - return hash((self.media_type, self.provider, self.item_id)) - @dataclass class ItemMapping(DataClassDictMixin): @@ -280,10 +292,6 @@ def from_item(cls, item: MediaItem): result.available = item.available return result - def __hash__(self): - """Return custom hash.""" - return hash((self.media_type, self.provider, self.item_id)) - def __post_init__(self): """Call after init.""" if not self.uri: @@ -291,6 +299,10 @@ def __post_init__(self): if not self.sort_name: self.sort_name = create_sort_name(self.name) + def __hash__(self) -> int: + """Return custom hash.""" + return hash((self.media_type.value, self.provider, self.item_id)) + @dataclass class Artist(MediaItem): @@ -299,10 +311,6 @@ class Artist(MediaItem): media_type: MediaType = MediaType.ARTIST musicbrainz_id: str | None = None - def __hash__(self): - """Return custom hash.""" - return hash((self.provider, self.item_id)) - @dataclass class Album(MediaItem): @@ -316,10 +324,6 @@ class Album(MediaItem): barcode: set[str] = field(default_factory=set) musicbrainz_id: str | None = None # release group id - def __hash__(self): - """Return custom hash.""" - return hash((self.provider, self.item_id)) - @dataclass class DbAlbum(Album): diff --git a/music_assistant/constants.py b/music_assistant/constants.py index 707caf914..a76d0a380 100755 --- a/music_assistant/constants.py +++ b/music_assistant/constants.py @@ -3,7 +3,7 @@ import pathlib from typing import Final -__version__: Final[str] = "2.0.0b29" +__version__: Final[str] = "2.0.0b30" SCHEMA_VERSION: Final[int] = 22 @@ -35,8 +35,8 @@ CONF_PATH: Final[str] = "path" CONF_USERNAME: Final[str] = "username" CONF_PASSWORD: Final[str] = "password" -CONF_VOLUME_NORMALISATION: Final[str] = "volume_normalisation" -CONF_VOLUME_NORMALISATION_TARGET: Final[str] = "volume_normalisation_target" +CONF_VOLUME_NORMALIZATION: Final[str] = "volume_normalization" +CONF_VOLUME_NORMALIZATION_TARGET: Final[str] = "volume_normalization_target" CONF_MAX_SAMPLE_RATE: Final[str] = "max_sample_rate" CONF_EQ_BASS: Final[str] = "eq_bass" CONF_EQ_MID: Final[str] = "eq_mid" @@ -51,8 +51,6 @@ # config default values DEFAULT_HOST: Final[str] = "0.0.0.0" DEFAULT_PORT: Final[int] = 8095 -DEFAULT_DB_LIBRARY: Final[str] = "sqlite:///[storage_path]/library.db" -DEFAULT_DB_CACHE: Final[str] = "sqlite:///[storage_path]/cache.db" # common db tables DB_TABLE_TRACK_LOUDNESS: Final[str] = "track_loudness" diff --git a/music_assistant/server/controllers/cache.py b/music_assistant/server/controllers/cache.py index 6a9909a1b..14afdff6b 100644 --- a/music_assistant/server/controllers/cache.py +++ b/music_assistant/server/controllers/cache.py @@ -5,16 +5,15 @@ import functools import json import logging +import os import time from collections import OrderedDict from collections.abc import Iterator, MutableMapping from typing import TYPE_CHECKING, Any from music_assistant.constants import ( - CONF_DB_CACHE, DB_TABLE_CACHE, DB_TABLE_SETTINGS, - DEFAULT_DB_CACHE, ROOT_LOGGER_NAME, SCHEMA_VERSION, ) @@ -43,6 +42,7 @@ async def setup(self) -> None: async def close(self) -> None: """Cleanup on exit.""" + await self.database.close() async def get(self, cache_key: str, checksum: str | None = None, default=None): """Get object from cache and return the results. @@ -120,9 +120,9 @@ async def auto_cleanup(self): async def _setup_database(self): """Initialize database.""" - db_url: str = self.mass.config.get(CONF_DB_CACHE, DEFAULT_DB_CACHE) - db_url = db_url.replace("[storage_path]", self.mass.storage_path) - self.database = DatabaseConnection(db_url) + db_path = os.path.join(self.mass.storage_path, "cache.db") + self.database = DatabaseConnection(db_path) + await self.database.setup() # always create db tables if they don't exist to prevent errors trying to access them later await self.__create_database_tables() diff --git a/music_assistant/server/controllers/config.py b/music_assistant/server/controllers/config.py index a78e00bd4..b44e49619 100644 --- a/music_assistant/server/controllers/config.py +++ b/music_assistant/server/controllers/config.py @@ -326,7 +326,7 @@ def save_player_config( data=config, ) # signal update to the player manager - with suppress(PlayerUnavailableError): + with suppress(PlayerUnavailableError, AttributeError): player = self.mass.players.get(config.player_id) player.enabled = config.enabled self.mass.players.update(config.player_id, force_update=True) diff --git a/music_assistant/server/controllers/media/albums.py b/music_assistant/server/controllers/media/albums.py index b4223322c..e03b3b256 100644 --- a/music_assistant/server/controllers/media/albums.py +++ b/music_assistant/server/controllers/media/albums.py @@ -96,34 +96,22 @@ async def add(self, item: Album, skip_metadata_lookup: bool = False) -> Album: # grab additional metadata if not skip_metadata_lookup: await self.mass.metadata.get_album_metadata(item) - async with self._db_add_lock: - # use the lock to prevent a race condition of the same item being added twice - existing = await self.get_db_item_by_prov_id(item.item_id, item.provider) - if existing: - db_item = await self._update_db_item(existing.item_id, item) + if item.provider == "database": + db_item = await self._update_db_item(item.item_id, item) else: - db_item = await self._add_db_item(item) + # use the lock to prevent a race condition of the same item being added twice + async with self._db_add_lock: + db_item = await self._add_db_item(item) # also fetch the same album on all providers if not skip_metadata_lookup: await self._match(db_item) - # return final db_item after all match/metadata actions - db_item = await self.get_db_item(db_item.item_id) - # preload album tracks in db + # preload album tracks listing (do not load them in the db) for prov_mapping in db_item.provider_mappings: - for track in await self._get_provider_album_tracks( + await self._get_provider_album_tracks( prov_mapping.item_id, prov_mapping.provider_instance - ): - if not await self.mass.music.tracks.get_db_item_by_prov_id( - track.item_id, track.provider - ): - track.album = db_item - await self.mass.music.tracks.add(track, skip_metadata_lookup=True) - self.mass.signal_event( - EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED, - db_item.uri, - db_item, - ) - return db_item + ) + # return final db_item after all match/metadata actions + return await self.get_db_item(db_item.item_id) async def update(self, item_id: str | int, update: Album, overwrite: bool = False) -> Album: """Update existing record in the database.""" @@ -202,53 +190,63 @@ async def _add_db_item(self, item: Album) -> Album: """Add a new record to the database.""" assert item.provider_mappings, "Item is missing provider mapping(s)" assert item.artists, f"Album {item.name} is missing artists" - cur_item = None + # safety guard: check for existing item first - # use the lock to prevent a race condition of the same item being added twice - async with self._db_add_lock: - # always try to grab existing item by musicbrainz_id - if item.musicbrainz_id: - match = {"musicbrainz_id": item.musicbrainz_id} - cur_item = await self.mass.music.database.get_row(self.db_table, match) - # try barcode/upc - if not cur_item and item.barcode: - for barcode in item.barcode: - if search_result := await self.mass.music.database.search( - self.db_table, barcode, "barcode" - ): - cur_item = Album.from_db_row(search_result[0]) - break - if not cur_item: - # fallback to search and match - for row in await self.mass.music.database.search(self.db_table, item.name): - row_album = Album.from_db_row(row) - if compare_album(row_album, item): - cur_item = row_album - break - if cur_item: - # update existing + if cur_item := await self.get_db_item_by_prov_id(item.item_id, item.provider): + # existing item found: update it return await self._update_db_item(cur_item.item_id, item) + if item.musicbrainz_id: + match = {"musicbrainz_id": item.musicbrainz_id} + if db_row := await self.mass.music.database.get_row(self.db_table, match): + cur_item = Album.from_db_row(db_row) + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) + # try barcode/upc + if not cur_item and item.barcode: + for barcode in item.barcode: + if search_result := await self.mass.music.database.search( + self.db_table, barcode, "barcode" + ): + cur_item = Album.from_db_row(search_result[0]) + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) + # fallback to search and match + match = {"sort_name": item.sort_name} + for row in await self.mass.music.database.get_rows(self.db_table, match): + row_album = Album.from_db_row(row) + if compare_album(row_album, item): + cur_item = row_album + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) # insert new item album_artists = await self._get_artist_mappings(item, cur_item) sort_artist = album_artists[0].sort_name if album_artists else "" - async with self._db_add_lock: - new_item = await self.mass.music.database.insert( - self.db_table, - { - **item.to_db_row(), - "artists": serialize_to_json(album_artists) or None, - "sort_artist": sort_artist, - "timestamp_added": int(utc_timestamp()), - "timestamp_modified": int(utc_timestamp()), - }, - ) - item_id = new_item["item_id"] + new_item = await self.mass.music.database.insert( + self.db_table, + { + **item.to_db_row(), + "artists": serialize_to_json(album_artists) or None, + "sort_artist": sort_artist, + "timestamp_added": int(utc_timestamp()), + "timestamp_modified": int(utc_timestamp()), + }, + ) + db_id = new_item["item_id"] # update/set provider_mappings table - await self._set_provider_mappings(item_id, item.provider_mappings) + await self._set_provider_mappings(db_id, item.provider_mappings) self.logger.debug("added %s to database", item.name) - # return created object - return await self.get_db_item(item_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_ADDED, + db_item.uri, + db_item, + ) + # return the full item we just added + return db_item async def _update_db_item( self, item_id: str | int, item: Album | ItemMapping, overwrite: bool = False @@ -287,7 +285,17 @@ async def _update_db_item( # update/set provider_mappings table await self._set_provider_mappings(db_id, provider_mappings) self.logger.debug("updated %s in database: %s", item.name, db_id) - return await self.get_db_item(db_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_UPDATED, + db_item.uri, + db_item, + ) + # return the full item we just updated + return db_item async def _get_provider_album_tracks( self, item_id: str, provider_instance_id_or_domain: str diff --git a/music_assistant/server/controllers/media/artists.py b/music_assistant/server/controllers/media/artists.py index aaf68c5c1..ed8939a36 100644 --- a/music_assistant/server/controllers/media/artists.py +++ b/music_assistant/server/controllers/media/artists.py @@ -60,24 +60,17 @@ async def add(self, item: Artist | ItemMapping, skip_metadata_lookup: bool = Fal # grab musicbrainz id and additional metadata if not skip_metadata_lookup: await self.mass.metadata.get_artist_metadata(item) - async with self._db_add_lock: - # use the lock to prevent a race condition of the same item being added twice - existing = await self.get_db_item_by_prov_id(item.item_id, item.provider) - if existing: - db_item = await self._update_db_item(existing.item_id, item) + if item.provider == "database": + db_item = await self._update_db_item(item.item_id, item) else: - db_item = await self._add_db_item(item) + # use the lock to prevent a race condition of the same item being added twice + async with self._db_add_lock: + db_item = await self._add_db_item(item) # also fetch same artist on all providers if not skip_metadata_lookup: await self.match_artist(db_item) # return final db_item after all match/metadata actions - db_item = await self.get_db_item(db_item.item_id) - self.mass.signal_event( - EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED, - db_item.uri, - db_item, - ) - return db_item + return await self.get_db_item(db_item.item_id) async def update(self, item_id: str | int, update: Artist, overwrite: bool = False) -> Artist: """Update existing record in the database.""" @@ -292,43 +285,54 @@ async def _add_db_item(self, item: Artist | ItemMapping) -> Artist: if item.musicbrainz_id == VARIOUS_ARTISTS_ID: item.name = VARIOUS_ARTISTS # safety guard: check for existing item first - # use the lock to prevent a race condition of the same item being added twice - async with self._db_add_lock: - # always try to grab existing item by musicbrainz_id - cur_item = None - if musicbrainz_id := getattr(item, "musicbrainz_id", None): - match = {"musicbrainz_id": musicbrainz_id} - cur_item = await self.mass.music.database.get_row(self.db_table, match) - if not cur_item: - # fallback to exact name match - # NOTE: we match an artist by name which could theoretically lead to collisions - # but the chance is so small it is not worth the additional overhead of grabbing - # the musicbrainz id upfront - match = {"sort_name": item.sort_name} - for row in await self.mass.music.database.get_rows(self.db_table, match): - row_artist = Artist.from_db_row(row) - if row_artist.sort_name == item.sort_name: - cur_item = row_artist - break - if cur_item: - # update existing + if isinstance(item, ItemMapping) and ( + cur_item := await self.get_db_item_by_prov_id(item.item_id, item.provider) + ): + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) + if cur_item := await self.get_db_item_by_prov_mappings(item.provider_mappings): return await self._update_db_item(cur_item.item_id, item) + if musicbrainz_id := getattr(item, "musicbrainz_id", None): + match = {"musicbrainz_id": musicbrainz_id} + if db_row := await self.mass.music.database.get_row(self.db_table, match): + # existing item found: update it + cur_item = Artist.from_db_row(db_row) + return await self._update_db_item(cur_item.item_id, item) + # fallback to exact name match + # NOTE: we match an artist by name which could theoretically lead to collisions + # but the chance is so small it is not worth the additional overhead of grabbing + # the musicbrainz id upfront + match = {"sort_name": item.sort_name} + for row in await self.mass.music.database.get_rows(self.db_table, match): + row_artist = Artist.from_db_row(row) + if row_artist.sort_name == item.sort_name: + cur_item = row_artist + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) - # insert item + # no existing item matched: insert item item.timestamp_added = int(utc_timestamp()) item.timestamp_modified = int(utc_timestamp()) # edge case: item is an ItemMapping, # try to construct (a half baken) Artist object from it if isinstance(item, ItemMapping): item = Artist.from_dict(item.to_dict()) - async with self._db_add_lock: - new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row()) - item_id = new_item["item_id"] + new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row()) + db_id = new_item["item_id"] # update/set provider_mappings table - await self._set_provider_mappings(item_id, item.provider_mappings) + await self._set_provider_mappings(db_id, item.provider_mappings) self.logger.debug("added %s to database", item.name) - # return created object - return await self.get_db_item(item_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_ADDED, + db_item.uri, + db_item, + ) + # return the full item we just added + return db_item async def _update_db_item( self, item_id: str | int, item: Artist | ItemMapping, overwrite: bool = False @@ -361,7 +365,17 @@ async def _update_db_item( # update/set provider_mappings table await self._set_provider_mappings(db_id, provider_mappings) self.logger.debug("updated %s in database: %s", item.name, db_id) - return await self.get_db_item(db_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_UPDATED, + db_item.uri, + db_item, + ) + # return the full item we just updated + return db_item async def _get_provider_dynamic_tracks( self, diff --git a/music_assistant/server/controllers/media/base.py b/music_assistant/server/controllers/media/base.py index ed03125c3..ad0cbcc8e 100644 --- a/music_assistant/server/controllers/media/base.py +++ b/music_assistant/server/controllers/media/base.py @@ -1,7 +1,6 @@ """Base (ABC) MediaType specific controller.""" from __future__ import annotations -import asyncio import logging from abc import ABCMeta, abstractmethod from collections.abc import AsyncGenerator @@ -36,7 +35,6 @@ class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta): media_type: MediaType item_cls: MediaItemType db_table: str - _db_add_lock = asyncio.Lock() def __init__(self, mass: MusicAssistant): """Initialize class.""" @@ -340,6 +338,27 @@ async def get_db_item_by_prov_id( return item return None + async def get_db_item_by_prov_mappings( + self, + provider_mappings: list[ProviderMapping], + ) -> ItemCls | None: + """Get the database item for the given provider_instance.""" + # always prefer provider instance first + for mapping in provider_mappings: + for item in await self.get_db_items_by_prov_id( + mapping.provider_instance, + provider_item_ids=(mapping.item_id,), + ): + return item + # check by domain too + for mapping in provider_mappings: + for item in await self.get_db_items_by_prov_id( + mapping.provider_domain, + provider_item_ids=(mapping.item_id,), + ): + return item + return None + async def get_db_items_by_prov_id( self, provider_instance_id_or_domain: str, @@ -392,8 +411,7 @@ async def set_db_library(self, item_id: str | int, in_library: bool) -> None: """Set the in-library bool on a database item.""" db_id = int(item_id) # ensure integer match = {"item_id": db_id} - async with self._db_add_lock: - await self.mass.music.database.update(self.db_table, match, {"in_library": in_library}) + await self.mass.music.database.update(self.db_table, match, {"in_library": in_library}) db_item = await self.get_db_item(db_id) self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, db_item.uri, db_item) @@ -429,7 +447,7 @@ async def get_provider_item( return fallback_item raise MediaNotFoundError( f"{self.media_type.value}://{item_id} not " - "found on provider {provider_instance_id_or_domain}" + f"found on provider {provider_instance_id_or_domain}" ) async def remove_prov_mapping(self, item_id: str | int, provider_instance_id: str) -> None: @@ -442,15 +460,14 @@ async def remove_prov_mapping(self, item_id: str | int, provider_instance_id: st return # update provider_mappings table - async with self._db_add_lock: - await self.mass.music.database.delete( - DB_TABLE_PROVIDER_MAPPINGS, - { - "media_type": self.media_type.value, - "item_id": db_id, - "provider_instance": provider_instance_id, - }, - ) + await self.mass.music.database.delete( + DB_TABLE_PROVIDER_MAPPINGS, + { + "media_type": self.media_type.value, + "item_id": db_id, + "provider_instance": provider_instance_id, + }, + ) # update the item in db (provider_mappings column only) db_item.provider_mappings = { @@ -458,12 +475,11 @@ async def remove_prov_mapping(self, item_id: str | int, provider_instance_id: st } match = {"item_id": db_id} if db_item.provider_mappings: - async with self._db_add_lock: - await self.mass.music.database.update( - self.db_table, - match, - {"provider_mappings": serialize_to_json(db_item.provider_mappings)}, - ) + await self.mass.music.database.update( + self.db_table, + match, + {"provider_mappings": serialize_to_json(db_item.provider_mappings)}, + ) self.logger.debug("removed provider %s from item id %s", provider_instance_id, db_id) self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, db_item.uri, db_item) else: @@ -511,24 +527,40 @@ async def _set_provider_mappings( ) -> None: """Update the provider_items table for the media item.""" db_id = int(item_id) # ensure integer - # clear all records first - async with self._db_add_lock: - await self.mass.music.database.delete( - DB_TABLE_PROVIDER_MAPPINGS, - {"media_type": self.media_type.value, "item_id": db_id}, + # get current mappings (if any) + cur_mappings = set() + match = {"media_type": self.media_type.value, "item_id": db_id} + for db_row in await self.mass.music.database.get_rows(DB_TABLE_PROVIDER_MAPPINGS, match): + cur_mappings.add( + ProviderMapping( + item_id=db_row["provider_item_id"], + provider_domain=db_row["provider_domain"], + provider_instance=db_row["provider_instance"], + ) ) - # add entries - for provider_mapping in provider_mappings: - await self.mass.music.database.insert_or_replace( + # delete removed mappings + for prov_mapping in cur_mappings: + if prov_mapping not in set(provider_mappings): + await self.mass.music.database.delete( DB_TABLE_PROVIDER_MAPPINGS, { - "media_type": self.media_type.value, - "item_id": db_id, - "provider_domain": provider_mapping.provider_domain, - "provider_instance": provider_mapping.provider_instance, - "provider_item_id": provider_mapping.item_id, + **match, + "provider_domain": prov_mapping.provider_domain, + "provider_instance": prov_mapping.provider_instance, + "provider_item_id": prov_mapping.item_id, }, ) + # add entries + for provider_mapping in provider_mappings: + await self.mass.music.database.insert_or_replace( + DB_TABLE_PROVIDER_MAPPINGS, + { + **match, + "provider_domain": provider_mapping.provider_domain, + "provider_instance": provider_mapping.provider_instance, + "provider_item_id": provider_mapping.item_id, + }, + ) def _get_provider_mappings( self, diff --git a/music_assistant/server/controllers/media/playlists.py b/music_assistant/server/controllers/media/playlists.py index 21b97e3c1..ae213accf 100644 --- a/music_assistant/server/controllers/media/playlists.py +++ b/music_assistant/server/controllers/media/playlists.py @@ -45,23 +45,19 @@ def __init__(self, *args, **kwargs): async def add(self, item: Playlist, skip_metadata_lookup: bool = False) -> Playlist: """Add playlist to local db and return the new database item.""" - if not skip_metadata_lookup: - await self.mass.metadata.get_playlist_metadata(item) + if item.provider == "database": + db_item = await self._update_db_item(item.item_id, item) + else: + # use the lock to prevent a race condition of the same item being added twice + async with self._db_add_lock: + db_item = await self._add_db_item(item) # preload playlist tracks listing (do not load them in the db) - async for track in self.tracks(item.item_id, item.provider): + async for _ in self.tracks(item.item_id, item.provider): pass - async with self._db_add_lock: - # use the lock to prevent a race condition of the same item being added twice - existing = await self.get_db_item_by_prov_id(item.item_id, item.provider) - if existing: - db_item = await self._update_db_item(existing.item_id, item) - else: - db_item = await self._add_db_item(item) - self.mass.signal_event( - EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED, - db_item.uri, - db_item, - ) + # metadata lookup we need to do after adding it to the db + if not skip_metadata_lookup: + await self.mass.metadata.get_playlist_metadata(db_item) + db_item = await self._update_db_item(db_item.item_id, db_item) return db_item async def update(self, item_id: int, update: Playlist, overwrite: bool = False) -> Playlist: @@ -204,26 +200,35 @@ async def remove_playlist_tracks( async def _add_db_item(self, item: Playlist) -> Playlist: """Add a new record to the database.""" assert item.provider_mappings, "Item is missing provider mapping(s)" - cur_item = None # safety guard: check for existing item first - # use the lock to prevent a race condition of the same item being added twice - async with self._db_add_lock: - match = {"sort_name": item.sort_name, "owner": item.owner} - cur_item = await self.mass.music.database.get_row(self.db_table, match) - if cur_item: - # update existing - return await self._update_db_item(cur_item["item_id"], item) + if cur_item := await self.get_db_item_by_prov_mappings(item.provider_mappings): + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) + # try name matching + match = {"name": item.name, "owner": item.owner} + if db_row := await self.mass.music.database.get_row(self.db_table, match): + cur_item = Playlist.from_db_row(db_row) + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) # insert new item item.timestamp_added = int(utc_timestamp()) item.timestamp_modified = int(utc_timestamp()) - async with self._db_add_lock: - new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row()) - item_id = new_item["item_id"] + new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row()) + db_id = new_item["item_id"] # update/set provider_mappings table - await self._set_provider_mappings(item_id, item.provider_mappings) + await self._set_provider_mappings(db_id, item.provider_mappings) self.logger.debug("added %s to database", item.name) - # return created object - return await self.get_db_item(item_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_ADDED, + db_item.uri, + db_item, + ) + # return the full item we just added + return db_item async def _update_db_item( self, item_id: str | int, item: Playlist, overwrite: bool = False @@ -250,7 +255,17 @@ async def _update_db_item( # update/set provider_mappings table await self._set_provider_mappings(db_id, provider_mappings) self.logger.debug("updated %s in database: %s", item.name, db_id) - return await self.get_db_item(db_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_UPDATED, + db_item.uri, + db_item, + ) + # return the full item we just updated + return db_item async def _get_provider_playlist_tracks( self, diff --git a/music_assistant/server/controllers/media/radio.py b/music_assistant/server/controllers/media/radio.py index ef85561ba..8adc68b03 100644 --- a/music_assistant/server/controllers/media/radio.py +++ b/music_assistant/server/controllers/media/radio.py @@ -61,16 +61,12 @@ async def add(self, item: Radio, skip_metadata_lookup: bool = False) -> Radio: """Add radio to local db and return the new database item.""" if not skip_metadata_lookup: await self.mass.metadata.get_radio_metadata(item) - existing = await self.get_db_item_by_prov_id(item.item_id, item.provider) - if existing: - db_item = await self._update_db_item(existing.item_id, item) + if item.provider == "database": + db_item = await self._update_db_item(item.item_id, item) else: - db_item = await self._add_db_item(item) - self.mass.signal_event( - EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED, - db_item.uri, - db_item, - ) + # use the lock to prevent a race condition of the same item being added twice + async with self._db_add_lock: + db_item = await self._add_db_item(item) return db_item async def update(self, item_id: str | int, update: Radio, overwrite: bool = False) -> Radio: @@ -82,24 +78,34 @@ async def _add_db_item(self, item: Radio) -> Radio: assert item.provider_mappings, "Item is missing provider mapping(s)" cur_item = None # safety guard: check for existing item first - # use the lock to prevent a race condition of the same item being added twice - async with self._db_add_lock: - match = {"name": item.name} - cur_item = await self.mass.music.database.get_row(self.db_table, match) - if cur_item: - # update existing - return await self._update_db_item(cur_item["item_id"], item) + if cur_item := await self.get_db_item_by_prov_id(item.item_id, item.provider): + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) + # try name matching + match = {"name": item.name} + if db_row := await self.mass.music.database.get_row(self.db_table, match): + cur_item = Radio.from_db_row(db_row) + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) # insert new item item.timestamp_added = int(utc_timestamp()) item.timestamp_modified = int(utc_timestamp()) - async with self._db_add_lock: - new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row()) - item_id = new_item["item_id"] + new_item = await self.mass.music.database.insert(self.db_table, item.to_db_row()) + db_id = new_item["item_id"] # update/set provider_mappings table - await self._set_provider_mappings(item_id, item.provider_mappings) + await self._set_provider_mappings(db_id, item.provider_mappings) self.logger.debug("added %s to database", item.name) - # return created object - return await self.get_db_item(item_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_ADDED, + db_item.uri, + db_item, + ) + # return the full item we just added + return db_item async def _update_db_item( self, item_id: str | int, item: Radio, overwrite: bool = False @@ -125,7 +131,17 @@ async def _update_db_item( # update/set provider_mappings table await self._set_provider_mappings(db_id, provider_mappings) self.logger.debug("updated %s in database: %s", item.name, db_id) - return await self.get_db_item(db_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_UPDATED, + db_item.uri, + db_item, + ) + # return the full item we just updated + return db_item async def _get_provider_dynamic_tracks( self, diff --git a/music_assistant/server/controllers/media/tracks.py b/music_assistant/server/controllers/media/tracks.py index 619842ccc..e7ddf1c88 100644 --- a/music_assistant/server/controllers/media/tracks.py +++ b/music_assistant/server/controllers/media/tracks.py @@ -131,24 +131,17 @@ async def add(self, item: Track, skip_metadata_lookup: bool = False) -> Track: # grab additional metadata if not skip_metadata_lookup: await self.mass.metadata.get_track_metadata(item) - async with self._db_add_lock: - # use the lock to prevent a race condition of the same item being added twice - existing = await self.get_db_item_by_prov_id(item.item_id, item.provider) - if existing: - db_item = await self._update_db_item(existing.item_id, item) + if item.provider == "database": + db_item = await self._update_db_item(item.item_id, item) else: - db_item = await self._add_db_item(item) + # use the lock to prevent a race condition of the same item being added twice + async with self._db_add_lock: + db_item = await self._add_db_item(item) # also fetch same track on all providers (will also get other quality versions) if not skip_metadata_lookup: await self._match(db_item) # return final db_item after all match/metadata actions - db_item = await self.get_db_item(db_item.item_id) - self.mass.signal_event( - EventType.MEDIA_ITEM_UPDATED if existing else EventType.MEDIA_ITEM_ADDED, - db_item.uri, - db_item, - ) - return db_item + return await self.get_db_item(db_item.item_id) async def update(self, item_id: str | int, update: Track, overwrite: bool = False) -> Track: """Update existing record in the database.""" @@ -285,57 +278,65 @@ async def _add_db_item(self, item: Track) -> Track: assert isinstance(item, Track), "Not a full Track object" assert item.artists, "Track is missing artist(s)" assert item.provider_mappings, "Track is missing provider mapping(s)" - cur_item = None - # safety guard: check for existing item first - # use the lock to prevent a race condition of the same item being added twice - async with self._db_add_lock: - # always try to grab existing item by external_id - if item.musicbrainz_id: - match = {"musicbrainz_id": item.musicbrainz_id} - cur_item = await self.mass.music.database.get_row(self.db_table, match) - for isrc in item.isrc: - if search_result := await self.mass.music.database.search( - self.db_table, isrc, "isrc" - ): - cur_item = Track.from_db_row(search_result[0]) - break - if not cur_item: - # fallback to matching - match = {"sort_name": item.sort_name} - for row in await self.mass.music.database.get_rows(self.db_table, match): - row_track = Track.from_db_row(row) - if compare_track(row_track, item): - cur_item = row_track - break - if cur_item: - # update existing + if cur_item := await self.get_db_item_by_prov_mappings(item.provider_mappings): + # existing item found: update it return await self._update_db_item(cur_item.item_id, item) + # try matching on musicbrainz_id + if item.musicbrainz_id: + match = {"musicbrainz_id": item.musicbrainz_id} + if db_row := await self.mass.music.database.get_row(self.db_table, match): + cur_item = Track.from_db_row(db_row) + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) + # try matching on isrc + for isrc in item.isrc: + if search_result := await self.mass.music.database.search(self.db_table, isrc, "isrc"): + cur_item = Track.from_db_row(search_result[0]) + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) + # fallback to compare matching + match = {"sort_name": item.sort_name} + for row in await self.mass.music.database.get_rows(self.db_table, match): + row_track = Track.from_db_row(row) + if compare_track(row_track, item): + cur_item = row_track + # existing item found: update it + return await self._update_db_item(cur_item.item_id, item) # no existing match found: insert new item track_artists = await self._get_artist_mappings(item) track_albums = await self._get_track_albums(item) sort_artist = track_artists[0].sort_name if track_artists else "" sort_album = track_albums[0].sort_name if track_albums else "" - async with self._db_add_lock: - new_item = await self.mass.music.database.insert( - self.db_table, - { - **item.to_db_row(), - "artists": serialize_to_json(track_artists), - "albums": serialize_to_json(track_albums), - "sort_artist": sort_artist, - "sort_album": sort_album, - "timestamp_added": int(utc_timestamp()), - "timestamp_modified": int(utc_timestamp()), - }, - ) - item_id = new_item["item_id"] + new_item = await self.mass.music.database.insert( + self.db_table, + { + **item.to_db_row(), + "artists": serialize_to_json(track_artists), + "albums": serialize_to_json(track_albums), + "sort_artist": sort_artist, + "sort_album": sort_album, + "timestamp_added": int(utc_timestamp()), + "timestamp_modified": int(utc_timestamp()), + }, + ) + db_id = new_item["item_id"] # update/set provider_mappings table - await self._set_provider_mappings(item_id, item.provider_mappings) + await self._set_provider_mappings(db_id, item.provider_mappings) # return created object - self.logger.debug("added %s to database: %s", item.name, item_id) - return await self.get_db_item(item_id) + self.logger.debug("added %s to database: %s", item.name, db_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_ADDED, + db_item.uri, + db_item, + ) + # return the full item we just added + return db_item async def _update_db_item( self, item_id: str | int, item: Track | ItemMapping, overwrite: bool = False @@ -368,7 +369,17 @@ async def _update_db_item( # update/set provider_mappings table await self._set_provider_mappings(db_id, provider_mappings) self.logger.debug("updated %s in database: %s", item.name, db_id) - return await self.get_db_item(db_id) + # get full created object + db_item = await self.get_db_item(db_id) + # only signal event if we're not running a sync (to prevent a floodstorm of events) + if not self.mass.music.get_running_sync_tasks(): + self.mass.signal_event( + EventType.MEDIA_ITEM_UPDATED, + db_item.uri, + db_item, + ) + # return the full item we just updated + return db_item async def _get_track_albums( self, diff --git a/music_assistant/server/controllers/music.py b/music_assistant/server/controllers/music.py index b74aebcb4..3b90d99eb 100755 --- a/music_assistant/server/controllers/music.py +++ b/music_assistant/server/controllers/music.py @@ -3,6 +3,7 @@ import asyncio import logging +import os import statistics from itertools import zip_longest from typing import TYPE_CHECKING @@ -14,7 +15,6 @@ from music_assistant.common.models.media_items import BrowseFolder, MediaItemType, SearchResults from music_assistant.common.models.provider import SyncTask from music_assistant.constants import ( - CONF_DB_LIBRARY, DB_TABLE_ALBUMS, DB_TABLE_ARTISTS, DB_TABLE_PLAYLISTS, @@ -24,7 +24,6 @@ DB_TABLE_SETTINGS, DB_TABLE_TRACK_LOUDNESS, DB_TABLE_TRACKS, - DEFAULT_DB_LIBRARY, ROOT_LOGGER_NAME, SCHEMA_VERSION, ) @@ -59,6 +58,7 @@ def __init__(self, mass: MusicAssistant): self.radio = RadioController(mass) self.playlists = PlaylistController(mass) self.in_progress_syncs: list[SyncTask] = [] + self._sync_lock = asyncio.Lock() async def setup(self): """Async initialize of module.""" @@ -68,6 +68,7 @@ async def setup(self): async def close(self) -> None: """Cleanup on exit.""" + await self.database.close() @property def providers(self) -> list[MusicProvider]: @@ -95,8 +96,6 @@ async def start_sync( if provider.instance_id not in providers: continue self._start_provider_sync(provider.instance_id, media_types) - # trigger metadata scan after provider sync completed - self.mass.metadata.start_scan() # reschedule task if needed def create_sync_task(): @@ -107,7 +106,7 @@ def create_sync_task(): @api_command("music/synctasks") def get_running_sync_tasks(self) -> list[SyncTask]: - """Return list with providers that are currently syncing.""" + """Return list with providers that are currently (scheduled for) syncing.""" return self.in_progress_syncs @api_command("music/search") @@ -536,9 +535,16 @@ def _start_provider_sync(self, provider_instance: str, media_types: tuple[MediaT ) return - # we keep track of running sync tasks provider = self.mass.get_provider(provider_instance) - task = self.mass.create_task(provider.sync_library(media_types)) + + async def run_sync() -> None: + # Wrap the provider sync into a lock to prevent + # race conditions when multiple propviders are syncing at the same time. + async with self._sync_lock: + await provider.sync_library(media_types) + + # we keep track of running sync tasks + task = self.mass.create_task(run_sync()) sync_spec = SyncTask( provider_domain=provider.domain, provider_instance=provider.instance_id, @@ -552,6 +558,8 @@ def _start_provider_sync(self, provider_instance: str, media_types: tuple[MediaT def on_sync_task_done(task: asyncio.Task): # noqa: ARG001 self.in_progress_syncs.remove(sync_spec) self.mass.signal_event(EventType.SYNC_TASKS_UPDATED, data=self.in_progress_syncs) + # trigger metadata scan after provider sync completed + self.mass.metadata.start_scan() task.add_done_callback(on_sync_task_done) @@ -575,9 +583,9 @@ async def cleanup_provider(self, provider_instance: str) -> None: async def _setup_database(self): """Initialize database.""" - db_url: str = self.mass.config.get(CONF_DB_LIBRARY, DEFAULT_DB_LIBRARY) - db_url = db_url.replace("[storage_path]", self.mass.storage_path) - self.database = DatabaseConnection(db_url) + db_path = os.path.join(self.mass.storage_path, "library.db") + self.database = DatabaseConnection(db_path) + await self.database.setup() # always create db tables if they don't exist to prevent errors trying to access them later await self.__create_database_tables() @@ -727,8 +735,7 @@ async def __create_database_tables(self) -> None: provider_domain TEXT NOT NULL, provider_instance TEXT NOT NULL, provider_item_id TEXT NOT NULL, - UNIQUE(media_type, item_id, provider_instance, - provider_item_id, provider_item_id) + UNIQUE(media_type, provider_instance, provider_item_id) );""" ) diff --git a/music_assistant/server/controllers/player_queues.py b/music_assistant/server/controllers/player_queues.py index 5c291137d..7a90bc8d3 100755 --- a/music_assistant/server/controllers/player_queues.py +++ b/music_assistant/server/controllers/player_queues.py @@ -443,10 +443,10 @@ async def resume(self, queue_id: str) -> None: # track is already played for > 90% - skip to next resume_item = next_item resume_pos = 0 - elif queue.current_index is not None and len(queue_items) > 0: + elif not resume_item and queue.current_index is not None and len(queue_items) > 0: resume_item = self.get_item(queue_id, queue.current_index) resume_pos = 0 - elif queue.current_index is None and len(queue_items) > 0: + elif not resume_item and queue.current_index is None and len(queue_items) > 0: # items available in queue but no previous track, start at 0 resume_item = self.get_item(queue_id, 0) resume_pos = 0 diff --git a/music_assistant/server/helpers/app_vars.py b/music_assistant/server/helpers/app_vars.py index 7073846da..b754cf181 100644 --- a/music_assistant/server/helpers/app_vars.py +++ b/music_assistant/server/helpers/app_vars.py @@ -2,4 +2,4 @@ # fmt: off # flake8: noqa # type: ignore -(lambda __g: [(lambda __mod: [[[None for __g['app_var'], app_var.__name__ in [(lambda index: (lambda __l: [[AV(aap(__l['var'].encode()).decode()) for __l['var'] in [(vars.split('acb2')[__l['index']][::(-1)])]][0] for __l['index'] in [(index)]][0])({}), 'app_var')]][0] for __g['vars'] in [('3YTNyUDOyQTOacb2=EmN5M2YjdzMhljYzYzYhlDMmFGNlVTOmNDZwMzNxYzNacb2=UDMzEGOyADO1QWO5kDNygTMlJGN5QzNzIWOmZTOiVmMacb2yMTNzITNacb2=UDZhJmMldTZ3QTY4IjZ3kTNxYjN0czNwI2YxkTM5MjN')]][0] for __g['aap'] in [(__mod.b64decode)]][0])(__import__('base64', __g, __g, ('b64decode',), 0)) for __g['AV'] in [((lambda b, d: d.get('__metaclass__', getattr(b[0], '__class__', type(b[0])))('AV', b, d))((str,), (lambda __l: [__l for __l['__repr__'], __l['__repr__'].__name__ in [(lambda self: (lambda __l: [__name__ for __l['self'] in [(self)]][0])({}), '__repr__')]][0])({'__module__': __name__})))]][0])(globals()) +(lambda __g: [(lambda __mod: [[[None for __g['app_var'], app_var.__name__ in [(lambda index: (lambda __l: [[AV(aap(__l['var'].encode()).decode()) for __l['var'] in [(vars.split('acb2')[__l['index']][::-1])]][0] for __l['index'] in [(index)]][0])({}), 'app_var')]][0] for __g['vars'] in [('3YTNyUDOyQTOacb2=EmN5M2YjdzMhljYzYzYhlDMmFGNlVTOmNDZwMzNxYzNacb2=UDMzEGOyADO1QWO5kDNygTMlJGN5QzNzIWOmZTOiVmMacb2yMTNzITNacb2=UDZhJmMldTZ3QTY4IjZ3kTNxYjN0czNwI2YxkTM5MjNacb2==QMh5WOmZnewM2d4UDblRzZacb20QzMwAjNacb2=QzNiRTO3EjMjFzMldjY3QTMwEDMwADMiNWZ5UWO3UWM')]][0] for __g['aap'] in [(__mod.b64decode)]][0])(__import__('base64', __g, __g, ('b64decode',), 0)) for __g['AV'] in [((lambda b, d: d.get('__metaclass__', getattr(b[0], '__class__', type(b[0])))('AV', b, d))((str,), (lambda __l: [__l for __l['__repr__'], __l['__repr__'].__name__ in [(lambda self: (lambda __l: [__name__ for __l['self'] in [(self)]][0])({}), '__repr__')]][0])({'__module__': __name__})))]][0])(globals()) diff --git a/music_assistant/server/helpers/audio.py b/music_assistant/server/helpers/audio.py index 111c1b0a8..3b81eb2ad 100644 --- a/music_assistant/server/helpers/audio.py +++ b/music_assistant/server/helpers/audio.py @@ -18,8 +18,8 @@ from music_assistant.common.models.errors import AudioError, MediaNotFoundError, MusicAssistantError from music_assistant.common.models.media_items import ContentType, MediaType, StreamDetails from music_assistant.constants import ( - CONF_VOLUME_NORMALISATION, - CONF_VOLUME_NORMALISATION_TARGET, + CONF_VOLUME_NORMALIZATION, + CONF_VOLUME_NORMALIZATION_TARGET, ROOT_LOGGER_NAME, ) from music_assistant.server.helpers.process import AsyncProcess, check_output @@ -295,11 +295,11 @@ async def get_gain_correct( ) -> tuple[float | None, float | None]: """Get gain correction for given queue / track combination.""" player_settings = mass.config.get_player_config(streamdetails.queue_id) - if not player_settings or not player_settings.get_value(CONF_VOLUME_NORMALISATION): + if not player_settings or not player_settings.get_value(CONF_VOLUME_NORMALIZATION): return (None, None) if streamdetails.gain_correct is not None: return (streamdetails.loudness, streamdetails.gain_correct) - target_gain = player_settings.get_value(CONF_VOLUME_NORMALISATION_TARGET) + target_gain = player_settings.get_value(CONF_VOLUME_NORMALIZATION_TARGET) track_loudness = await mass.music.get_track_loudness( streamdetails.item_id, streamdetails.provider ) @@ -746,7 +746,7 @@ async def _get_ffmpeg_args( "Please install ffmpeg on your OS to enable playback.", ) - major_version = int(version.split(".")[0]) + major_version = int("".join(char for char in version.split(".")[0] if not char.isalpha())) # generic args generic_args = [ @@ -755,6 +755,8 @@ async def _get_ffmpeg_args( "-loglevel", "warning" if LOGGER.isEnabledFor(logging.DEBUG) else "quiet", "-ignore_unknown", + "-protocol_whitelist", + "file,http,https,tcp,tls,crypto,pipe", # support nested protocols (e.g. within playlist) ] # collect input args input_args = [] diff --git a/music_assistant/server/helpers/compare.py b/music_assistant/server/helpers/compare.py index c980d1c96..afc04bf2f 100644 --- a/music_assistant/server/helpers/compare.py +++ b/music_assistant/server/helpers/compare.py @@ -216,7 +216,7 @@ def compare_album( return left_album.musicbrainz_id == right_album.musicbrainz_id # fallback to comparing - if not compare_strings(left_album.name, right_album.name, False): + if not compare_strings(left_album.name, right_album.name, True): return False if not compare_version(left_album.version, right_album.version): return False @@ -261,7 +261,7 @@ def compare_track(left_track: Track, right_track: Track, strict: bool = True): # track artist(s) must match if not compare_artists(left_track.artists, right_track.artists): return False - # track if both tracks are (not) explicit + # check if both tracks are (not) explicit if strict and not compare_explicit(left_track.metadata, right_track.metadata): return False # exact albumtrack match = 100% match @@ -273,15 +273,25 @@ def compare_track(left_track: Track, right_track: Track, strict: bool = True): and left_track.track_number == right_track.track_number ): return True - # exact album match = 100% match - if left_track.albums and right_track.albums: + # check album match + if ( + not (album_match_found := compare_album(left_track.album, right_track.album)) + and left_track.albums + and right_track.albums + ): for left_album in left_track.albums: for right_album in right_track.albums: if compare_album(left_album, right_album): - return True + album_match_found = True + if ( + (left_album.disc_number or 1) == (right_album.disc_number or 1) + and left_album.track_number + and right_album.track_number + and left_album.track_number == right_album.track_number + ): + # exact albumtrack match = 100% match + return True # fallback: exact album match and (near-exact) track duration match - if abs(left_track.duration - right_track.duration) <= 3 and compare_album( - left_track.album, right_track.album - ): + if album_match_found and abs(left_track.duration - right_track.duration) <= 3: return True return False diff --git a/music_assistant/server/helpers/database.py b/music_assistant/server/helpers/database.py index 12a11ca2d..e76642c57 100755 --- a/music_assistant/server/helpers/database.py +++ b/music_assistant/server/helpers/database.py @@ -4,28 +4,26 @@ from collections.abc import Mapping from typing import Any -from databases import Database as Db -from databases import DatabaseURL -from sqlalchemy.sql import ClauseElement +import aiosqlite class DatabaseConnection: """Class that holds the (connection to the) database with some convenience helper functions.""" - def __init__(self, url: DatabaseURL): + _db: aiosqlite.Connection + + def __init__(self, db_path: str): """Initialize class.""" - self.url = url - # we maintain one global connection - otherwise we run into (dead)lock issues. - # https://github.com/encode/databases/issues/456 - self._db = Db(self.url, timeout=360) + self.db_path = db_path async def setup(self) -> None: """Perform async initialization.""" - await self._db.connect() + self._db = await aiosqlite.connect(self.db_path) + self._db.row_factory = aiosqlite.Row async def close(self) -> None: """Close db connection on exit.""" - await self._db.disconnect() + await self._db.close() async def get_rows( self, @@ -42,7 +40,7 @@ async def get_rows( if order_by is not None: sql_query += f" ORDER BY {order_by}" sql_query += f" LIMIT {limit} OFFSET {offset}" - return await self._db.fetch_all(sql_query, match) + return await self._db.execute_fetchall(sql_query, match) async def get_rows_from_query( self, @@ -53,7 +51,7 @@ async def get_rows_from_query( ) -> list[Mapping]: """Get all rows for given custom query.""" query = f"{query} LIMIT {limit} OFFSET {offset}" - return await self._db.fetch_all(query, params) + return await self._db.execute_fetchall(query, params) async def get_count_from_query( self, @@ -62,8 +60,9 @@ async def get_count_from_query( ) -> int: """Get row count for given custom query.""" query = f"SELECT count() FROM ({query})" - if result := await self._db.fetch_one(query, params): - return result[0] + async with self._db.execute(query, params) as cursor: + if result := await cursor.fetchone(): + return result[0] return 0 async def get_count( @@ -72,21 +71,23 @@ async def get_count( ) -> int: """Get row count for given table.""" query = f"SELECT count(*) FROM {table}" - if result := await self._db.fetch_one(query): - return result[0] + async with self._db.execute(query) as cursor: + if result := await cursor.fetchone(): + return result[0] return 0 async def search(self, table: str, search: str, column: str = "name") -> list[Mapping]: """Search table by column.""" sql_query = f"SELECT * FROM {table} WHERE {column} LIKE :search" params = {"search": f"%{search}%"} - return await self._db.fetch_all(sql_query, params) + return await self._db.execute_fetchall(sql_query, params) async def get_row(self, table: str, match: dict[str, Any]) -> Mapping | None: """Get single row for given table where column matches keys/values.""" sql_query = f"SELECT * FROM {table} WHERE " sql_query += " AND ".join(f"{x} = :{x}" for x in match) - return await self._db.fetch_one(sql_query, match) + async with self._db.execute(sql_query, match) as cursor: + return await cursor.fetchone() async def insert( self, @@ -102,6 +103,7 @@ async def insert( sql_query = f'INSERT INTO {table}({",".join(keys)})' sql_query += f' VALUES ({",".join((f":{x}" for x in keys))})' await self.execute(sql_query, values) + await self._db.commit() # return inserted/replaced item lookup_vals = {key: value for key, value in values.items() if value not in (None, "")} return await self.get_row(table, lookup_vals) @@ -121,6 +123,7 @@ async def update( sql_query = f'UPDATE {table} SET {",".join((f"{x}=:{x}" for x in keys))} WHERE ' sql_query += " AND ".join(f"{x} = :{x}" for x in match) await self.execute(sql_query, {**match, **values}) + await self._db.commit() # return updated item return await self.get_row(table, match) @@ -134,14 +137,15 @@ async def delete(self, table: str, match: dict | None = None, query: str | None sql_query += "WHERE " + query elif query: sql_query += query - await self.execute(sql_query, match) + await self._db.commit() async def delete_where_query(self, table: str, query: str | None = None) -> None: """Delete data in given table using given where clausule.""" sql_query = f"DELETE FROM {table} WHERE {query}" await self.execute(sql_query) + await self._db.commit() - async def execute(self, query: ClauseElement | str, values: dict = None) -> Any: + async def execute(self, query: str | str, values: dict = None) -> Any: """Execute command on the database.""" return await self._db.execute(query, values) diff --git a/music_assistant/server/helpers/tags.py b/music_assistant/server/helpers/tags.py index e8b19a82b..6430bff14 100644 --- a/music_assistant/server/helpers/tags.py +++ b/music_assistant/server/helpers/tags.py @@ -343,7 +343,7 @@ async def chunk_feeder(): data = json.loads(res) if error := data.get("error"): raise InvalidDataError(error["string"]) - if not data.get("streams") or data["streams"][0].get("codec_type") == "video": + if not data.get("streams"): raise InvalidDataError("Not an audio file") tags = AudioTags.parse(data) del res diff --git a/music_assistant/server/models/music_provider.py b/music_assistant/server/models/music_provider.py index 5056921e8..0ce032203 100644 --- a/music_assistant/server/models/music_provider.py +++ b/music_assistant/server/models/music_provider.py @@ -399,14 +399,12 @@ async def sync_library(self, media_types: tuple[MediaType, ...]) -> None: controller = self.mass.music.get_controller(media_type) cur_db_ids = set() async for prov_item in self._get_library_gen(media_type): - db_item: MediaItemType - if not ( - db_item := await controller.get_db_item_by_prov_id( - prov_item.item_id, - prov_item.provider, - ) - ): + db_item = await controller.get_db_item_by_prov_mappings( + prov_item.provider_mappings, + ) + if not db_item: # create full db item + prov_item.in_library = True db_item = await controller.add(prov_item, skip_metadata_lookup=True) elif ( db_item.metadata.checksum and prov_item.metadata.checksum @@ -418,17 +416,14 @@ async def sync_library(self, media_types: tuple[MediaType, ...]) -> None: await controller.set_db_library(db_item.item_id, True) # process deletions (= no longer in library) - async for db_item in controller.iter_db_items(True): - if db_item.item_id in cur_db_ids: - continue - for prov_mapping in db_item.provider_mappings: - provider_domains = {x.provider_domain for x in db_item.provider_mappings} - if len(provider_domains) > 1: - continue - if prov_mapping.provider_instance != self.instance_id: - continue - # only mark the item as not in library and leave the metadata in db - await controller.set_db_library(db_item.item_id, False) + cache_key = f"db_items.{media_type}.{self.instance_id}" + prev_db_items: list[int] | None + if prev_db_items := await self.mass.cache.get(cache_key): + for db_id in prev_db_items: + if db_id not in cur_db_ids: + # only mark the item as not in library and leave the metadata in db + await controller.set_db_library(db_id, False) + await self.mass.cache.set(cache_key, list(cur_db_ids)) # DO NOT OVERRIDE BELOW diff --git a/music_assistant/server/providers/deezer/__init__.py b/music_assistant/server/providers/deezer/__init__.py new file mode 100644 index 000000000..36625c023 --- /dev/null +++ b/music_assistant/server/providers/deezer/__init__.py @@ -0,0 +1,645 @@ +"""Deezer music provider support for MusicAssistant.""" +import hashlib +from asyncio import TaskGroup +from collections.abc import AsyncGenerator +from math import ceil + +import deezer +from aiohttp import ClientTimeout +from asyncio_throttle.throttler import Throttler +from Crypto.Cipher import Blowfish + +from music_assistant.common.models.config_entries import ( + ConfigEntry, + ConfigValueType, + ProviderConfig, +) +from music_assistant.common.models.enums import ( + AlbumType, + ConfigEntryType, + ContentType, + ImageType, + MediaType, + ProviderFeature, +) +from music_assistant.common.models.errors import LoginFailed +from music_assistant.common.models.media_items import ( + Album, + Artist, + BrowseFolder, + ItemMapping, + MediaItemImage, + MediaItemMetadata, + Playlist, + ProviderMapping, + SearchResults, + StreamDetails, + Track, +) +from music_assistant.common.models.provider import ProviderManifest +from music_assistant.server.helpers.app_vars import app_var # pylint: disable=no-name-in-module +from music_assistant.server.helpers.auth import AuthenticationHelper +from music_assistant.server.models import ProviderInstanceType +from music_assistant.server.models.music_provider import MusicProvider +from music_assistant.server.server import MusicAssistant + +from .gw_client import GWClient +from .helpers import Credential, DeezerClient + +SUPPORTED_FEATURES = ( + ProviderFeature.LIBRARY_ARTISTS, + ProviderFeature.LIBRARY_ALBUMS, + ProviderFeature.LIBRARY_TRACKS, + ProviderFeature.LIBRARY_PLAYLISTS, + ProviderFeature.LIBRARY_ALBUMS_EDIT, + ProviderFeature.LIBRARY_TRACKS_EDIT, + ProviderFeature.LIBRARY_ARTISTS_EDIT, + ProviderFeature.LIBRARY_PLAYLISTS_EDIT, + ProviderFeature.ALBUM_METADATA, + ProviderFeature.TRACK_METADATA, + ProviderFeature.ARTIST_METADATA, + ProviderFeature.ARTIST_ALBUMS, + ProviderFeature.ARTIST_TOPTRACKS, + ProviderFeature.BROWSE, + ProviderFeature.SEARCH, + ProviderFeature.PLAYLIST_TRACKS_EDIT, + ProviderFeature.PLAYLIST_CREATE, + ProviderFeature.RECOMMENDATIONS, +) + +CONF_ACCESS_TOKEN = "access_token" +CONF_ACTION_AUTH = "auth" +DEEZER_AUTH_URL = "https://connect.deezer.com/oauth/auth.php" +RELAY_URL = "https://deezer.oauth.jonathanbangert.com/" +DEEZER_PERMS = "basic_access,email,offline_access,manage_library,\ +manage_community,delete_library,listening_history" +DEEZER_APP_ID = app_var(6) +DEEZER_APP_SECRET = app_var(7) + + +async def setup( + mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig +) -> ProviderInstanceType: + """Initialize provider(instance) with given configuration.""" + prov = DeezerProvider(mass, manifest, config) + await prov.handle_setup() + return prov + + +async def get_config_entries( + mass: MusicAssistant, + instance_id: str | None = None, # noqa: ARG001 pylint: disable=W0613 + action: str | None = None, + values: dict[str, ConfigValueType] | None = None, +) -> tuple[ConfigEntry, ...]: + """Return Config entries to setup this provider.""" + # If the action is to launch oauth flow + if action == CONF_ACTION_AUTH: + # We use the AuthenticationHelper to authenticate + async with AuthenticationHelper(mass, values["session_id"]) as auth_helper: # type: ignore + callback_url = auth_helper.callback_url + url = f"{DEEZER_AUTH_URL}?app_id={DEEZER_APP_ID}&redirect_uri={RELAY_URL}\ +&perms={DEEZER_PERMS}&state={callback_url}" + code = (await auth_helper.authenticate(url))["code"] + values[CONF_ACCESS_TOKEN] = await DeezerProvider.update_access_token( # type: ignore + DeezerProvider, DEEZER_APP_ID, DEEZER_APP_SECRET, code, mass.http_session + ) + + return ( + ConfigEntry( + key=CONF_ACCESS_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Access token", + required=True, + action=CONF_ACTION_AUTH, + description="You need to authenticate on Deezer.", + action_label="Authenticate with Deezer", + value=values.get(CONF_ACCESS_TOKEN) if values else None, + ), + ) + + +class DeezerProvider(MusicProvider): + """Deezer provider support.""" + + client: DeezerClient + gw_client: GWClient + creds: Credential + _throttler: Throttler + + async def handle_setup(self) -> None: + """Set up the Deezer provider.""" + self._throttler = Throttler(rate_limit=4, period=1) + self.creds = Credential( + app_id=DEEZER_APP_ID, + app_secret=DEEZER_APP_SECRET, + access_token=self.config.get_value(CONF_ACCESS_TOKEN), # type: ignore + ) + try: + deezer_client = await DeezerClient.get_deezer_client(self=None, creds=self.creds) + self.client = DeezerClient(creds=self.creds, client=deezer_client) + except Exception as error: + raise LoginFailed("Invalid login credentials") from error + + self.gw_client = GWClient(self.mass.http_session, self.config.get_value(CONF_ACCESS_TOKEN)) + await self.gw_client.setup() + + @property + def supported_features(self) -> tuple[ProviderFeature, ...]: + """Return the features supported by this Provider.""" + return SUPPORTED_FEATURES + + async def search( + self, search_query: str, media_types=list[MediaType] | None, limit: int = 5 + ) -> SearchResults: + """Perform search on music provider. + + :param search_query: Search query. + :param media_types: A list of media_types to include. All types if None. + """ + if not media_types: + media_types = [MediaType.ARTIST, MediaType.ALBUM, MediaType.TRACK, MediaType.PLAYLIST] + + tasks = {} + + async with TaskGroup() as taskgroup: + for media_type in media_types: + if media_type == MediaType.TRACK: + tasks[MediaType.TRACK] = taskgroup.create_task( + self.search_and_parse_tracks( + query=search_query, + limit=limit, + user_country=self.gw_client.user_country, + ) + ) + elif media_type == MediaType.ARTIST: + tasks[MediaType.ARTIST] = taskgroup.create_task( + self.search_and_parse_artists(query=search_query, limit=limit) + ) + elif media_type == MediaType.ALBUM: + tasks[MediaType.ALBUM] = taskgroup.create_task( + self.search_and_parse_albums(query=search_query, limit=limit) + ) + elif media_type == MediaType.PLAYLIST: + tasks[MediaType.PLAYLIST] = taskgroup.create_task( + self.search_and_parse_playlists(query=search_query, limit=limit) + ) + + results = SearchResults() + + for media_type, task in tasks.items(): + if media_type == MediaType.ARTIST: + results.artists = task.result() + elif media_type == MediaType.ALBUM: + results.albums = task.result() + elif media_type == MediaType.TRACK: + results.tracks = task.result() + elif media_type == MediaType.PLAYLIST: + results.playlists = task.result() + + return results + + async def get_library_artists(self) -> AsyncGenerator[Artist, None]: + """Retrieve all library artists from Deezer.""" + for artist in await self.client.get_user_artists(): + yield self.parse_artist(artist=artist) + + async def get_library_albums(self) -> AsyncGenerator[Album, None]: + """Retrieve all library albums from Deezer.""" + for album in await self.client.get_user_albums(): + yield self.parse_album(album=album) + + async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]: + """Retrieve all library playlists from Deezer.""" + for playlist in await self.client.get_user_playlists(): + yield self.parse_playlist(playlist=playlist) + + async def get_library_tracks(self) -> AsyncGenerator[Track, None]: + """Retrieve all library tracks from Deezer.""" + for track in await self.client.get_user_tracks(): + if self.track_available(track, self.gw_client.user_country): + yield self.parse_track(track=track, user_country=self.gw_client.user_country) + + async def get_artist(self, prov_artist_id: str) -> Artist: + """Get full artist details by id.""" + return self.parse_artist(artist=await self.client.get_artist(artist_id=int(prov_artist_id))) + + async def get_album(self, prov_album_id: str) -> Album: + """Get full album details by id.""" + try: + return self.parse_album(album=await self.client.get_album(album_id=int(prov_album_id))) + except deezer.exceptions.DeezerErrorResponse as error: + self.logger.warning("Failed getting album: %s", error) + return Album(prov_album_id, self.instance_id, "Not Found") + + async def get_playlist(self, prov_playlist_id: str) -> Playlist: + """Get full playlist details by id.""" + return self.parse_playlist( + playlist=await self.client.get_playlist(playlist_id=int(prov_playlist_id)), + ) + + async def get_track(self, prov_track_id: str) -> Track: + """Get full track details by id.""" + return self.parse_track( + track=await self.client.get_track(track_id=int(prov_track_id)), + user_country=self.gw_client.user_country, + ) + + async def get_album_tracks(self, prov_album_id: str) -> list[Track]: + """Get all albums in a playlist.""" + album = await self.client.get_album(album_id=int(prov_album_id)) + return [ + self.parse_track(track=track, user_country=self.gw_client.user_country) + for track in album.tracks + if self.track_available(track, self.gw_client.user_country) + ] + + async def get_playlist_tracks(self, prov_playlist_id: str) -> AsyncGenerator[Track, None]: + """Get all tracks in a playlist.""" + playlist = await self.client.get_playlist(playlist_id=prov_playlist_id) + for count, track in enumerate(playlist.tracks, start=1): + track_parsed = self.parse_track(track=track, user_country=self.gw_client.user_country) + track_parsed.position = count + track_parsed.id = track.id + yield track_parsed + + async def get_artist_albums(self, prov_artist_id: str) -> list[Album]: + """Get albums by an artist.""" + artist = await self.client.get_artist(artist_id=int(prov_artist_id)) + albums = [] + for album in await self.client.get_albums_by_artist(artist=artist): + albums.append(self.parse_album(album=album)) + return albums + + async def get_artist_toptracks(self, prov_artist_id: str) -> list[Track]: + """Get top 25 tracks of an artist.""" + artist = await self.client.get_artist(artist_id=int(prov_artist_id)) + top_tracks = (await self.client.get_artist_top(artist=artist))[:25] + return [ + self.parse_track(track=track, user_country=self.gw_client.user_country) + for track in top_tracks + if self.track_available(track, self.gw_client.user_country) + ] + + async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool: + """Add an item to the library.""" + result = False + if media_type == MediaType.ARTIST: + result = await self.client.add_user_artists( + artist_id=int(prov_item_id), + ) + elif media_type == MediaType.ALBUM: + result = await self.client.add_user_albums( + album_id=int(prov_item_id), + ) + elif media_type == MediaType.TRACK: + result = await self.client.add_user_tracks( + track_id=int(prov_item_id), + ) + elif media_type == MediaType.PLAYLIST: + result = await self.client.add_user_playlists( + playlist_id=int(prov_item_id), + ) + else: + raise NotImplementedError + return result + + async def library_remove(self, prov_item_id: str, media_type: MediaType) -> bool: + """Remove an item to the library.""" + result = False + if media_type == MediaType.ARTIST: + result = await self.client.remove_user_artists( + artist_id=int(prov_item_id), + ) + elif media_type == MediaType.ALBUM: + result = await self.client.remove_user_albums( + album_id=int(prov_item_id), + ) + elif media_type == MediaType.TRACK: + result = await self.client.remove_user_tracks( + track_id=int(prov_item_id), + ) + elif media_type == MediaType.PLAYLIST: + result = await self.client.remove_user_playlists( + playlist_id=int(prov_item_id), + ) + else: + raise NotImplementedError + return result + + async def recommendations(self) -> list[BrowseFolder]: + """Get deezer's recommendations.""" + browser_folder = BrowseFolder( + item_id="recommendations", + provider=self.domain, + path="recommendations", + name="Recommendations", + label="recommendations", + items=[ + self.parse_track(track=track, user_country=self.gw_client.user_country) + for track in await self.client.get_recommended_tracks() + ], + ) + return [browser_folder] + + async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]): + """Add tra ck(s) to playlist.""" + await self.client.add_playlist_tracks( + playlist_id=prov_playlist_id, tracks=[eval(i) for i in prov_track_ids] + ) + + async def remove_playlist_tracks( + self, prov_playlist_id: str, positions_to_remove: tuple[int, ...] + ): + """Remove track(s) to playlist.""" + playlist_track_ids = [] + async for track in self.get_playlist_tracks(prov_playlist_id): + if track.position in positions_to_remove: + playlist_track_ids.append(track.id) + if len(playlist_track_ids) == len(positions_to_remove): + break + await self.client.remove_playlist_tracks( + playlist_id=prov_playlist_id, tracks=list(playlist_track_ids) + ) + + async def create_playlist(self, name: str) -> Playlist: + """Create a new playlist on provider with given name.""" + playlist = await self.client.create_playlist(playlist_name=name) + return self.parse_playlist(playlist=playlist) + + async def get_stream_details(self, item_id: str) -> StreamDetails | None: + """Return the content details for the given track when it will be streamed.""" + url_details, song_data = await self.gw_client.get_deezer_track_urls(item_id) + url = url_details["sources"][0]["url"] + return StreamDetails( + item_id=item_id, + provider=self.instance_id, + content_type=ContentType.try_parse(url_details["format"].split("_")[0]), + duration=int(song_data["DURATION"]), + data=url, + expires=url_details["exp"], + size=int(song_data[f"FILESIZE_{url_details['format']}"]), + ) + + async def get_audio_stream( + self, streamdetails: StreamDetails, seek_position: int = 0 + ) -> AsyncGenerator[bytes, None]: + """Return the audio stream for the provider item.""" + blowfish_key = self.get_blowfish_key(streamdetails.item_id) + chunk_index = 0 + timeout = ClientTimeout(total=0, connect=30, sock_read=600) + headers = {} + if seek_position and streamdetails.size: + chunk_count = ceil(streamdetails.size / 2048) + chunk_index = int(chunk_count / streamdetails.duration) * seek_position + skip_bytes = chunk_index * 2048 + headers["Range"] = f"bytes={skip_bytes}-" + + buffer = bytearray() + async with self.mass.http_session.get( + streamdetails.data, headers=headers, timeout=timeout + ) as resp: + async for chunk in resp.content.iter_chunked(2048): + buffer += chunk + if len(buffer) >= 2048: + if chunk_index % 3 > 0: + yield bytes(buffer[:2048]) + else: + yield self.decrypt_chunk(bytes(buffer[:2048]), blowfish_key) + chunk_index += 1 + del buffer[:2048] + yield bytes(buffer) + + ### PARSING METADATA FUNCTIONS ### + + def parse_metadata_track(self, track: deezer.Track) -> MediaItemMetadata: + """Parse the track metadata.""" + try: + return MediaItemMetadata( + preview=track.preview, + images=[ + MediaItemImage( + type=ImageType.THUMB, + path=track.album.cover_big, + ) + ], + ) + except AttributeError: + return MediaItemMetadata( + preview=track.preview, + ) + + def parse_metadata_album(self, album: deezer.Album) -> MediaItemMetadata: + """Parse the album metadata.""" + return MediaItemMetadata( + images=[MediaItemImage(type=ImageType.THUMB, path=album.cover_big)], + ) + + def parse_metadata_artist(self, artist: deezer.Artist) -> MediaItemMetadata: + """Parse the artist metadata.""" + return MediaItemMetadata( + images=[MediaItemImage(type=ImageType.THUMB, path=artist.picture_big)], + ) + + ### PARSING FUNCTIONS ### + def parse_artist(self, artist: deezer.Artist) -> Artist: + """Parse the deezer-python artist to a MASS artist.""" + return Artist( + item_id=str(artist.id), + provider=self.domain, + name=artist.name, + media_type=MediaType.ARTIST, + provider_mappings={ + ProviderMapping( + item_id=str(artist.id), + provider_domain=self.domain, + provider_instance=self.instance_id, + ) + }, + metadata=self.parse_metadata_artist(artist=artist), + ) + + def parse_album(self, album: deezer.Album) -> Album: + """Parse the deezer-python album to a MASS album.""" + return Album( + album_type=AlbumType(album.type), + item_id=str(album.id), + provider=self.domain, + name=album.title, + artists=[ + ItemMapping( + MediaType.ARTIST, + str(album.artist.id), + self.instance_id, + album.artist.name, + ) + ], + media_type=MediaType.ALBUM, + provider_mappings={ + ProviderMapping( + item_id=str(album.id), + provider_domain=self.domain, + provider_instance=self.instance_id, + ) + }, + metadata=self.parse_metadata_album(album=album), + ) + + def parse_playlist(self, playlist: deezer.Playlist) -> Playlist: + """Parse the deezer-python playlist to a MASS playlist.""" + return Playlist( + item_id=str(playlist.id), + provider=self.domain, + name=playlist.title, + media_type=MediaType.PLAYLIST, + provider_mappings={ + ProviderMapping( + item_id=str(playlist.id), + provider_domain=self.domain, + provider_instance=self.instance_id, + ) + }, + metadata=MediaItemMetadata( + images=[MediaItemImage(type=ImageType.THUMB, path=playlist.picture_big)], + ), + is_editable=playlist.creator.id == self.client.user.id, + ) + + def parse_track(self, track: deezer.Track, user_country: str) -> Track: + """Parse the deezer-python track to a MASS track.""" + return Track( + item_id=str(track.id), + provider=self.domain, + name=track.title, + media_type=MediaType.TRACK, + sort_name=track.title_short, + position=track.track_position, + duration=track.duration, + artists=[ + ItemMapping( + MediaType.ARTIST, + str(track.artist.id), + self.instance_id, + track.artist.name, + ) + ], + album=ItemMapping( + MediaType.ALBUM, + str(track.album.id), + self.instance_id, + track.album.title, + ), + provider_mappings={ + ProviderMapping( + item_id=str(track.id), + provider_domain=self.domain, + provider_instance=self.instance_id, + available=user_country in track.available_countries, + ) + }, + metadata=self.parse_metadata_track(track=track), + ) + + ### SEARCH AND PARSE FUNCTIONS ### + async def search_and_parse_tracks( + self, query: str, user_country: str, limit: int = 5 + ) -> list[Track]: + """Search for tracks and parse them.""" + deezer_tracks = await self.client.search_track(query=query, limit=limit) + return [ + self.parse_track(track, user_country) + for track in deezer_tracks + if self.track_available(track, user_country) + ] + + async def search_and_parse_artists(self, query: str, limit: int = 5) -> list[Artist]: + """Search for artists and parse them.""" + deezer_artist = await self.client.search_artist(query=query, limit=limit) + return [self.parse_artist(artist=artist) for artist in deezer_artist] + + async def search_and_parse_albums(self, query: str, limit: int = 5) -> list[Album]: + """Search for album and parse them.""" + deezer_albums = await self.client.search_album(query=query, limit=limit) + return [self.parse_album(album=album) for album in deezer_albums] + + async def search_and_parse_playlists(self, query: str, limit: int = 5) -> list[Playlist]: + """Search for playlists and parse them.""" + deezer_playlists = await self.client.search_playlist(query=query, limit=limit) + return [self.parse_playlist(playlist=playlist) for playlist in deezer_playlists] + + ### OTHER PARSING FUNCTIONS ### + def _get_album(self, track: deezer.Track) -> Album | None: + try: + return self.parse_album(album=track.get_album()) + except AttributeError: + return None + + ### OTHER FUNCTIONS ### + async def update_access_token(self, app_id, app_secret, code, http_session=None) -> str: + """Update the access_token.""" + if not http_session: + http_session = self.mass.http_session + response = await self._post_http( # pylint: disable=E1124 + self=self, + http_session=http_session, + url="https://connect.deezer.com/oauth/access_token.php", + data={ + "code": code, + "app_id": app_id, + "secret": app_secret, + }, + params={ + "code": code, + "app_id": app_id, + "secret": app_secret, + }, + headers=None, + ) + try: + return response.split("=")[1].split("&")[0] + except Exception as error: + raise LoginFailed("Invalid auth code") from error + + async def _post_http(self, http_session, url, data, params=None, headers=None) -> str: + async with http_session.post( + url, headers=headers, params=params, json=data, ssl=False + ) as response: + if response.status != 200: + raise ConnectionError(f"HTTP Error {response.status}: {response.reason}") + response_text = await response.text() + return response_text + + async def get_track_content_type(self, gw_client: GWClient, track_id: int): + """Get a tracks contentType.""" + song_data = await gw_client.get_song_data(track_id) + if song_data["results"]["FILESIZE_FLAC"]: + return ContentType.FLAC + + if song_data["results"]["FILESIZE_MP3_320"] or song_data["results"]["FILESIZE_MP3_128"]: + return ContentType.MP3 + + raise NotImplementedError("Unsupported contenttype") + + def track_available(self, track: deezer.Track, user_country: str) -> bool: + """Check if a given track is available in the users country.""" + return user_country in track.available_countries + + def _md5(self, data, data_type="ascii"): + md5sum = hashlib.md5() + md5sum.update(data.encode(data_type)) + return md5sum.hexdigest() + + def get_blowfish_key(self, track_id): + """Get blowfish key to decrypt a chunk of a track.""" + secret = "g4el58wc" + "0zvf9na1" + id_md5 = self._md5(track_id) + return "".join( + chr(ord(id_md5[i]) ^ ord(id_md5[i + 16]) ^ ord(secret[i])) for i in range(16) + ) + + def decrypt_chunk(self, chunk, blowfish_key): + """Decrypt a given chunk using the blow fish key.""" + cipher = Blowfish.new( + blowfish_key.encode("ascii"), Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07" + ) + return cipher.decrypt(chunk) diff --git a/music_assistant/server/providers/deezer/gw_client.py b/music_assistant/server/providers/deezer/gw_client.py new file mode 100644 index 000000000..97e69896a --- /dev/null +++ b/music_assistant/server/providers/deezer/gw_client.py @@ -0,0 +1,153 @@ +"""A minimal client for the unofficial gw-API, which deezer is using on their website and app. + +Credits go out to RemixDev (https://gitlab.com/RemixDev) for figuring out, how to get the arl +cookie based on the api_token. +""" +import datetime +from http.cookies import BaseCookie, Morsel + +from aiohttp import ClientSession +from yarl import URL + +USER_AGENT_HEADER = ( + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/79.0.3945.130 Safari/537.36" +) + +GW_LIGHT_URL = "https://www.deezer.com/ajax/gw-light.php" + + +class DeezerGWError(BaseException): + """Exception type for GWClient related exceptions.""" + + pass + + +class GWClient: + """The GWClient class can be used to perform actions not being of the official API.""" + + _api_token: str + _gw_csrf_token: str | None + _license: str | None + _license_expiration_timestamp: int + session: ClientSession + formats: list[dict[str, str]] = [ + {"cipher": "BF_CBC_STRIPE", "format": "MP3_128"}, + ] + user_country: str + + def __init__(self, session: ClientSession, api_token: str): + """Provide an aiohttp ClientSession and the deezer api_token.""" + self._api_token = api_token + self.session = session + + async def _get_cookie(self): + await self.session.get( + "https://api.deezer.com/platform/generic/track/3135556", + headers={"Authorization": f"Bearer {self._api_token}", "User-Agent": USER_AGENT_HEADER}, + ) + json_response = await self._gw_api_call("user.getArl", False, http_method="GET") + arl = json_response.get("results") + + cookie = Morsel() + + cookie.set("arl", arl, arl) + cookie.domain = ".deezer.com" + cookie.path = "/" + cookie.httponly = {"HttpOnly": True} + + self.session.cookie_jar.update_cookies(BaseCookie({"arl": cookie}), URL(GW_LIGHT_URL)) + + async def _update_user_data(self): + user_data = await self._gw_api_call("deezer.getUserData", False) + if not user_data["results"]["USER"]["USER_ID"]: + await self._get_cookie() + user_data = await self._gw_api_call("deezer.getUserData", False) + + if not user_data["results"]["OFFER_ID"]: + raise DeezerGWError("Free subscriptions cannot be used in MA.") + + self._gw_csrf_token = user_data["results"]["checkForm"] + self._license = user_data["results"]["USER"]["OPTIONS"]["license_token"] + self._license_expiration_timestamp = user_data["results"]["USER"]["OPTIONS"][ + "expiration_timestamp" + ] + web_qualities = user_data["results"]["USER"]["OPTIONS"]["web_sound_quality"] + mobile_qualities = user_data["results"]["USER"]["OPTIONS"]["mobile_sound_quality"] + if web_qualities["high"] or mobile_qualities["high"]: + self.formats.insert(0, {"cipher": "BF_CBC_STRIPE", "format": "MP3_320"}) + if web_qualities["lossless"] or mobile_qualities["lossless"]: + self.formats.insert(0, {"cipher": "BF_CBC_STRIPE", "format": "FLAC"}) + + self.user_country = user_data["results"]["COUNTRY"] + + async def setup(self): + """Call this to let the client get its cookies, license and tokens.""" + await self._get_cookie() + await self._update_user_data() + + async def _get_license(self): + if ( + self._license_expiration_timestamp + < (datetime.datetime.now() + datetime.timedelta(days=1)).timestamp() + ): + await self._update_user_data() + return self._license + + async def _gw_api_call( + self, method, use_csrf_token=True, args=None, params=None, http_method="POST", retry=True + ): + csrf_token = self._gw_csrf_token if use_csrf_token else "null" + if params is None: + params = {} + p = {"api_version": "1.0", "api_token": csrf_token, "input": "3", "method": method} + p.update(params) + result = await self.session.request( + http_method, + GW_LIGHT_URL, + params=p, + timeout=30, + json=args, + headers={"User-Agent": USER_AGENT_HEADER}, + ) + result_json = await result.json() + if result_json["error"]: + if retry: + await self._update_user_data() + return await self._gw_api_call( + method, use_csrf_token, args, params, http_method, False + ) + else: + raise DeezerGWError("Failed to call GW-API", result_json["error"]) + return result_json + + async def get_song_data(self, track_id): + """Get data such as the track token for a given track.""" + return await self._gw_api_call("song.getData", args={"SNG_ID": track_id}) + + async def get_deezer_track_urls(self, track_id): + """Get the URL for a given track id.""" + dz_license = await self._get_license() + song_data = await self.get_song_data(track_id) + track_token = song_data["results"]["TRACK_TOKEN"] + url_data = { + "license_token": dz_license, + "media": [ + { + "type": "FULL", + "formats": self.formats, + } + ], + "track_tokens": [track_token], + } + url_response = await self.session.post( + "https://media.deezer.com/v1/get_url", + json=url_data, + headers={"User-Agent": USER_AGENT_HEADER}, + ) + result_json = await url_response.json() + + if error := result_json["data"][0].get("errors"): + raise DeezerGWError("Received an error from API", error) + + return result_json["data"][0]["media"][0], song_data["results"] diff --git a/music_assistant/server/providers/deezer/helpers.py b/music_assistant/server/providers/deezer/helpers.py new file mode 100644 index 000000000..01b96356b --- /dev/null +++ b/music_assistant/server/providers/deezer/helpers.py @@ -0,0 +1,322 @@ +"""Helper module for parsing the Deezer API. Also helper for getting audio streams. + +This helpers file is an async wrapper around the excellent deezer-python package. +While the deezer-python package does an excellent job at parsing the Deezer results, +it is unfortunately not async, which is required for Music Assistant to run smoothly. +This also nicely separates the parsing logic from the Deezer provider logic. + +CREDITS: +deezer-python: https://github.com/browniebroke/deezer-python by @browniebroke +""" + +import asyncio +from dataclasses import dataclass + +import deezer + + +@dataclass +class Credential: + """Class for storing credentials.""" + + def __init__(self, app_id: int, app_secret: str, access_token: str): + """Set the correct things.""" + self.app_id = app_id + self.app_secret = app_secret + self.access_token = access_token + + app_id: int + app_secret: str + access_token: str + + +class DeezerClient: + """Async wrapper of the deezer-python library.""" + + _client: deezer.Client + _creds: Credential + user: deezer.User + + def __init__(self, creds: Credential, client: deezer.Client): + """Initialize the client.""" + self._creds = creds + self._client = client + self.user = self._client.get_user() + + async def get_deezer_client(self, creds: Credential) -> deezer.Client: # type: ignore + """ + Return a deezer-python Client. + + If credentials are given the client is authorized. + If no credentials are given the deezer client is not authorized. + + :param creds: Credentials. If none are given client is not authorized, defaults to None + :type creds: credential, optional + """ + if not isinstance(creds, Credential): + raise TypeError("Creds must be of type credential") + + def _authorize(): + return deezer.Client( + app_id=creds.app_id, app_secret=creds.app_secret, access_token=creds.access_token + ) + + return await asyncio.to_thread(_authorize) + + async def get_artist(self, artist_id: int) -> deezer.Artist: + """Async wrapper of the deezer-python get_artist function.""" + + def _get_artist(): + artist = self._client.get_artist(artist_id=artist_id) + return artist + + return await asyncio.to_thread(_get_artist) + + async def get_album(self, album_id: int) -> deezer.Album: + """Async wrapper of the deezer-python get_album function.""" + + def _get_album(): + album = self._client.get_album(album_id=album_id) + return album + + return await asyncio.to_thread(_get_album) + + async def get_playlist(self, playlist_id) -> deezer.Playlist: + """Async wrapper of the deezer-python get_playlist function.""" + + def _get_playlist(): + playlist = self._client.get_playlist(playlist_id=playlist_id) + return playlist + + return await asyncio.to_thread(_get_playlist) + + async def get_track(self, track_id: int) -> deezer.Track: + """Async wrapper of the deezer-python get_track function.""" + + def _get_track(): + track = self._client.get_track(track_id=track_id) + return track + + return await asyncio.to_thread(_get_track) + + async def get_user_artists(self) -> deezer.PaginatedList: + """Async wrapper of the deezer-python get_user_artists function.""" + + def _get_artist(): + artists = self._client.get_user_artists() + return artists + + return await asyncio.to_thread(_get_artist) + + async def get_user_playlists(self) -> deezer.PaginatedList: + """Async wrapper of the deezer-python get_user_playlists function.""" + + def _get_playlist(): + playlists = self._client.get_user().get_playlists() + return playlists + + return await asyncio.to_thread(_get_playlist) + + async def get_user_albums(self) -> deezer.PaginatedList: + """Async wrapper of the deezer-python get_user_albums function.""" + + def _get_album(): + albums = self._client.get_user_albums() + return albums + + return await asyncio.to_thread(_get_album) + + async def get_user_tracks(self) -> deezer.PaginatedList: + """Async wrapper of the deezer-python get_user_tracks function.""" + + def _get_track(): + tracks = self._client.get_user_tracks() + return tracks + + return await asyncio.to_thread(_get_track) + + async def add_user_albums(self, album_id: int) -> bool: + """Async wrapper of the deezer-python add_user_albums function.""" + + def _get_track(): + success = self._client.add_user_album(album_id=album_id) + return success + + return await asyncio.to_thread(_get_track) + + async def remove_user_albums(self, album_id: int) -> bool: + """Async wrapper of the deezer-python remove_user_albums function.""" + + def _get_track(): + success = self._client.remove_user_album(album_id=album_id) + return success + + return await asyncio.to_thread(_get_track) + + async def add_user_tracks(self, track_id: int) -> bool: + """Async wrapper of the deezer-python add_user_tracks function.""" + + def _get_track(): + success = self._client.add_user_track(track_id=track_id) + return success + + return await asyncio.to_thread(_get_track) + + async def remove_user_tracks(self, track_id: int) -> bool: + """Async wrapper of the deezer-python remove_user_tracks function.""" + + def _get_track(): + success = self._client.remove_user_track(track_id=track_id) + return success + + return await asyncio.to_thread(_get_track) + + async def add_user_artists(self, artist_id: int) -> bool: + """Async wrapper of the deezer-python add_user_artists function.""" + + def _get_artist(): + success = self._client.add_user_artist(artist_id=artist_id) + return success + + return await asyncio.to_thread(_get_artist) + + async def remove_user_artists(self, artist_id: int) -> bool: + """Async wrapper of the deezer-python remove_user_artists function.""" + + def _get_artist(): + success = self._client.remove_user_artist(artist_id=artist_id) + return success + + return await asyncio.to_thread(_get_artist) + + async def add_user_playlists(self, playlist_id: int) -> bool: + """Async wrapper of the deezer-python add_user_playlists function.""" + + def _get_playlist(): + success = self._client.add_user_playlist(playlist_id=playlist_id) + return success + + return await asyncio.to_thread(_get_playlist) + + async def remove_user_playlists(self, playlist_id: int) -> bool: + """Async wrapper of the deezer-python remove_user_playlists function.""" + + def _get_playlist(): + success = self._client.remove_user_playlist(playlist_id=playlist_id) + return success + + return await asyncio.to_thread(_get_playlist) + + async def search_album(self, query: str, limit: int = 5) -> list[deezer.Album]: + """Async wrapper of the deezer-python search_albums function.""" + + def _search(): + result = self._client.search_albums(query=query)[:limit] + return result + + return await asyncio.to_thread(_search) + + async def search_track(self, query: str, limit: int = 5) -> list[deezer.Track]: + """Async wrapper of the deezer-python search function.""" + + def _search(): + result = self._client.search(query=query)[:limit] + return result + + return await asyncio.to_thread(_search) + + async def search_artist(self, query: str, limit: int = 5) -> list[deezer.Artist]: + """Async wrapper of the deezer-python search_artist function.""" + + def _search(): + result = self._client.search_artists(query=query)[:limit] + return result + + return await asyncio.to_thread(_search) + + async def search_playlist(self, query: str, limit: int = 5) -> list[deezer.Playlist]: + """Async wrapper of the deezer-python search_playlist function.""" + + def _search(): + result = self._client.search_playlists(query=query)[:limit] + return result + + return await asyncio.to_thread(_search) + + async def get_album_from_track(self, track: deezer.Track) -> deezer.Album: + """Get track's artist.""" + + def _get_album_from_track(): + try: + return track.get_album() + except deezer.exceptions.DeezerErrorResponse: + return None + + return await asyncio.to_thread(_get_album_from_track) + + async def get_artist_from_track(self, track: deezer.Track) -> deezer.Artist: + """Get track's artist.""" + + def _get_artist_from_track(): + return track.get_artist() + + return await asyncio.to_thread(_get_artist_from_track) + + async def get_artist_from_album(self, album: deezer.Album) -> deezer.Artist: + """Get track's artist.""" + + def _get_artist_from_album(): + return album.get_artist() + + return await asyncio.to_thread(_get_artist_from_album) + + async def get_albums_by_artist(self, artist: deezer.Artist) -> deezer.PaginatedList: + """Get albums by an artist.""" + + def _get_albums_by_artist(): + return artist.get_albums() + + return await asyncio.to_thread(_get_albums_by_artist) + + async def get_artist_top(self, artist: deezer.Artist) -> deezer.PaginatedList: + """Get top tracks by an artist.""" + + def _get_artist_top(): + return artist.get_top() + + return await asyncio.to_thread(_get_artist_top) + + async def get_recommended_tracks(self) -> deezer.PaginatedList: + """Get recommended tracks for user.""" + + def _get_recommended_tracks(): + return self._client.get_user_recommended_tracks() + + return await asyncio.to_thread(_get_recommended_tracks) + + async def create_playlist(self, playlist_name) -> deezer.Playlist: + """Create a playlist on deezer.""" + + def _create_playlist(): + playlist_id = self._client.create_playlist(playlist_name=playlist_name) + return self._client.get_playlist(playlist_id=playlist_id) + + return await asyncio.to_thread(_create_playlist) + + async def add_playlist_tracks(self, playlist_id: int, tracks: list[int]): + """Add tracks to playlist.""" + + def _add_playlist_tracks(): + playlist = self._client.get_playlist(playlist_id=playlist_id) + playlist.add_tracks(tracks=tracks) + + return await asyncio.to_thread(_add_playlist_tracks) + + async def remove_playlist_tracks(self, playlist_id: int, tracks: list[int]): + """Remove tracks from playlist.""" + + def _remove_playlist_tracks(): + playlist = self._client.get_playlist(playlist_id=playlist_id) + playlist.delete_tracks(tracks=tracks) + + return await asyncio.to_thread(_remove_playlist_tracks) diff --git a/music_assistant/server/providers/deezer/icon.svg b/music_assistant/server/providers/deezer/icon.svg new file mode 100644 index 000000000..0704de6cc --- /dev/null +++ b/music_assistant/server/providers/deezer/icon.svg @@ -0,0 +1,53 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/music_assistant/server/providers/deezer/manifest.json b/music_assistant/server/providers/deezer/manifest.json new file mode 100644 index 000000000..1ce7bc07f --- /dev/null +++ b/music_assistant/server/providers/deezer/manifest.json @@ -0,0 +1,9 @@ +{ + "type": "music", + "domain": "deezer", + "name": "Deezer", + "description": "Support for the Deezer streaming provider in Music Assistant.", + "codeowners": ["@Un10ck3d", "@micha91"], + "requirements": ["deezer-python==5.10.0", "pycryptodome==3.17"], + "multi_instance": true +} diff --git a/music_assistant/server/providers/plex/helpers.py b/music_assistant/server/providers/plex/helpers.py index 15461d957..87f514a22 100644 --- a/music_assistant/server/providers/plex/helpers.py +++ b/music_assistant/server/providers/plex/helpers.py @@ -26,9 +26,11 @@ def _get_libraries(): # create a listing of available music libraries on all servers all_libraries: list[str] = [] plex_account = MyPlexAccount(token=auth_token) - for server_resource in plex_account.resources(): + for resource in plex_account.resources(): + if "server" not in resource.provides: + continue try: - plex_server: PlexServer = server_resource.connect(None, 10) + plex_server: PlexServer = resource.connect(None, 10) except plexapi.exceptions.NotFound: continue for media_section in plex_server.library.sections(): @@ -36,7 +38,7 @@ def _get_libraries(): if media_section.type != PlexMusicSection.TYPE: continue # TODO: figure out what plex uses as stable id and use that instead of names - all_libraries.append(f"{server_resource.name} / {media_section.title}") + all_libraries.append(f"{resource.name} / {media_section.title}") return all_libraries if cache := await mass.cache.get(cache_key, checksum=auth_token): diff --git a/music_assistant/server/providers/qobuz/__init__.py b/music_assistant/server/providers/qobuz/__init__.py index 0c0bf172a..3d843b5c2 100644 --- a/music_assistant/server/providers/qobuz/__init__.py +++ b/music_assistant/server/providers/qobuz/__init__.py @@ -486,7 +486,7 @@ async def _parse_album(self, album_obj: dict, artist_obj: dict = None): album.barcode.add(album_obj["upc"]) if "label" in album_obj: album.metadata.label = album_obj["label"]["name"] - if album_obj.get("released_at"): + if (released_at := album_obj.get("released_at")) and released_at != 0: album.year = datetime.datetime.fromtimestamp(album_obj["released_at"]).year if album_obj.get("copyright"): album.metadata.copyright = album_obj["copyright"] diff --git a/music_assistant/server/providers/radiobrowser/__init__.py b/music_assistant/server/providers/radiobrowser/__init__.py index 15456565d..aff95fb46 100644 --- a/music_assistant/server/providers/radiobrowser/__init__.py +++ b/music_assistant/server/providers/radiobrowser/__init__.py @@ -193,9 +193,9 @@ async def browse(self, path: str) -> BrowseFolder: sub_items: list[BrowseFolder] = [] for country in await self.radios.countries(order=Order.NAME): folder = BrowseFolder( - item_id=country.name.lower(), + item_id=country.code.lower(), provider=self.domain, - path=path + "/" + country.name.lower(), + path=path + "/" + country.code.lower(), name="", label=country.name, ) @@ -220,7 +220,7 @@ async def browse(self, path: str) -> BrowseFolder: items=[x for x in await self.get_by_tag(subsubpath)], ) - if subsubpath in await self.get_country_names(): + if subsubpath in await self.get_country_codes(): return BrowseFolder( item_id="radios", provider=self.domain, @@ -244,13 +244,13 @@ async def get_tag_names(self): tag_names.append(tag.name.lower()) return tag_names - async def get_country_names(self): + async def get_country_codes(self): """Get a list of country names.""" countries = await self.radios.countries(order=Order.NAME) - country_names = [] + country_codes = [] for country in countries: - country_names.append(country.name.lower()) - return country_names + country_codes.append(country.code.lower()) + return country_codes async def get_by_popularity(self): """Get radio stations by popularity.""" @@ -279,12 +279,12 @@ async def get_by_tag(self, tag: str): items.append(await self._parse_radio(station)) return items - async def get_by_country(self, country: str): + async def get_by_country(self, country_code: str): """Get radio stations by country.""" items = [] stations = await self.radios.stations( - filter_by=FilterBy.COUNTRY_EXACT, - filter_term=country, + filter_by=FilterBy.COUNTRY_CODE_EXACT, + filter_term=country_code, hide_broken=True, order=Order.NAME, reverse=False, diff --git a/music_assistant/server/providers/sonos/__init__.py b/music_assistant/server/providers/sonos/__init__.py index b58cc66d3..c98dc32f6 100644 --- a/music_assistant/server/providers/sonos/__init__.py +++ b/music_assistant/server/providers/sonos/__init__.py @@ -405,9 +405,7 @@ async def _run_discovery(self) -> None: try: self._discovery_running = True self.logger.debug("Sonos discovery started...") - discovered_devices: set[soco.SoCo] = await asyncio.to_thread( - soco.discover, 30, allow_network_scan=True - ) + discovered_devices: set[soco.SoCo] = await asyncio.to_thread(soco.discover, 60) if discovered_devices is None: discovered_devices = set() new_device_ids = {item.uid for item in discovered_devices} diff --git a/music_assistant/server/providers/spotify/__init__.py b/music_assistant/server/providers/spotify/__init__.py index 4705f9036..908c27e4f 100644 --- a/music_assistant/server/providers/spotify/__init__.py +++ b/music_assistant/server/providers/spotify/__init__.py @@ -537,7 +537,7 @@ async def login(self) -> dict: if ( self._auth_token and os.path.isdir(self._cache_dir) - and (self._auth_token["expiresAt"] > int(time.time()) + 20) + and (self._auth_token["expiresAt"] > int(time.time()) + 600) ): return self._auth_token tokeninfo, userinfo = None, self._sp_user diff --git a/music_assistant/server/providers/tidal/__init__.py b/music_assistant/server/providers/tidal/__init__.py index 278d3f9c0..53b164399 100644 --- a/music_assistant/server/providers/tidal/__init__.py +++ b/music_assistant/server/providers/tidal/__init__.py @@ -14,6 +14,7 @@ from tidalapi import Quality as TidalQuality from tidalapi import Session as TidalSession from tidalapi import Track as TidalTrack +from tidalapi.media import Lyrics as TidalLyrics from music_assistant.common.helpers.uri import create_uri from music_assistant.common.helpers.util import create_sort_name @@ -273,36 +274,26 @@ async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]: async def get_album_tracks(self, prov_album_id: str) -> list[Track]: """Get album tracks for given album id.""" tidal_session = await self._get_tidal_session() - result = [] - tracks = await get_album_tracks(tidal_session, prov_album_id) - for index, track_obj in enumerate(tracks, 1): - if track_obj.available: - track = await self._parse_track(track_obj=track_obj) - track.position = index - result.append(track) - return result + return [ + await self._parse_track(track_obj=track) + for track in await get_album_tracks(tidal_session, prov_album_id) + ] async def get_artist_albums(self, prov_artist_id: str) -> list[Album]: """Get a list of all albums for the given artist.""" tidal_session = await self._get_tidal_session() - result = [] - albums = await get_artist_albums(tidal_session, prov_artist_id) - for album_obj in albums: - album = await self._parse_album(album_obj=album_obj) - result.append(album) - return result + return [ + await self._parse_album(album_obj=album) + for album in await get_artist_albums(tidal_session, prov_artist_id) + ] async def get_artist_toptracks(self, prov_artist_id: str) -> list[Track]: """Get a list of 10 most popular tracks for the given artist.""" tidal_session = await self._get_tidal_session() - result = [] - tracks = await get_artist_toptracks(tidal_session, prov_artist_id) - for index, track_obj in enumerate(tracks, 1): - if track_obj.available: - track = await self._parse_track(track_obj=track_obj) - track.position = index - result.append(track) - return result + return [ + await self._parse_track(track_obj=track) + for track in await get_artist_toptracks(tidal_session, prov_artist_id) + ] async def get_playlist_tracks(self, prov_playlist_id: str) -> AsyncGenerator[Track, None]: """Get all playlist tracks for given playlist id.""" @@ -320,13 +311,10 @@ async def get_playlist_tracks(self, prov_playlist_id: str) -> AsyncGenerator[Tra async def get_similar_tracks(self, prov_track_id: str, limit=25) -> list[Track]: """Get similar tracks for given track id.""" tidal_session = await self._get_tidal_session() - similar_tracks_obj = await get_similar_tracks(tidal_session, prov_track_id, limit) - tracks = [] - for track_obj in similar_tracks_obj: - if track_obj.available: - track = await self._parse_track(track_obj=track_obj) - tracks.append(track) - return tracks + return [ + await self._parse_track(track_obj=track) + for track in await get_similar_tracks(tidal_session, prov_track_id, limit) + ] async def library_add(self, prov_item_id: str, media_type: MediaType): """Add item to library.""" @@ -370,8 +358,7 @@ async def create_playlist(self, name: str) -> Playlist: """Create a new playlist on provider with given name.""" tidal_session = await self._get_tidal_session() playlist_obj = await create_playlist(tidal_session, self._tidal_user_id, name) - playlist = await self._parse_playlist(playlist_obj=playlist_obj) - return playlist + return await self._parse_playlist(playlist_obj=playlist_obj) async def get_stream_details(self, item_id: str) -> StreamDetails: """Return the content details for the given track when it will be streamed.""" @@ -506,19 +493,18 @@ async def _parse_artist(self, artist_obj: TidalArtist, full_details: bool = Fals ) ) # metadata - if full_details: - image_url = None - if artist_obj.name != "Various Artists": - try: - image_url = await asyncio.to_thread(artist_obj.image(750)) - except Exception: - self.logger.info(f"Artist {artist_obj.id} has no available picture") - artist.metadata.images = [ - MediaItemImage( - ImageType.THUMB, - image_url, - ) - ] + if full_details and artist_obj.name != "Various Artists": + try: + image_url = await self._get_image_url(artist_obj, 750) + artist.metadata.images = [ + MediaItemImage( + ImageType.THUMB, + image_url, + ) + ] + except Exception: + self.logger.info(f"Artist {artist_obj.id} has no available picture") + return artist async def _parse_album(self, album_obj: TidalAlbum, full_details: bool = False) -> Album: @@ -554,17 +540,17 @@ async def _parse_album(self, album_obj: TidalAlbum, full_details: bool = False) album.metadata.explicit = album_obj.explicit album.metadata.popularity = album_obj.popularity if full_details: - image_url = None try: - image_url = await asyncio.to_thread(album_obj.image(1280)) + image_url = await self._get_image_url(album_obj, 1280) + album.metadata.images = [ + MediaItemImage( + ImageType.THUMB, + image_url, + ) + ] except Exception: self.logger.info(f"Album {album_obj.id} has no available picture") - album.metadata.images = [ - MediaItemImage( - ImageType.THUMB, - image_url, - ) - ] + return album async def _parse_track(self, track_obj: TidalTrack, full_details: bool = False) -> Track: @@ -609,7 +595,7 @@ async def _parse_track(self, track_obj: TidalTrack, full_details: bool = False) track.metadata.copyright = track_obj.copyright if full_details: try: - if lyrics_obj := await asyncio.to_thread(track_obj.lyrics): + if lyrics_obj := await self._get_lyrics(track_obj): track.metadata.lyrics = lyrics_obj.text except Exception: self.logger.info(f"Track {track_obj.id} has no available lyrics") @@ -642,16 +628,27 @@ async def _parse_playlist( playlist.metadata.checksum = str(playlist_obj.last_updated) playlist.metadata.popularity = playlist_obj.popularity if full_details: - image_url = None try: - image_url = await asyncio.to_thread(playlist_obj.image(1080)) + image_url = await self._get_image_url(playlist_obj, 1080) + playlist.metadata.images = [ + MediaItemImage( + ImageType.THUMB, + image_url, + ) + ] except Exception: self.logger.info(f"Playlist {playlist_obj.id} has no available picture") - playlist.metadata.images = [ - MediaItemImage( - ImageType.THUMB, - image_url, - ) - ] return playlist + + async def _get_image_url(self, item, size: int): + def inner() -> str: + return item.image(size) + + return await asyncio.to_thread(inner) + + async def _get_lyrics(self, item): + def inner() -> TidalLyrics: + return item.lyrics + + return await asyncio.to_thread(inner) diff --git a/music_assistant/server/providers/ytmusic/__init__.py b/music_assistant/server/providers/ytmusic/__init__.py index ce65e3f96..2438055ba 100644 --- a/music_assistant/server/providers/ytmusic/__init__.py +++ b/music_assistant/server/providers/ytmusic/__init__.py @@ -9,7 +9,6 @@ from urllib.parse import unquote import pytube -import ytmusicapi from music_assistant.common.helpers.uri import create_uri from music_assistant.common.helpers.util import create_sort_name @@ -36,7 +35,7 @@ StreamDetails, Track, ) -from music_assistant.constants import CONF_USERNAME +from music_assistant.server.helpers.auth import AuthenticationHelper from music_assistant.server.models.music_provider import MusicProvider from .helpers import ( @@ -53,6 +52,8 @@ library_add_remove_album, library_add_remove_artist, library_add_remove_playlist, + login_oauth, + refresh_oauth_token, search, ) @@ -64,6 +65,11 @@ CONF_COOKIE = "cookie" +CONF_ACTION_AUTH = "auth" +CONF_AUTH_TOKEN = "auth_token" +CONF_REFRESH_TOKEN = "refresh_token" +CONF_TOKEN_TYPE = "token_type" +CONF_EXPIRY_TIME = "expiry_time" YT_DOMAIN = "https://www.youtube.com" YTM_DOMAIN = "https://music.youtube.com" @@ -97,7 +103,7 @@ async def setup( async def get_config_entries( mass: MusicAssistant, - instance_id: str | None = None, + instance_id: str | None = None, # noqa: ARG001 action: str | None = None, values: dict[str, ConfigValueType] | None = None, ) -> tuple[ConfigEntry, ...]: @@ -108,18 +114,44 @@ async def get_config_entries( action: [optional] action key called from config entries UI. values: the (intermediate) raw values for config entries sent with the action. """ - # ruff: noqa: ARG001 + if action == CONF_ACTION_AUTH: + async with AuthenticationHelper(mass, values["session_id"]) as auth_helper: + token = await login_oauth(auth_helper) + values[CONF_AUTH_TOKEN] = token["access_token"] + values[CONF_REFRESH_TOKEN] = token["refresh_token"] + values[CONF_EXPIRY_TIME] = token["expires_in"] + values[CONF_TOKEN_TYPE] = token["token_type"] + # return the collected config entries return ( ConfigEntry( - key=CONF_USERNAME, type=ConfigEntryType.STRING, label="Username", required=True + key=CONF_AUTH_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Authentication token for Youtube Music", + description="You need to link Music Assistant to your Youtube Music account.", + action=CONF_ACTION_AUTH, + action_label="Authenticate on Youtube Music", + value=values.get(CONF_AUTH_TOKEN) if values else None, ), ConfigEntry( - key=CONF_COOKIE, + key=CONF_REFRESH_TOKEN, type=ConfigEntryType.SECURE_STRING, - label="Login Cookie", - required=True, - description="The Login cookie you grabbed from an existing session, " - "see the documentation.", + label=CONF_REFRESH_TOKEN, + hidden=True, + value=values.get(CONF_REFRESH_TOKEN) if values else None, + ), + ConfigEntry( + key=CONF_EXPIRY_TIME, + type=ConfigEntryType.INTEGER, + label="Expiry time of auth token for Youtube Music", + hidden=True, + value=values.get(CONF_EXPIRY_TIME) if values else None, + ), + ConfigEntry( + key=CONF_TOKEN_TYPE, + type=ConfigEntryType.STRING, + label="The token type required to create headers", + hidden=True, + value=values.get(CONF_TOKEN_TYPE) if values else None, ), ) @@ -135,9 +167,9 @@ class YoutubeMusicProvider(MusicProvider): async def handle_setup(self) -> None: """Set up the YTMusic provider.""" - if not self.config.get_value(CONF_USERNAME) or not self.config.get_value(CONF_COOKIE): + if not self.config.get_value(CONF_AUTH_TOKEN): raise LoginFailed("Invalid login credentials") - await self._initialize_headers(cookie=self.config.get_value(CONF_COOKIE)) + await self._initialize_headers() await self._initialize_context() self._cookies = {"CONSENT": "YES+1"} self._signature_timestamp = await self._get_signature_timestamp() @@ -185,32 +217,36 @@ async def search( async def get_library_artists(self) -> AsyncGenerator[Artist, None]: """Retrieve all library artists from Youtube Music.""" + await self._check_oauth_token() artists_obj = await get_library_artists( - headers=self._headers, username=self.config.get_value(CONF_USERNAME) + headers=self._headers, ) for artist in artists_obj: yield await self._parse_artist(artist) async def get_library_albums(self) -> AsyncGenerator[Album, None]: """Retrieve all library albums from Youtube Music.""" + await self._check_oauth_token() albums_obj = await get_library_albums( - headers=self._headers, username=self.config.get_value(CONF_USERNAME) + headers=self._headers, ) for album in albums_obj: yield await self._parse_album(album, album["browseId"]) async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]: """Retrieve all library playlists from the provider.""" + await self._check_oauth_token() playlists_obj = await get_library_playlists( - headers=self._headers, username=self.config.get_value(CONF_USERNAME) + headers=self._headers, ) for playlist in playlists_obj: yield await self._parse_playlist(playlist) async def get_library_tracks(self) -> AsyncGenerator[Track, None]: """Retrieve library tracks from Youtube Music.""" + await self._check_oauth_token() tracks_obj = await get_library_tracks( - headers=self._headers, username=self.config.get_value(CONF_USERNAME) + headers=self._headers, ) for track in tracks_obj: # Library tracks sometimes do not have a valid artist id @@ -257,21 +293,17 @@ async def get_track(self, prov_track_id) -> Track: async def get_playlist(self, prov_playlist_id) -> Playlist: """Get full playlist details by id.""" + await self._check_oauth_token() if playlist_obj := await get_playlist( - prov_playlist_id=prov_playlist_id, - headers=self._headers, - username=self.config.get_value(CONF_USERNAME), + prov_playlist_id=prov_playlist_id, headers=self._headers ): return await self._parse_playlist(playlist_obj) raise MediaNotFoundError(f"Item {prov_playlist_id} not found") async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]: """Get all playlist tracks for given playlist id.""" - playlist_obj = await get_playlist( - prov_playlist_id=prov_playlist_id, - headers=self._headers, - username=self.config.get_value(CONF_USERNAME), - ) + await self._check_oauth_token() + playlist_obj = await get_playlist(prov_playlist_id=prov_playlist_id, headers=self._headers) if "tracks" not in playlist_obj: return for index, track in enumerate(playlist_obj["tracks"]): @@ -316,27 +348,19 @@ async def get_artist_toptracks(self, prov_artist_id) -> list[Track]: async def library_add(self, prov_item_id, media_type: MediaType) -> None: """Add an item to the library.""" + await self._check_oauth_token() result = False if media_type == MediaType.ARTIST: result = await library_add_remove_artist( - headers=self._headers, - prov_artist_id=prov_item_id, - add=True, - username=self.config.get_value(CONF_USERNAME), + headers=self._headers, prov_artist_id=prov_item_id, add=True ) elif media_type == MediaType.ALBUM: result = await library_add_remove_album( - headers=self._headers, - prov_item_id=prov_item_id, - add=True, - username=self.config.get_value(CONF_USERNAME), + headers=self._headers, prov_item_id=prov_item_id, add=True ) elif media_type == MediaType.PLAYLIST: result = await library_add_remove_playlist( - headers=self._headers, - prov_item_id=prov_item_id, - add=True, - username=self.config.get_value(CONF_USERNAME), + headers=self._headers, prov_item_id=prov_item_id, add=True ) elif media_type == MediaType.TRACK: raise NotImplementedError @@ -344,27 +368,19 @@ async def library_add(self, prov_item_id, media_type: MediaType) -> None: async def library_remove(self, prov_item_id, media_type: MediaType): """Remove an item from the library.""" + await self._check_oauth_token() result = False if media_type == MediaType.ARTIST: result = await library_add_remove_artist( - headers=self._headers, - prov_artist_id=prov_item_id, - add=False, - username=self.config.get_value(CONF_USERNAME), + headers=self._headers, prov_artist_id=prov_item_id, add=False ) elif media_type == MediaType.ALBUM: result = await library_add_remove_album( - headers=self._headers, - prov_item_id=prov_item_id, - add=False, - username=self.config.get_value(CONF_USERNAME), + headers=self._headers, prov_item_id=prov_item_id, add=False ) elif media_type == MediaType.PLAYLIST: result = await library_add_remove_playlist( - headers=self._headers, - prov_item_id=prov_item_id, - add=False, - username=self.config.get_value(CONF_USERNAME), + headers=self._headers, prov_item_id=prov_item_id, add=False ) elif media_type == MediaType.TRACK: raise NotImplementedError @@ -372,23 +388,20 @@ async def library_remove(self, prov_item_id, media_type: MediaType): async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None: """Add track(s) to playlist.""" + await self._check_oauth_token() return await add_remove_playlist_tracks( headers=self._headers, prov_playlist_id=prov_playlist_id, prov_track_ids=prov_track_ids, add=True, - username=self.config.get_value(CONF_USERNAME), ) async def remove_playlist_tracks( self, prov_playlist_id: str, positions_to_remove: tuple[int, ...] ) -> None: """Remove track(s) from playlist.""" - playlist_obj = await get_playlist( - prov_playlist_id=prov_playlist_id, - headers=self._headers, - username=self.config.get_value(CONF_USERNAME), - ) + await self._check_oauth_token() + playlist_obj = await get_playlist(prov_playlist_id=prov_playlist_id, headers=self._headers) if "tracks" not in playlist_obj: return None tracks_to_delete = [] @@ -406,15 +419,14 @@ async def remove_playlist_tracks( prov_playlist_id=prov_playlist_id, prov_track_ids=tracks_to_delete, add=False, - username=self.config.get_value(CONF_USERNAME), ) async def get_similar_tracks(self, prov_track_id, limit=25) -> list[Track]: """Retrieve a dynamic list of tracks based on the provided item.""" + await self._check_oauth_token() result = [] result = await get_song_radio_tracks( headers=self._headers, - username=self.config.get_value(CONF_USERNAME), prov_item_id=prov_track_id, limit=limit, ) @@ -465,6 +477,8 @@ async def get_stream_details(self, item_id: str) -> StreamDetails: return stream_details async def _post_data(self, endpoint: str, data: dict[str, str], **kwargs): # noqa: ARG002 + """Post data to the given endpoint.""" + await self._check_oauth_token() url = f"{YTM_BASE_URL}{endpoint}" data.update(self._context) async with self.mass.http_session.post( @@ -477,13 +491,27 @@ async def _post_data(self, endpoint: str, data: dict[str, str], **kwargs): # no return await response.json() async def _get_data(self, url: str, params: dict = None): + """Get data from the given URL.""" + await self._check_oauth_token() async with self.mass.http_session.get( url, headers=self._headers, params=params, cookies=self._cookies ) as response: return await response.text() - async def _initialize_headers(self, cookie: str) -> dict[str, str]: + async def _check_oauth_token(self) -> None: + """Verify the OAuth token is valid and refresh if needed.""" + if self.config.get_value(CONF_EXPIRY_TIME) < time(): + token = await refresh_oauth_token( + self.mass.http_session, self.config.get_value(CONF_REFRESH_TOKEN) + ) + self.config.update({CONF_AUTH_TOKEN: token["access_token"]}) + self.config.update({CONF_EXPIRY_TIME: time() + token["expires_in"]}) + self.config.update({CONF_TOKEN_TYPE: token["token_type"]}) + await self._initialize_headers() + + async def _initialize_headers(self) -> dict[str, str]: """Return headers to include in the requests.""" + auth = f"{self.config.get_value(CONF_TOKEN_TYPE)} {self.config.get_value(CONF_AUTH_TOKEN)}" headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:72.0) Gecko/20100101 Firefox/72.0", # noqa: E501 "Accept": "*/*", @@ -491,11 +519,9 @@ async def _initialize_headers(self, cookie: str) -> dict[str, str]: "Content-Type": "application/json", "X-Goog-AuthUser": "0", "x-origin": "https://music.youtube.com", - "Cookie": cookie, + "X-Goog-Request-Time": str(int(time())), + "Authorization": auth, } - sapisid = ytmusicapi.helpers.sapisid_from_cookie(cookie) - origin = headers.get("origin", headers.get("x-origin")) - headers["Authorization"] = ytmusicapi.helpers.get_authorization(sapisid + " " + origin) self._headers = headers async def _initialize_context(self) -> dict[str, str]: diff --git a/music_assistant/server/providers/ytmusic/helpers.py b/music_assistant/server/providers/ytmusic/helpers.py index 27491a51e..c18ece95a 100644 --- a/music_assistant/server/providers/ytmusic/helpers.py +++ b/music_assistant/server/providers/ytmusic/helpers.py @@ -11,6 +11,17 @@ from time import time import ytmusicapi +from aiohttp import ClientSession +from ytmusicapi.constants import ( + OAUTH_CLIENT_ID, + OAUTH_CLIENT_SECRET, + OAUTH_CODE_URL, + OAUTH_SCOPE, + OAUTH_TOKEN_URL, + OAUTH_USER_AGENT, +) + +from music_assistant.server.helpers.auth import AuthenticationHelper async def get_artist(prov_artist_id: str) -> dict[str, str]: @@ -40,14 +51,11 @@ def _get_album(): return await asyncio.to_thread(_get_album) -async def get_playlist( - prov_playlist_id: str, headers: dict[str, str], username: str -) -> dict[str, str]: +async def get_playlist(prov_playlist_id: str, headers: dict[str, str]) -> dict[str, str]: """Async wrapper around the ytmusicapi get_playlist function.""" def _get_playlist(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) playlist = ytm.get_playlist(playlistId=prov_playlist_id, limit=None) playlist["checksum"] = get_playlist_checksum(playlist) return playlist @@ -80,12 +88,11 @@ def _get_song(): return await asyncio.to_thread(_get_song) -async def get_library_artists(headers: dict[str, str], username: str) -> dict[str, str]: +async def get_library_artists(headers: dict[str, str]) -> dict[str, str]: """Async wrapper around the ytmusicapi get_library_artists function.""" def _get_library_artists(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) artists = ytm.get_library_subscriptions(limit=9999) # Sync properties with uniformal artist object for artist in artists: @@ -98,23 +105,21 @@ def _get_library_artists(): return await asyncio.to_thread(_get_library_artists) -async def get_library_albums(headers: dict[str, str], username: str) -> dict[str, str]: +async def get_library_albums(headers: dict[str, str]) -> dict[str, str]: """Async wrapper around the ytmusicapi get_library_albums function.""" def _get_library_albums(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) return ytm.get_library_albums(limit=9999) return await asyncio.to_thread(_get_library_albums) -async def get_library_playlists(headers: dict[str, str], username: str) -> dict[str, str]: +async def get_library_playlists(headers: dict[str, str]) -> dict[str, str]: """Async wrapper around the ytmusicapi get_library_playlists function.""" def _get_library_playlists(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) playlists = ytm.get_library_playlists(limit=9999) # Sync properties with uniformal playlist object for playlist in playlists: @@ -126,12 +131,11 @@ def _get_library_playlists(): return await asyncio.to_thread(_get_library_playlists) -async def get_library_tracks(headers: dict[str, str], username: str) -> dict[str, str]: +async def get_library_tracks(headers: dict[str, str]) -> dict[str, str]: """Async wrapper around the ytmusicapi get_library_tracks function.""" def _get_library_tracks(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) tracks = ytm.get_library_songs(limit=9999) return tracks @@ -139,13 +143,12 @@ def _get_library_tracks(): async def library_add_remove_artist( - headers: dict[str, str], prov_artist_id: str, add: bool = True, username: str = None + headers: dict[str, str], prov_artist_id: str, add: bool = True ) -> bool: """Add or remove an artist to the user's library.""" def _library_add_remove_artist(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) if add: return "actions" in ytm.subscribe_artists(channelIds=[prov_artist_id]) if not add: @@ -156,14 +159,13 @@ def _library_add_remove_artist(): async def library_add_remove_album( - headers: dict[str, str], prov_item_id: str, add: bool = True, username: str = None + headers: dict[str, str], prov_item_id: str, add: bool = True ) -> bool: """Add or remove an album or playlist to the user's library.""" album = await get_album(prov_album_id=prov_item_id) def _library_add_remove_album(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) playlist_id = album["audioPlaylistId"] if add: return ytm.rate_playlist(playlist_id, "LIKE") @@ -175,13 +177,12 @@ def _library_add_remove_album(): async def library_add_remove_playlist( - headers: dict[str, str], prov_item_id: str, add: bool = True, username: str = None + headers: dict[str, str], prov_item_id: str, add: bool = True ) -> bool: """Add or remove an album or playlist to the user's library.""" def _library_add_remove_playlist(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) if add: return "actions" in ytm.rate_playlist(prov_item_id, "LIKE") if not add: @@ -192,17 +193,12 @@ def _library_add_remove_playlist(): async def add_remove_playlist_tracks( - headers: dict[str, str], - prov_playlist_id: str, - prov_track_ids: list[str], - add: bool, - username: str = None, + headers: dict[str, str], prov_playlist_id: str, prov_track_ids: list[str], add: bool ) -> bool: """Async wrapper around adding/removing tracks to a playlist.""" def _add_playlist_tracks(): - user = username if is_brand_account(username) else None - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) if add: return ytm.add_playlist_items(playlistId=prov_playlist_id, videoIds=prov_track_ids) if not add: @@ -213,13 +209,12 @@ def _add_playlist_tracks(): async def get_song_radio_tracks( - headers: dict[str, str], username: str, prov_item_id: str, limit=25 + headers: dict[str, str], prov_item_id: str, limit=25 ) -> dict[str, str]: """Async wrapper around the ytmusicapi radio function.""" - user = username if is_brand_account(username) else None def _get_song_radio_tracks(): - ytm = ytmusicapi.YTMusic(auth=json.dumps(headers), user=user) + ytm = ytmusicapi.YTMusic(auth=json.dumps(headers)) playlist_id = f"RDAMVM{prov_item_id}" result = ytm.get_watch_playlist(videoId=prov_item_id, playlistId=playlist_id, limit=limit) # Replace inconsistensies for easier parsing @@ -285,3 +280,75 @@ def get_sec(time_str): if len(parts) == 2: return int(parts[0]) * 60 + int(parts[1]) return 0 + + +async def login_oauth(auth_helper: AuthenticationHelper): + """Use device login to get a token.""" + http_session = auth_helper.mass.http_session + code = await get_oauth_code(http_session) + token = await visit_oauth_auth_url(auth_helper, code) + return token + + +def _get_data_and_headers(data: dict): + """Prepare headers for OAuth requests.""" + data.update({"client_id": OAUTH_CLIENT_ID}) + headers = {"User-Agent": OAUTH_USER_AGENT} + return data, headers + + +async def get_oauth_code(session: ClientSession): + """Get the OAuth code from the server.""" + data, headers = _get_data_and_headers({"scope": OAUTH_SCOPE}) + async with session.post(OAUTH_CODE_URL, json=data, headers=headers) as code_response: + return await code_response.json() + + +async def visit_oauth_auth_url(auth_helper: AuthenticationHelper, code: dict[str, str]): + """Redirect the user to the OAuth login page and wait for the token.""" + auth_url = f"{code['verification_url']}?user_code={code['user_code']}" + auth_helper.send_url(auth_url=auth_url) + device_code = code["device_code"] + expiry = code["expires_in"] + interval = code["interval"] + while expiry > 0: + token = await get_oauth_token_from_code(auth_helper.mass.http_session, device_code) + if token.get("access_token"): + return token + await asyncio.sleep(interval) + expiry -= interval + raise TimeoutError("You took too long to log in") + + +async def get_oauth_token_from_code(session: ClientSession, device_code: str): + """Check if the OAuth token is ready yet.""" + data, headers = _get_data_and_headers( + data={ + "client_secret": OAUTH_CLIENT_SECRET, + "grant_type": "http://oauth.net/grant_type/device/1.0", + "code": device_code, + } + ) + async with session.post( + OAUTH_TOKEN_URL, + json=data, + headers=headers, + ) as token_response: + return await token_response.json() + + +async def refresh_oauth_token(session: ClientSession, refresh_token: str): + """Refresh an expired OAuth token.""" + data, headers = _get_data_and_headers( + { + "client_secret": OAUTH_CLIENT_SECRET, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + ) + async with session.post( + OAUTH_TOKEN_URL, + json=data, + headers=headers, + ) as response: + return await response.json() diff --git a/music_assistant/server/providers/ytmusic/manifest.json b/music_assistant/server/providers/ytmusic/manifest.json index 82f16c6d1..0bbe9174d 100644 --- a/music_assistant/server/providers/ytmusic/manifest.json +++ b/music_assistant/server/providers/ytmusic/manifest.json @@ -4,7 +4,7 @@ "name": "YouTube Music", "description": "Support for the YouTube Music streaming provider in Music Assistant.", "codeowners": ["@MarvinSchenkel"], - "requirements": ["ytmusicapi==0.25.1", "git+https://github.com/pytube/pytube.git@refs/pull/1501/head"], + "requirements": ["ytmusicapi==1.0.0", "git+https://github.com/pytube/pytube.git@refs/pull/1501/head"], "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/606", "multi_instance": true } diff --git a/pyproject.toml b/pyproject.toml index ec1c7b96a..b754f5730 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,26 +32,25 @@ server = [ "aiofiles==23.1.0", "aiorun==2022.11.1", "coloredlogs==15.0.1", - "databases==0.7.0", "aiosqlite==0.19.0", "python-slugify==8.0.1", "mashumaro==3.7", "memory-tempfile==2.2.3", - "music-assistant-frontend==20230420.0", + "music-assistant-frontend==20230510.0", "pillow==9.5.0", "unidecode==1.3.6", "xmltodict==0.13.0", - "orjson==3.8.9", + "orjson==3.8.12", "shortuuid==1.0.11", - "zeroconf==0.56.0", + "zeroconf==0.62.0", "cryptography==40.0.2" ] test = [ - "black==23.1.0", + "black==23.3.0", "codespell==2.2.4", "mypy==1.2.0", - "ruff==0.0.261", - "pytest==7.2.2", + "ruff==0.0.265", + "pytest==7.3.1", "pytest-asyncio==0.21.0", "pytest-aiohttp==1.0.4", "pytest-cov==4.0.0", diff --git a/requirements_all.txt b/requirements_all.txt index fa53c4016..164edb89d 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -11,22 +11,23 @@ async-upnp-client==0.33.1 asyncio-throttle==1.0.2 coloredlogs==15.0.1 cryptography==40.0.2 -databases==0.7.0 +deezer-python==5.10.0 faust-cchardet>=2.1.18 git+https://github.com/gieljnssns/python-radios.git@main git+https://github.com/jozefKruszynski/python-tidal.git@v0.7.1 git+https://github.com/pytube/pytube.git@refs/pull/1501/head mashumaro==3.7 memory-tempfile==2.2.3 -music-assistant-frontend==20230420.0 -orjson==3.8.9 +music-assistant-frontend==20230510.0 +orjson==3.8.12 pillow==9.5.0 plexapi==4.13.4 PyChromecast==13.0.7 +pycryptodome==3.17 python-slugify==8.0.1 shortuuid==1.0.11 soco==0.29.1 unidecode==1.3.6 xmltodict==0.13.0 -ytmusicapi==0.25.1 -zeroconf==0.56.0 +ytmusicapi==1.0.0 +zeroconf==0.62.0