Skip to content

Commit

Permalink
Fix type hint problems for resource clients
Browse files Browse the repository at this point in the history
Closes #152
  • Loading branch information
vdusek committed Dec 7, 2023
1 parent cbed2b7 commit 1a4e5f0
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from apify_shared.utils import ignore_docs

if TYPE_CHECKING:
from typing_extensions import Self

from ..memory_storage_client import MemoryStorageClient


Expand Down Expand Up @@ -48,9 +50,9 @@ def _get_storages_dir(cls: type[BaseResourceClient], memory_storage_client: Memo
@classmethod
@abstractmethod
def _get_storage_client_cache(
cls: type[BaseResourceClient],
cls, # noqa: ANN102 # type annotated cls seems not to be working with Self as a return type
memory_storage_client: MemoryStorageClient,
) -> list[BaseResourceClient]:
) -> list[Self]:
raise NotImplementedError('You must override this method in the subclass!')

@abstractmethod
Expand All @@ -60,21 +62,21 @@ def _to_resource_info(self: BaseResourceClient) -> dict:
@classmethod
@abstractmethod
def _create_from_directory(
cls: type[BaseResourceClient],
cls, # noqa: ANN102 # type annotated cls seems not to be working with Self as a return type
storage_directory: str,
memory_storage_client: MemoryStorageClient,
id: str | None = None, # noqa: A002
name: str | None = None,
) -> BaseResourceClient:
) -> Self:
raise NotImplementedError('You must override this method in the subclass!')

@classmethod
def _find_or_create_client_by_id_or_name(
cls: type[BaseResourceClient],
cls, # noqa: ANN102 # type annotated cls seems not to be working with Self as a return type
memory_storage_client: MemoryStorageClient,
id: str | None = None, # noqa: A002
name: str | None = None,
) -> BaseResourceClient | None:
) -> Self | None:
assert id is not None or name is not None # noqa: S101

storage_client_cache = cls._get_storage_client_cache(memory_storage_client)
Expand Down
34 changes: 17 additions & 17 deletions src/apify/_memory_storage/resource_clients/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ async def get(self: DatasetClient) -> dict | None:
found = self._find_or_create_client_by_id_or_name(memory_storage_client=self._memory_storage_client, id=self._id, name=self._name)

if found:
async with found._file_operation_lock: # type: ignore
await found._update_timestamps(has_been_modified=False) # type: ignore
async with found._file_operation_lock:
await found._update_timestamps(has_been_modified=False)
return found._to_resource_info()

return None
Expand Down Expand Up @@ -103,7 +103,7 @@ async def update(self: DatasetClient, *, name: str | None = None) -> dict:
if name is None:
return existing_dataset_by_id._to_resource_info()

