-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dtype flexibility #49
Comments
(cc @raphaelsty) |
Awesome findings @bclavie, if you want to create a MR, or feel free to share some code here and I'll co-author the commit with you :) |
Thanks! My personal branch is a complete mess since it's mostly RAGatouille related, so if you don't mind I can share a few snippets! In if self.model_kwargs and self.model_kwargs.get("torch_dtype") == torch.float16 or self.model_kwargs.get("torch_dtype") == 'float16':
self.half() in def rerank(
documents_ids: list[list[int | str]],
queries_embeddings: list[list[float | int] | np.ndarray | torch.Tensor],
documents_embeddings: list[list[float | int] | np.ndarray | torch.Tensor],
device: str = None,
fp16: bool = False,
) -> list[list[dict[str, float]]]:
...
if fp16:
query_embeddings = query_embeddings.astype(np.float16)
query_documents_embeddings = [
doc_embedding.astype(np.float16) for doc_embedding in query_documents_embeddings
]
.... in class ColBERT:
...
def __init__(self, index: Voyager, fp16: bool = False) -> None:
self.index = index
self.fp16 = fp16
def retrieve(self,
...
):
...
if self.fp16:
documents_embeddings = [
[doc_embedding.astype(np.float16) for doc_embedding in query_docs]
for query_docs in documents_embeddings
]
...
reranking_results.extend(
rerank(
documents_ids=documents_ids,
queries_embeddings=queries_embeddings_batch,
documents_embeddings=documents_embeddings,
device=device,
fp16=self.fp16,
)
) Sorry this is messy and a bit hardcoded, as you'd probably want to have the option to also do bfloat16 for the model loading |
Out of scope for the main issue but if you're looking for better support for dtypes in numpy for future improvements, ml-dtypes adds full support for bfloat16 and various fp8 implementations (including voyager-friendly E4M3) to numpy arrays. |
Passing I can explore more once I am back from vacation. |
May I once again suggest that you are actually off on your time off? 😂
Oh I see. IMO figuring out a "perfect" solution is pretty important, especially as I hear some people maintain a late-interaction wrapper library and are really looking forward to making it completely seamless/invisible-to-the-user to switch backend back and forth between pylate and stanfordnlp. It's mostly going smoothly so far, save for some small issues, the dtype problem and having full interoperatibility between models. I'll open a separate issues in the next few days to request utils to convert models on the fly 😄 |
I cleaned up the loading logic in #52, to directly load the weights from stanford-nlp repository. This means that we do not have to have a subfolder for the transformer module, either it's a stanford-nlp repo or a PyLate one and so it is at the root. You'll have to rollback the colbert-small repo to previous commit to only include the stanford-nlp weights, but you can now load the model in fp16 using Tell me if I am wrong but for the rest, adding this to rank.py seems enough:
with an attribute that can be set in rerank and is forwarded by the retrieve function. Edit: this naive casting seems to hurt the performance compared to fp32 on my high-end setup. |
Hey! Congrats on the release 😄
My first issue, as promised to @NohTow: scoring is pretty slow, and I think it could be greatly improved by adding extra flexibility, in terms of dtype? Noticeably:
model_kwargs={"torch_dtype": torch.float16}
results in the model being fp32, which slows things down a lot for almost no performance improvement, especially on weaker hardware. A separate.half()
call afterwards is needed.retrieve
doesn't have an option to convert the documents fetched from voyager from float32 to float16, and neither doescolbert_scores
orrerank
. This means we end up needing to do expensive float32 scoring without being able to opt out.On my machine with very dirty changes, going from the hardcoded float32 to this version took the time to eval on Scifact from ~1.35s/query to 0.85s. I think this is well worth implementing since the complexity isn't gigantic!
More minor typing flexibility change
The text was updated successfully, but these errors were encountered: