Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

add bert model #398

Merged
merged 1 commit into from
Nov 12, 2023
Merged

add bert model #398

merged 1 commit into from
Nov 12, 2023

Conversation

oppiliappan
Copy link
Contributor

No description provided.

Co-authored-by: Lukas Kreussel <[email protected]>
Co-authored-by: Philpax <[email protected]>
@philpax
Copy link
Collaborator

philpax commented Aug 13, 2023

Sorry for not engaging more with this earlier! Very cool - what's the current status of it, and how do you use it? My understanding is that the existing InferenceSession abstractions don't necessarily make sense - I'm wondering if we should change that to expose different methods based on whether the model is autoregressive or seq2seq.

(cc @LLukas22)

@LLukas22
Copy link
Contributor

As far as i know the model loads and infers fine, we only need a more generic InferenceSession, which can handle the logits/embeddings generated by the BERT model. Maybe we could introduce some sort of autoregressive and seq2seq traits but im not to sure about that yet. Also the pooling and norm operations are currently part of the inference call and should probably be moved out into an embedding function. Maybe we could also handle #295 while integrating this. 🤔

@philpax
Copy link
Collaborator

philpax commented Aug 13, 2023

Hmm yeah, the InferenceSession is not parameterised over the model (which is a good thing!) which makes it difficult to switch parts of it on and off. Maybe we should refactor it into a InferenceSessionBase and then reimplement InferenceSessionAutoregressive and InferenceSessionSeq2Seq, then add ModelAutoregressive and ModelSeq2Seq traits that supply the actual start_session method?

Do you have a test case for BERT already that we can use to test/do API design with?

@LLukas22
Copy link
Contributor

Use cases for BERT models can be a bit iffy as a lot of things are possible and need to be implemented accordingly. I would suggest to simply expose the logits somehow and implement a embed function which takes some sort of pooling function as an argument which performs the mean and pooling operations on the ggml side. Alternatively we could move the pooling to the rust side as we mainly deal with small outputs.

99% of user will use these models only to generate embeddings, we should probably focus on that.

@oppiliappan
Copy link
Contributor Author

Sorry for not engaging more with this earlier! Very cool - what's the current status of it, and how do you use it? My understanding is that the existing InferenceSession abstractions don't necessarily make sense - I'm wondering if we should change that to expose different methods based on whether the model is autoregressive or seq2seq

its currently in a state of flux for my usecase: i am looking to add batching support and have this run properly on metal. i agree that the current outputs of the inference model may not be perfect, but for now, it does suffice for just embeddings (with the hacks done in this PR: such as disabling offloading for pooling). i do believe this is usable for embeddings in current state - the embeddings example is on par with the embeddings produced by bert.cpp for the same model files.

@LLukas22
Copy link
Contributor

@nerdypepper Looks like they are working on proper matrix x matrix support for metal: ggerganov/llama.cpp#2615

@carlgronvald
Copy link

carlgronvald commented Oct 31, 2023

I am interested in using BERT for encodings, is there any work I could do on this pull request to get it in a state where it can be merged? Thank you

@philpax
Copy link
Collaborator

philpax commented Oct 31, 2023

Hm, that's a good question. @nerdypepper I'll try to get #428 across the line, but is that a blocker for merging this?

@oppiliappan
Copy link
Contributor Author

@philpax i believe it is, inference on metal does not work without latest kernels from llama-cpp. that being said, i have no idea how BERT performs against the branch on #428, this will require some testing. would be happy to pick this up over the weekend. this pr needs some patches from this branch on my fork.

side note, gpu inference of any form mostly only benefits from batching, for which, curently, there is no interface in KnownModel, (perhaps something like KnownModel::batch_evaluate).

@philpax philpax changed the base branch from main to develop November 12, 2023 20:16
@philpax philpax merged commit 52c2bb6 into rustformers:develop Nov 12, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants