diff --git a/autocomplete/gpt2.py b/autocomplete/gpt2.py index 7f9bbe2..31f4979 100644 --- a/autocomplete/gpt2.py +++ b/autocomplete/gpt2.py @@ -20,14 +20,14 @@ def predict_with_original_gpt2(prompts): for prompt in prompts: # Generate text using the model. Verbose set to False to prevent logging generated sequences. generated = model.generate(prompt, verbose=False) - generated = generated[0] - print("=============================================================================") print(generated) - print("=============================================================================") + print("=" * 20) -def train(model_dir="outputs/fine-tuned/", train_file="download/train.txt", valid_file="download/valid.txt", +def train(model_dir="outputs/fine-tuned/", + train_file="download/train.txt", + valid_file="download/valid.txt", num_train_epochs=3): train_args = { "reprocess_input_data": True, diff --git a/autocomplete/server.py b/autocomplete/server.py index 791ad0b..e612a0c 100644 --- a/autocomplete/server.py +++ b/autocomplete/server.py @@ -2,6 +2,7 @@ @author:XuMing(xuming624@qq.com) @description: Server """ +import argparse import uvicorn import sys import os @@ -16,8 +17,11 @@ pwd_path = os.path.abspath(os.path.dirname(__file__)) use_cuda = torch.cuda.is_available() # Use finetuned GPT2 model -model_dir = os.path.join(pwd_path, "outputs/fine-tuned/") -gpt2_infer = Infer(model_name="gpt2", model_dir=model_dir, use_cuda=use_cuda) +parser = argparse.ArgumentParser() +parser.add_argument("--model_name_or_path", type=str, default="shibing624/code-autocomplete-gpt2-base", + help="Model save dir or model name") +args = parser.parse_args() +gpt2_infer = Infer(model_name="gpt2", model_dir=args.model_name_or_path, use_cuda=use_cuda) # define the app app = FastAPI()