-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Export Gemma 2 from Keras to Saved Model? #595
Comments
After downloading the tar file this was sufficient for me to export it: import keras_hub
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma2-keras-gemma2_instruct_2b_en-v1")
gemma_lm.export("gemma2-2b-it-saved-model") The saved model is massive (~19GB) for some reason I don't understand, but it looks like it should be callable in Java:
|
@Craigacp , got similar file size by using a little different Python script
I thought I made a mistake, but now I see that we got similar results. What exactly is the size of your files? Besides...I'm primarily a Java developer, and all this conversion to the saved model format is just a hassle for me. It would be really nice if there was some repository of popular models in the saved model format (or at least a collection of Python scripts for saved model conversions) , so they would be ready to use in the Tensorflow Java API. I would make something like that myself and publish it somewhere, but I'm not sure I can do something like that. I don't have enough knowledge for that :-( |
Now, let's try to load this monster using Java code. We'll see what happens. |
It ate up all 16GB on my computer, and failed to load :-( ...
It seems too bulky to use.... |
I loaded it in in jshell, but my laptop has 48GB of RAM. The model should take about 8GB for the parameters in fp32, I'm not sure why it's 19GB. I don't have time to get the tokenizer up and running on TF-Java but it should be possible. The problem is that Google don't release things in TensorFlow SavedModel format anymore as they are focusing on Keras and JAX, so it's a second class citizen in the ecosystem. They deprecated TensorFlow hub which used to have SavedModels in it, and pushed everything over to Kaggle. Older models on there still have TF SavedModel checkpoints, but none of the LLMs do. If you've got a HuggingFace account you can request access to LLaMA 3.1 (which should be straightforward unless you work at a massive multinational social media company) and then export the 1B version of that. When I did it I got a 4.6GB SavedModel which I also managed to load and execute to generate the probability distribution over tokens. I didn't check to see if it was producing sensible output, I don't have a harness with a tokenizer where I can easily generate from a TF-Java based LLM. In [1]: import huggingface_hub
In [2]: huggingface_hub.login(token="<huggingface-api-token-for-your-account>")
In [3]: import keras_hub
In [4]: model = keras_hub.models.Llama3CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16")
model.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.47G/2.47G [00:58<00:00, 42.1MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.09M/9.09M [00:00<00:00, 45.2MB/s]
tokenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54.5k/54.5k [00:00<00:00, 259MB/s]
In [5]: model.export("llama-3.2-1b-instruct-saved-model") |
@Craigacp , thank you for your honest and detailed insight. If Google has indeed abandoned TensorFlow, then the situation is not good for projects like the TensorFlow Java API. Because, as far as I understand, the TensorFlow Java API can only load the Saved Model format. Overall, I've tried a few Java tools for working with pre-trained AI models. It's nowhere near great, and the impression is that unless some serious progress is made, Java will be eliminated from the AI story. What I'm trying to do is to create an artificial intelligence system (let's call it that), based on pre-trained models, that will work locally, meaning it wouldn't rely on the services of large providers. I think that I'm only doing as much training and fine-tuning of the model as necessary - preferably not a little. But, after a month or so, I have to admit that it's going quite hard, and I don't know how and with what tools to continue. All in all, the situation is not easy at all :-( |
We did have an effort a few years ago to support Keras's model formats, but we don't have enough contributors currently to complete that and Keras changed its file format again in the meantime. You're right that we don't have good support for LLMs in TF-Java, we've been focused on existing deep learning production use cases. Now we've finished the 1.0 release we have a bit more bandwidth for other things, but I'm not sure what we're going to focus on. There are other Java implementations for LLMs, ones I've been involved in are onnxruntime-genai (I wrote the ONNX Runtime Java API and reviewed the ORT GenAI Java API for the Microsoft team), and llama3.java (Alfonso is a colleague here at Oracle Labs). There's also jlama but I've not investigated that much. If you want to train LoRAs or fine-tune LLMs in Java, then that's currently very difficult. TF-Java will allow training, and potentially you could fine-tune a model in Amazon's DJL, but I don't think anything has LoRA training support on the Java platform. It could be written on top of TF-Java, but would be extremely complicated. |
Ok, downloaded Gemma2 from
https://www.kaggle.com/models/google/gemma-2/keras
-unpacked and got the following files:
-tried to load model by using following Python script:
-at the line
model = keras.models.model_from_json(json.dumps(model_config))
got following error:
ValueError: Unknown layer: GemmaBackbone. Please ensure this object is passed to the
custom_objectsargument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
Questions:
but not sure what to do with GemmaBackbone...huh?
any help?
The text was updated successfully, but these errors were encountered: