Skip to content

Commit

Permalink
Tree: Swap from map to list comprehensions
Browse files Browse the repository at this point in the history
List comprehensions are the more "pythonic" way to approach mapping
values to a list. They're also more flexible across different collection
types rather than the inbuilt map method. It's best to keep one convention
rather than splitting down two.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed May 26, 2024
1 parent 6bf75d9 commit ae6a4af
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 27 deletions.
27 changes: 11 additions & 16 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,9 @@ def progress(loaded_modules: int, total_modules: int,
self.gpu_split_auto = gpu_split_auto

autosplit_reserve_megabytes = unwrap(kwargs.get("autosplit_reserve"), [96])
self.autosplit_reserve = list(
map(
lambda value: int(math.ceil(value * 1024**2)),
autosplit_reserve_megabytes,
)
)
self.autosplit_reserve = [
int(math.ceil(value * 1024**2)) for value in autosplit_reserve_megabytes
]
elif gpu_count > 1:
# Manual GPU split
self.gpu_split = kwargs.get("gpu_split")
Expand Down Expand Up @@ -681,21 +678,19 @@ def get_special_tokens(
}

def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
top_tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
),
token_ids.flatten().tolist(),
top_tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
)
)
for index in token_ids.flatten().tolist()
]

top_values = torch.log(token_probs).flatten().tolist()

# Cannot return -inf in JSON
cleaned_values = list(
map(lambda value: -1000 if value == float("-inf") else value, top_values)
)
cleaned_values = [
-1000 if value == float("-inf") else value for value in top_values
]

return dict(zip_longest(top_tokens, cleaned_values))

Expand Down
16 changes: 7 additions & 9 deletions common/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,13 @@ def _get_repo_info(repo_id, revision, token):

api_client = HfApi()
repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token)
return list(
map(
lambda filename: {
"filename": filename,
"url": hf_hub_url(repo_id, filename, revision=revision),
},
repo_tree,
)
)
return [
{
"filename": filename,
"url": hf_hub_url(repo_id, filename, revision=revision),
}
for filename in repo_tree
]


def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]):
Expand Down
2 changes: 1 addition & 1 deletion common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def get_all_presets():
"""Fetches all sampler override presets from the overrides directory"""

override_directory = pathlib.Path("sampler_overrides")
preset_files = map(lambda file: file.stem, override_directory.glob("*.yml"))
preset_files = [file.stem for file in override_directory.glob("*.yml")]

return preset_files

Expand Down
2 changes: 1 addition & 1 deletion endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def unload_model():
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
template_strings = [template.stem for template in templates]
return TemplateList(data=template_strings)


Expand Down

0 comments on commit ae6a4af

Please sign in to comment.