Skip to content

Latest commit

 

History

History
95 lines (70 loc) · 3.92 KB

README.md

File metadata and controls

95 lines (70 loc) · 3.92 KB

starcoder-jax

Introduction

This repository is a Jax/Flax implementation of the StarCoder model. We implement the inference code of GPTBigCode architecture. With this repository, you can run GPTBigCode based models such as starcoder, starcoderbase and starcoderplus.

The StarCoder models have 15.5B parameters and it requires about 63GB of memory for parameters only. Since tpu-v3-8 consists of 8 cores of 16GB, it is necessary to shard the parameters into multiple devices. Therefore this repository provides 2D parallelism (data parallelism and model parallelism) for inference.

Requirements

The below libraries are required to run the starcoder.

  • jax
  • flax
  • chex
  • torch
  • transformers

If you are trying to run on Cloud TPU VM, run the below commands:

$ pip install torch --index-url https://download.pytorch.org/whl/cpu
$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$ pip install flax chex transformers

Also you may need to login to the huggingface hub. Use the below command:

$ ~/.local/bin/huggingface-cli login [token]

Usage

This repository provides an interface to generate a text from the model. First of all, create a device mesh for parallelism and load model weights. The Generator class will automatically convert the PyTorch weights to Jax/Flax format. Note that you must specify your huggingface API token to load StarCoder models (because of the licence agreement).

# Define a parallelism rule.
mesh = Mesh(mesh_utils.create_device_mesh((1, 8)), ("dp", "mp"))

# Load the model from huggingface and shard the parameters into multiple devices.
generator = Generator.from_huggingface("bigcode/starcoder", use_auth_token=True)
generator.shard(mesh)

After loading the weights, you should prepare an initial input for the prompt context. The Generator class also provides a method to encode the text and its generation options:

context = """
def print_len(x):
    '''print the length of the string.'''
"""

initial, hparams = generator.prepare_context(
    context,
    max_length=8192,
    temperature=0.8,
    top_p=0.92,
)

The output hparams contains the hyperparameters for generation (temperature and top_p). As you can see below, it is reused while predicting next tokens. You can stack the hyperparameters with their initial inputs to make a batch with using different generation options.

Iterative generation

Like ChatGPT, you can iteratively generate next tokens from the model for streaming the generation progress.

with mesh:
    outputs = generator.generate_from_context(**initial, **hparams, rng=rng)
    print(generator.tokenizer.decode(int(outputs[0][0])), end="", flush=True)

    for _ in range(1024):
        outputs = generator.generate_next_tokens(*outputs, **hparams)
        print(generator.tokenizer.decode(int(outputs[0][0])), end="", flush=True)

Generate at once

Instead, you can generate a sentence at once like Bard. It can be accomplished by putting the above codes in a single function and compiling it. generator.generate_at_once performs the above codes with aggregating the tokens.

with mesh:
    tokens, rng = generator.generate_at_once(**initial, **hparams, rng=rng, max_new_tokens=1024)
print(generator.tokenizer.decode(tokens[0][tokens[0] != -1]))

For more details, check out the examples.

Examples

Acknowledgements

Tensorflow Research Cloud provides the TPU Resources for testing.

License

MIT