Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export Gemma 2 from Keras to Saved Model? #595

Open
JeeDevUser opened this issue Dec 7, 2024 · 7 comments
Open

Export Gemma 2 from Keras to Saved Model? #595

JeeDevUser opened this issue Dec 7, 2024 · 7 comments

Comments

@JeeDevUser
Copy link

Ok, downloaded Gemma2 from
https://www.kaggle.com/models/google/gemma-2/keras

-unpacked and got the following files:

/assets
config.json
metadata.json
model.weights.h5
tokenizer.json

-tried to load model by using following Python script:

import json
from tensorflow import keras

# 1. Path to the config and weights:
config_path = "d:/Install/TensorFlow/models/Gemma_2/config.json"
weights_path = "d:/Install/TensorFlow/models/Gemma_2/model.weights.h5"
saved_model_dir = "d:/Install/TensorFlow/models/Gemma_2/gemma2_saved_model"

# 2. load model configuration
with open(config_path, 'r') as f:
    model_config = json.load(f)

# 3. Reconstruct the model:
model = keras.models.model_from_json(json.dumps(model_config))

# 4. Load the weights:
model.load_weights(weights_path)

# 5. Finally, save model as  TensorFlow SavedModel format:
model.save(saved_model_dir, save_format="tf")

print("Model saved to the :", saved_model_dir)

-at the line

model = keras.models.model_from_json(json.dumps(model_config))

got following error:

ValueError: Unknown layer: GemmaBackbone. Please ensure this object is passed to the custom_objectsargument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

Questions:

but not sure what to do with GemmaBackbone...huh?

any help?

@Craigacp
Copy link
Collaborator

Craigacp commented Dec 7, 2024

After downloading the tar file this was sufficient for me to export it:

import keras_hub
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma2-keras-gemma2_instruct_2b_en-v1")
gemma_lm.export("gemma2-2b-it-saved-model")

The saved model is massive (~19GB) for some reason I don't understand, but it looks like it should be callable in Java:

jshell> import org.tensorflow.*

jshell> var model = SavedModelBundle.load("gemma2-2b-it-saved-model","serve")
model ==> org.tensorflow.SavedModelBundle@194fad1

jshell> model.signatures()
$3 ==> [Signature for "serving_default":
	Method: "tensorflow/serving/predict"
	Inputs:
		"padding_mask": dtype=DT_FLOAT, shape=(-1, -1)
		"token_ids": dtype=DT_FLOAT, shape=(-1, -1)
	Outputs:
		"output_0": dtype=DT_FLOAT, shape=(-1, -1, 256000)
, Signature for "serve":
	Method: "tensorflow/serving/predict"
	Inputs:
		"token_ids": dtype=DT_FLOAT, shape=(-1, -1)
		"padding_mask": dtype=DT_FLOAT, shape=(-1, -1)
	Outputs:
		"output_0": dtype=DT_FLOAT, shape=(-1, -1, 256000)
]

@JeeDevUser
Copy link
Author

JeeDevUser commented Dec 8, 2024

@Craigacp , got similar file size by using a little different Python script

/variables/variables.data-00000-of-00001 - **20.917.805.737** in size
saved_model.pb - 4.056.008

I thought I made a mistake, but now I see that we got similar results. What exactly is the size of your files?

Besides...I'm primarily a Java developer, and all this conversion to the saved model format is just a hassle for me. It would be really nice if there was some repository of popular models in the saved model format (or at least a collection of Python scripts for saved model conversions) , so they would be ready to use in the Tensorflow Java API. I would make something like that myself and publish it somewhere, but I'm not sure I can do something like that. I don't have enough knowledge for that :-(

@JeeDevUser
Copy link
Author

Now, let's try to load this monster using Java code. We'll see what happens.

@JeeDevUser
Copy link
Author

It ate up all 16GB on my computer, and failed to load :-( ...

   public static void main(String[] args) {
      SavedModelBundle model = null;
      try {
         String modelPath = "d:/Install/TensorFlow/models/Gemma_2/gemma2_saved_model";
         model = SavedModelBundle.load(modelPath, "serve");
         System.out.println("Succesfuly loaded the model: " + model);
      }finally {
         if (model != null) {
            model.close();
         }
      }
   }// of main()

It seems too bulky to use....
@Craigacp , Have you tried loading it from Java?

@Craigacp
Copy link
Collaborator

Craigacp commented Dec 8, 2024

I loaded it in in jshell, but my laptop has 48GB of RAM. The model should take about 8GB for the parameters in fp32, I'm not sure why it's 19GB. I don't have time to get the tokenizer up and running on TF-Java but it should be possible.

The problem is that Google don't release things in TensorFlow SavedModel format anymore as they are focusing on Keras and JAX, so it's a second class citizen in the ecosystem. They deprecated TensorFlow hub which used to have SavedModels in it, and pushed everything over to Kaggle. Older models on there still have TF SavedModel checkpoints, but none of the LLMs do.

If you've got a HuggingFace account you can request access to LLaMA 3.1 (which should be straightforward unless you work at a massive multinational social media company) and then export the 1B version of that. When I did it I got a 4.6GB SavedModel which I also managed to load and execute to generate the probability distribution over tokens. I didn't check to see if it was producing sensible output, I don't have a harness with a tokenizer where I can easily generate from a TF-Java based LLM.

In [1]: import huggingface_hub

In [2]: huggingface_hub.login(token="<huggingface-api-token-for-your-account>")

In [3]: import keras_hub

In [4]: model = keras_hub.models.Llama3CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16")
model.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.47G/2.47G [00:58<00:00, 42.1MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.09M/9.09M [00:00<00:00, 45.2MB/s]
tokenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54.5k/54.5k [00:00<00:00, 259MB/s]

In [5]: model.export("llama-3.2-1b-instruct-saved-model")

@JeeDevUser
Copy link
Author

@Craigacp , thank you for your honest and detailed insight. If Google has indeed abandoned TensorFlow, then the situation is not good for projects like the TensorFlow Java API. Because, as far as I understand, the TensorFlow Java API can only load the Saved Model format.

Overall, I've tried a few Java tools for working with pre-trained AI models. It's nowhere near great, and the impression is that unless some serious progress is made, Java will be eliminated from the AI ​​story.

What I'm trying to do is to create an artificial intelligence system (let's call it that), based on pre-trained models, that will work locally, meaning it wouldn't rely on the services of large providers. I think that I'm only doing as much training and fine-tuning of the model as necessary - preferably not a little. But, after a month or so, I have to admit that it's going quite hard, and I don't know how and with what tools to continue.

All in all, the situation is not easy at all :-(

@Craigacp
Copy link
Collaborator

Craigacp commented Dec 9, 2024

We did have an effort a few years ago to support Keras's model formats, but we don't have enough contributors currently to complete that and Keras changed its file format again in the meantime.

You're right that we don't have good support for LLMs in TF-Java, we've been focused on existing deep learning production use cases. Now we've finished the 1.0 release we have a bit more bandwidth for other things, but I'm not sure what we're going to focus on. There are other Java implementations for LLMs, ones I've been involved in are onnxruntime-genai (I wrote the ONNX Runtime Java API and reviewed the ORT GenAI Java API for the Microsoft team), and llama3.java (Alfonso is a colleague here at Oracle Labs). There's also jlama but I've not investigated that much.

If you want to train LoRAs or fine-tune LLMs in Java, then that's currently very difficult. TF-Java will allow training, and potentially you could fine-tune a model in Amazon's DJL, but I don't think anything has LoRA training support on the Java platform. It could be written on top of TF-Java, but would be extremely complicated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants