Skip to content

Commit

Permalink
Improve
Browse files Browse the repository at this point in the history
  • Loading branch information
pm3310 committed Mar 7, 2024
1 parent b1d1e14 commit 3b55e44
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 6 deletions.
71 changes: 70 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,10 @@ batch_inference(
aws_profile='sagemaker-dev',
aws_region='us-east-1',
num_instances=1,
ec2_type='ml.p3.2xlarge'
ec2_type='ml.p3.2xlarge',
aws_access_key_id='YOUR_AWS_ACCESS_KEY_ID'
aws_secret_access_key='YOUR_AWS_SECRET_ACCESS_KEY',
wait=True
)
```

Expand Down Expand Up @@ -2045,3 +2048,69 @@ It builds gateway docker image and starts the gateway locally.
`--platform PLATFORM`: Operating system. Platform in the format `os[/arch[/variant]]`.

`--start-local`: Flag to indicate if to start the gateway locally.


### LLM Batch Inference

#### Name

Command to execute an LLM batch inference job

#### Synopsis
```sh
sagify llm batch-inference --model MODEL --s3-input-location S3_INPUT_LOCATION --s3-output-location S3_OUTPUT_LOCATION --aws-profile AWS_PROFILE --aws-region AWS_REGION --num-instances NUMBER_OF_EC2_INSTANCES --ec2-type EC2_TYPE [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID] [--wait] [--job-name JOB_NAME] [--max-concurrent-transforms MAX_CONCURRENT_TRANSFORMS]
```

#### Description

This command triggers an batch inference job given an LLM model and an batch input.

- The input S3 path should contain a JSONL file or multiple JSONL files. Example of a file:
```json
{"id":1,"text_inputs":"what is the recipe of mayonnaise?"}
{"id":2,"text_inputs":"what is the recipe of fish and chips?"}
```

Each line contains a unique identifier (id) and the corresponding text input (text_inputs). This identifier is crucial for linking inputs to their respective outputs, as illustrated in the output format below:

```json
{'id': 1, 'embedding': [-0.029919596, -0.0011845357, ..., 0.08851079, 0.021398442]}
{'id': 2, 'embedding': [-0.041918136, 0.007127975, ..., 0.060178414, 0.031050885]}
```

By ensuring consistency in the id field between input and output files, you empower your ML use cases with seamless data coherence.

#### Required Flags

`--model MODEL`: LLM model name

`--s3-input-location S3_INPUT_LOCATION` or `-i S3_INPUT_LOCATION`: s3 input data location

`--s3-output-location S3_OUTPUT_LOCATION` or `-o S3_OUTPUT_LOCATION`: s3 location to save predictions

`--num-instances NUMBER_OF_EC2_INSTANCES` or `n NUMBER_OF_EC2_INSTANCES`: Number of ec2 instances

`--ec2-type EC2_TYPE` or `e EC2_TYPE`: ec2 type. Refer to https://aws.amazon.com/sagemaker/pricing/instance-types/

`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command

`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command

#### Optional Flags

`--aws-tags TAGS` or `-a TAGS`: Tags for labeling an inference job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.

`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for the inference job with *SageMaker*

`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role

`--wait`: Optional flag to wait until Batch Inference is finished. (default: don't wait)

`--job-name JOB_NAME`: Optional name for the SageMaker batch inference job

`--max-concurrent-transforms MAX_CONCURRENT_TRANSFORMS`: Optional maximum number of HTTP requests to be made to each individual inference container at one time. Default value: 1

#### Example
```sh
sagify llm batch-inference --model gte-small --s3-input-location s3://sagify-llm-playground/batch-input-data-example/embeddings/ --s3-output-location s3://sagify-llm-playground/batch-output-data-example/embeddings/1/ --aws-profile sagemaker-dev --aws-region us-east-1 --num-instances 1 --ec2-type ml.p3.2xlarge --wait
```
10 changes: 8 additions & 2 deletions sagify/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def batch_inference(
wait=True,
job_name=None,
model_version='1.*',
max_concurrent_transforms=None
max_concurrent_transforms=None,
aws_access_key_id=None,
aws_secret_access_key=None,
):
"""
Executes a batch inference job given a foundation model on SageMaker
Expand Down Expand Up @@ -49,6 +51,8 @@ def batch_inference(
:param job_name: [str, default=None], name for the SageMaker batch transform job
:param model_version: [str, default='1.*'], model version to use
:param max_concurrent_transforms: [int, default=None], max number of concurrent transforms
:param aws_access_key_id: [str, default=None], AWS access key id
:param aws_secret_access_key: [str, default=None], AWS secret access key
:return: [str], transform job status if wait=True.
Valid values: 'InProgress'|'Completed'|'Failed'|'Stopping'|'Stopped'
Expand All @@ -57,7 +61,9 @@ def batch_inference(
aws_profile=aws_profile,
aws_region=aws_region,
aws_role=aws_role,
external_id=external_id
external_id=external_id,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)

return sage_maker_client.foundation_model_batch_transform(
Expand Down
2 changes: 1 addition & 1 deletion sagify/commands/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def gateway(image, start_local, platform):
required=False,
default=None,
type=int,
help=" The maximum number of HTTP requests to be made to each individual inference container at one time"
help="The maximum number of HTTP requests to be made to each individual inference container at one time"
)
@click.option(
u"-a", u"--aws-tags",
Expand Down
19 changes: 17 additions & 2 deletions sagify/sagemaker/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,24 @@


class SageMakerClient(object):
def __init__(self, aws_profile, aws_region, aws_role=None, external_id=None):
def __init__(
self,
aws_profile,
aws_region,
aws_role=None,
external_id=None,
aws_access_key_id=None,
aws_secret_access_key=None
):

if aws_role:
if aws_access_key_id and aws_secret_access_key:
logger.info("AWS access key and secret access key were provided. Using these credentials...")
self.boto_session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region
)
elif aws_role:
logger.info("An IAM role and corresponding external id were provided. Attempting to assume that role...")

sts_client = boto3.client('sts')
Expand Down

0 comments on commit 3b55e44

Please sign in to comment.