-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main'
- Loading branch information
Showing
13 changed files
with
357 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import List | ||
|
||
import torch | ||
from transformers import AutoTokenizer, AutoModel | ||
|
||
from aloha.service.streamer import ManagedModel | ||
|
||
SEED = 0 | ||
torch.manual_seed(SEED) | ||
torch.cuda.manual_seed(SEED) | ||
|
||
|
||
class TextUnmaskModel: | ||
def __init__(self, max_sent_len=16, model_path="bert-base-uncased"): | ||
self.model_path = model_path | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | ||
self.transformer = AutoModel.from_pretrained(self.model_path) | ||
self.transformer.eval() | ||
self.transformer.to(device="cuda") | ||
self.max_sent_len = max_sent_len | ||
|
||
def predict(self, batch: List[str]) -> List[str]: | ||
"""predict masked word""" | ||
batch_inputs = [] | ||
masked_indexes = [] | ||
|
||
for text in batch: | ||
tokenized_text = self.tokenizer.tokenize(text) | ||
if len(tokenized_text) > self.max_sent_len - 2: | ||
tokenized_text = tokenized_text[: self.max_sent_len - 2] | ||
|
||
tokenized_text = ['[CLS]'] + tokenized_text + ['[SEP]'] | ||
tokenized_text += ['[PAD]'] * (self.max_sent_len - len(tokenized_text)) | ||
|
||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) | ||
batch_inputs.append(indexed_tokens) | ||
masked_indexes.append(tokenized_text.index('[MASK]')) | ||
|
||
tokens_tensor = torch.tensor(batch_inputs).to("cuda") | ||
|
||
with torch.no_grad(): | ||
# prediction_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` | ||
prediction_scores = self.transformer(tokens_tensor)[0] | ||
|
||
batch_outputs = [] | ||
for i in range(len(batch_inputs)): | ||
predicted_index = torch.argmax(prediction_scores[i, masked_indexes[i]]).item() | ||
predicted_token = self.tokenizer.convert_ids_to_tokens(predicted_index) | ||
batch_outputs.append(predicted_token) | ||
|
||
return batch_outputs | ||
|
||
|
||
class ManagedBertModel(ManagedModel): | ||
def init_model(self): | ||
self.model = TextUnmaskModel() | ||
|
||
def predict(self, batch): | ||
return self.model.predict(batch) | ||
|
||
|
||
def test_simple(): | ||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | ||
model = AutoModel.from_pretrained("bert-base-uncased") | ||
inputs = tokenizer("Hello! My name is [MASK]!", return_tensors="pt") | ||
outputs = model(**inputs) | ||
print(outputs) | ||
|
||
predicted_index = torch.argmax(outputs[1]).item() | ||
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index) | ||
print(predicted_token) | ||
|
||
|
||
def test_batch(): | ||
batch_text = [ | ||
"twinkle twinkle [MASK] star.", | ||
"Happy birthday to [MASK].", | ||
'the answer to life, the [MASK], and everything.' | ||
] | ||
model = TextUnmaskModel() | ||
outputs = model.predict(batch_text) | ||
print(outputs) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_simple() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from aloha.logger import LOG | ||
from aloha.service.api.v0 import APIHandler | ||
|
||
|
||
class MultipartHandler(APIHandler): | ||
def response(self, params=None, *args, **kwargs): | ||
LOG.debug(params) | ||
return params | ||
|
||
|
||
default_handlers = [ | ||
# internal API: QueryDB Postgres with sql directly | ||
(r"/api_internal/multipart", MultipartHandler), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import time | ||
|
||
import requests | ||
|
||
from ...logger import LOG | ||
|
||
|
||
def iter_over_request_files(request, url_files): | ||
for file_key, files in request.files.items(): # iter over files uploaded by multipart | ||
for f in files: | ||
file_name, content_type = f["filename"], f["content_type"] | ||
body = f.get('body', b"") | ||
LOG.info(f"File {file_name} from multipart has content type {content_type} and length bytes={len(body)}") | ||
yield file_key, file_name, content_type, body | ||
|
||
for file_key, list_url in {'url_files': url_files or []}.items(): # iter over files specified by `url_files` | ||
for url in sorted(set(list_url)): | ||
try: | ||
t_start = time.time() | ||
resp = requests.get(url, stream=True) # download the file from given url | ||
if resp.status_code == 200: | ||
body = resp.content | ||
content_type = resp.headers.get("Content-Type", "UNKNOWN") | ||
else: | ||
raise RuntimeError("Failed to download file after %s seconds with code=%s from URL %s" % ( | ||
time.time() - t_start, resp.status_code, url | ||
)) | ||
del resp | ||
except Exception as e: | ||
raise e | ||
t_cost = time.time() - t_start | ||
LOG.info(f"File {url} has content type {content_type} and length bytes={len(body)}, downloaded in {t_cost} seconds") | ||
yield 'url_files', url, content_type, body |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.