async with existing_dataset_by_id._file_operation_lock: # type: ignore
async with existing_dataset_by_id._file_operation_lock:
# Check that name is not in use already
existing_dataset_by_name = next(
(dataset for dataset in self._memory_storage_client._datasets_handled if dataset._name and dataset._name.lower() == name.lower()),
Expand All @@ -122,7 +122,7 @@ async def update(self: DatasetClient, *, name: str | None = None) -> dict:
await force_rename(previous_dir, existing_dataset_by_id._resource_directory)

# Update timestamps
await existing_dataset_by_id._update_timestamps(has_been_modified=True) # type: ignore
await existing_dataset_by_id._update_timestamps(has_been_modified=True)

return existing_dataset_by_id._to_resource_info()

Expand Down Expand Up @@ -193,19 +193,19 @@ async def list_items(
if existing_dataset_by_id is None:
raise_on_non_existing_storage(StorageTypes.DATASET, self._id)

async with existing_dataset_by_id._file_operation_lock: # type: ignore
start, end = existing_dataset_by_id._get_start_and_end_indexes( # type: ignore
max(existing_dataset_by_id._item_count - (offset or 0) - (limit or LIST_ITEMS_LIMIT), 0) if desc else offset or 0, # type: ignore
async with existing_dataset_by_id._file_operation_lock:
start, end = existing_dataset_by_id._get_start_and_end_indexes(
max(existing_dataset_by_id._item_count - (offset or 0) - (limit or LIST_ITEMS_LIMIT), 0) if desc else offset or 0,
limit,
)

items = []

for idx in range(start, end):
entry_number = self._generate_local_entry_name(idx)
items.append(existing_dataset_by_id._dataset_entries[entry_number]) # type: ignore
items.append(existing_dataset_by_id._dataset_entries[entry_number])

await existing_dataset_by_id._update_timestamps(has_been_modified=False) # type: ignore
await existing_dataset_by_id._update_timestamps(has_been_modified=False)

if desc:
items.reverse()
Expand All @@ -217,7 +217,7 @@ async def list_items(
'items': items,
'limit': limit or LIST_ITEMS_LIMIT,
'offset': offset or 0,
'total': existing_dataset_by_id._item_count, # type: ignore
'total': existing_dataset_by_id._item_count,
}
)

Expand Down Expand Up @@ -308,16 +308,16 @@ async def push_items(self: DatasetClient, items: JSONSerializable) -> None:

added_ids: list[str] = []
for entry in normalized:
existing_dataset_by_id._item_count += 1 # type: ignore
idx = self._generate_local_entry_name(existing_dataset_by_id._item_count) # type: ignore
existing_dataset_by_id._item_count += 1
idx = self._generate_local_entry_name(existing_dataset_by_id._item_count)

existing_dataset_by_id._dataset_entries[idx] = entry # type: ignore
existing_dataset_by_id._dataset_entries[idx] = entry
added_ids.append(idx)

data_entries = [(id, existing_dataset_by_id._dataset_entries[id]) for id in added_ids] # type: ignore # noqa: A001
data_entries = [(id, existing_dataset_by_id._dataset_entries[id]) for id in added_ids] # noqa: A001

async with existing_dataset_by_id._file_operation_lock: # type: ignore
await existing_dataset_by_id._update_timestamps(has_been_modified=True) # type: ignore
async with existing_dataset_by_id._file_operation_lock:
await existing_dataset_by_id._update_timestamps(has_been_modified=True)

await _update_dataset_items(
data=data_entries,
Expand Down Expand Up @@ -385,7 +385,7 @@ def _get_storages_dir(cls: type[DatasetClient], memory_storage_client: MemorySto
return memory_storage_client._datasets_directory

@classmethod
def _get_storage_client_cache( # type: ignore
def _get_storage_client_cache(
cls: type[DatasetClient],
memory_storage_client: MemoryStorageClient,
) -> list[DatasetClient]:
Expand Down
44 changes: 22 additions & 22 deletions src/apify/_memory_storage/resource_clients/key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ async def get(self: KeyValueStoreClient) -> dict | None:
found = self._find_or_create_client_by_id_or_name(memory_storage_client=self._memory_storage_client, id=self._id, name=self._name)

if found:
async with found._file_operation_lock: # type: ignore
await found._update_timestamps(has_been_modified=False) # type: ignore
async with found._file_operation_lock:
await found._update_timestamps(has_been_modified=False)
return found._to_resource_info()

return None
Expand All @@ -127,7 +127,7 @@ async def update(self: KeyValueStoreClient, *, name: str | None = None) -> dict:
if name is None:
return existing_store_by_id._to_resource_info()

async with existing_store_by_id._file_operation_lock: # type: ignore
async with existing_store_by_id._file_operation_lock:
# Check that name is not in use already
existing_store_by_name = next(
(store for store in self._memory_storage_client._key_value_stores_handled if store._name and store._name.lower() == name.lower()),
Expand All @@ -146,7 +146,7 @@ async def update(self: KeyValueStoreClient, *, name: str | None = None) -> dict:
await force_rename(previous_dir, existing_store_by_id._resource_directory)

# Update timestamps
await existing_store_by_id._update_timestamps(has_been_modified=True) # type: ignore
await existing_store_by_id._update_timestamps(has_been_modified=True)

return existing_store_by_id._to_resource_info()

Expand Down Expand Up @@ -187,7 +187,7 @@ async def list_keys(

items = []

for record in existing_store_by_id._records.values(): # type: ignore
for record in existing_store_by_id._records.values():
size = len(record['value'])
items.append(
{
Expand Down Expand Up @@ -222,8 +222,8 @@ async def list_keys(
is_last_selected_item_absolutely_last = last_item_in_store == last_selected_item
next_exclusive_start_key = None if is_last_selected_item_absolutely_last else last_selected_item['key']

async with existing_store_by_id._file_operation_lock: # type: ignore
await existing_store_by_id._update_timestamps(has_been_modified=False) # type: ignore
async with existing_store_by_id._file_operation_lock:
await existing_store_by_id._update_timestamps(has_been_modified=False)

return {
'count': len(items),
Expand All @@ -247,7 +247,7 @@ async def _get_record_internal(
if existing_store_by_id is None:
raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self._id)

stored_record = existing_store_by_id._records.get(key) # type: ignore
stored_record = existing_store_by_id._records.get(key)

if stored_record is None:
return None
Expand All @@ -264,8 +264,8 @@ async def _get_record_internal(
except ValueError:
logger.exception('Error parsing key-value store record')

async with existing_store_by_id._file_operation_lock: # type: ignore
await existing_store_by_id._update_timestamps(has_been_modified=False) # type: ignore
async with existing_store_by_id._file_operation_lock:
await existing_store_by_id._update_timestamps(has_been_modified=False)

return record

Expand Down Expand Up @@ -324,22 +324,22 @@ async def set_record(self: KeyValueStoreClient, key: str, value: Any, content_ty
if 'application/json' in content_type and not is_file_or_bytes(value) and not isinstance(value, str):
value = json_dumps(value).encode('utf-8')

async with existing_store_by_id._file_operation_lock: # type: ignore
await existing_store_by_id._update_timestamps(has_been_modified=True) # type: ignore
async with existing_store_by_id._file_operation_lock:
await existing_store_by_id._update_timestamps(has_been_modified=True)
record: KeyValueStoreRecord = {
'key': key,
'value': value,
'contentType': content_type,
}

old_record = existing_store_by_id._records.get(key) # type: ignore
existing_store_by_id._records[key] = record # type: ignore
old_record = existing_store_by_id._records.get(key)
existing_store_by_id._records[key] = record

if self._memory_storage_client._persist_storage:
if old_record is not None and _filename_from_record(old_record) != _filename_from_record(record):
await existing_store_by_id._delete_persisted_record(old_record) # type: ignore
await existing_store_by_id._delete_persisted_record(old_record)

await existing_store_by_id._persist_record(record) # type: ignore
await existing_store_by_id._persist_record(record)

async def _persist_record(self: KeyValueStoreClient, record: KeyValueStoreRecord) -> None:
store_directory = self._resource_directory
Expand Down Expand Up @@ -385,14 +385,14 @@ async def delete_record(self: KeyValueStoreClient, key: str) -> None:
if existing_store_by_id is None:
raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self._id)

record = existing_store_by_id._records.get(key) # type: ignore
record = existing_store_by_id._records.get(key)

if record is not None:
async with existing_store_by_id._file_operation_lock: # type: ignore
del existing_store_by_id._records[key] # type: ignore
await existing_store_by_id._update_timestamps(has_been_modified=True) # type: ignore
async with existing_store_by_id._file_operation_lock:
del existing_store_by_id._records[key]
await existing_store_by_id._update_timestamps(has_been_modified=True)
if self._memory_storage_client._persist_storage:
await existing_store_by_id._delete_persisted_record(record) # type: ignore
await existing_store_by_id._delete_persisted_record(record)

async def _delete_persisted_record(self: KeyValueStoreClient, record: KeyValueStoreRecord) -> None:
store_directory = self._resource_directory
Expand Down Expand Up @@ -437,7 +437,7 @@ def _get_storages_dir(cls: type[KeyValueStoreClient], memory_storage_client: Mem
return memory_storage_client._key_value_stores_directory

@classmethod
def _get_storage_client_cache( # type: ignore
def _get_storage_client_cache(
cls: type[KeyValueStoreClient],
memory_storage_client: MemoryStorageClient,
) -> list[KeyValueStoreClient]:
Expand Down
Loading

0 comments on commit 1a4e5f0

Please sign in to comment.