Skip to content

Commit

Permalink
add server.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Feb 11, 2022
1 parent da089fe commit 5d44507
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 4 additions & 4 deletions autocomplete/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def predict_with_original_gpt2(prompts):
for prompt in prompts:
# Generate text using the model. Verbose set to False to prevent logging generated sequences.
generated = model.generate(prompt, verbose=False)

generated = generated[0]
print("=============================================================================")
print(generated)
print("=============================================================================")
print("=" * 20)


def train(model_dir="outputs/fine-tuned/", train_file="download/train.txt", valid_file="download/valid.txt",
def train(model_dir="outputs/fine-tuned/",
train_file="download/train.txt",
valid_file="download/valid.txt",
num_train_epochs=3):
train_args = {
"reprocess_input_data": True,
Expand Down
8 changes: 6 additions & 2 deletions autocomplete/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
@author:XuMing([email protected])
@description: Server
"""
import argparse
import uvicorn
import sys
import os
Expand All @@ -16,8 +17,11 @@
pwd_path = os.path.abspath(os.path.dirname(__file__))
use_cuda = torch.cuda.is_available()
# Use finetuned GPT2 model
model_dir = os.path.join(pwd_path, "outputs/fine-tuned/")
gpt2_infer = Infer(model_name="gpt2", model_dir=model_dir, use_cuda=use_cuda)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="shibing624/code-autocomplete-gpt2-base",
help="Model save dir or model name")
args = parser.parse_args()
gpt2_infer = Infer(model_name="gpt2", model_dir=args.model_name_or_path, use_cuda=use_cuda)

# define the app
app = FastAPI()
Expand Down

0 comments on commit 5d44507

Please sign in to comment.