Skip to content

Commit

Permalink
Add Falcon Preprocessor. (keras-team#1498)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Mar 8, 2024
1 parent 7ef18a1 commit fe5a53b
Show file tree
Hide file tree
Showing 4 changed files with 547 additions and 0 deletions.
178 changes: 178 additions & 0 deletions keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.models.falcon.falcon_preprocessor import FalconPreprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras_nlp_export("keras_nlp.models.FalconCausalLMPreprocessor")
class FalconCausalLMPreprocessor(FalconPreprocessor):
"""Falcon Causal LM preprocessor.
This preprocessing layer is meant for use with
`keras_nlp.models.FalconCausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence.
For use with generation, the layer also exposes two methods
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
is attached to a `keras_nlp.models.FalconCausalLM` instance, these methods
will be called implicitly in `generate()`. They can also be called
standalone (e.g. to precompute preprocessing inputs for generation in a
separate process).
Args:
tokenizer: A `keras_nlp.models.FalconTokenizer` instance.
sequence_length: The length of the packed inputs.
add_start_token: If `True`, the preprocessor will prepend the tokenizer
start token to each input sequence.
add_end_token: If `True`, the preprocessor will append the tokenizer
end token to each input sequence.
Call arguments:
x: A string, `tf.Tensor` or list of python strings.
y: Label data. Should always be `None` as the layer generates labels.
sample_weight: Label weights. Should always be `None` as the layer
generates label weights.
sequence_length: Pass to override the configured `sequence_length` of
the layer.
Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.FalconCausalLMPreprocessor.from_preset(
"falcon_refinedweb_1b_en"
)
# Tokenize and pack a single sentence.
sentence = tf.constant("League of legends")
preprocessor(sentence)
# Same output.
preprocessor("League of legends")
# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["Taco tuesday", "Fish taco please!"])
# Map a dataset to preprocess a single sentence.
features = tf.constant(
[
"Avatar 2 is amazing!",
"Well, I am not sure.",
]
)
labels = tf.constant([1, 0])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Map a dataset to preprocess unlabled sentences.
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

def call(
self,
x,
y=None,
sample_weight=None,
sequence_length=None,
):
if y is not None or sample_weight is not None:
logging.warning(
"`FalconCausalLMPreprocessor` generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)
sequence_length = sequence_length or self.sequence_length

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

def generate_preprocess(
self,
x,
sequence_length=None,
):
"""Convert strings to integer token input for generation.
Similar to calling the layer for training, this method takes in strings
or tensor strings, tokenizes and packs the input, and computes a padding
mask masking all inputs not filled in with a padded value.
Unlike calling the layer for training, this method does not compute
labels and will never append a `tokenizer.end_token_id` to the end of
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
if not self.built:
self.build(None)

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
"""Convert integer token output to strings for generation.
This method reverses `generate_preprocess()`, by first removing all
padding and start/end tokens, and then converting the integer sequence
back to a string.
"""
if not self.built:
self.build(None)

token_ids, padding_mask = x["token_ids"], x["padding_mask"]
token_ids = ops.convert_to_numpy(token_ids)
padding_mask = ops.convert_to_numpy(padding_mask)
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
return self.tokenizer.detokenize(token_ids)
94 changes: 94 additions & 0 deletions keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from keras_nlp.models.falcon.falcon_causal_lm_preprocessor import (
FalconCausalLMPreprocessor,
)
from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_nlp.tests.test_case import TestCase


class FalconCausalLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|endoftext|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.tokenizer = FalconTokenizer(
vocabulary=self.vocab,
merges=self.merges,
)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = ["airplane at airport"]

def test_preprocessor_basics(self):
self.run_preprocessor_test(
cls=FalconCausalLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
[[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
),
)

def test_no_start_end_token(self):
input_data = ["airplane at airport"] * 4

preprocessor = FalconCausalLMPreprocessor(
**self.init_kwargs,
add_start_token=False,
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)

def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])

def test_generate_postprocess(self):
input_data = {
"token_ids": [6, 1, 3, 4, 2, 5, 0, 0],
"padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
}
preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "airplane at airport")

@pytest.mark.extra_large
def test_all_presets(self):
for preset in FalconCausalLMPreprocessor.presets:
self.run_preset_test(
cls=FalconCausalLMPreprocessor,
preset=preset,
input_data=self.input_data,
)
Loading

0 comments on commit fe5a53b

Please sign in to comment.