Skip to content

Bunmi e #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .ipynb_checkpoints/Untitled-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}
30 changes: 30 additions & 0 deletions 3.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
Collecting opentok
Downloading https://files.pythonhosted.org/packages/e5/36/39ff61bab71ad73bfaeab4a5f347dba94ab7889ef7d22958b84c424455b4/opentok-3.1.0.tar.gz
Installing build dependencies: started
Installing build dependencies: still running...
Installing build dependencies: finished with status 'done'
Getting requirements to build wheel: started
Getting requirements to build wheel: finished with status 'done'
Preparing wheel metadata: started
Preparing wheel metadata: finished with status 'done'
Collecting python-jose (from opentok)
Downloading https://files.pythonhosted.org/packages/bd/2d/e94b2f7bab6773c70efc70a61d66e312e1febccd9e0db6b9e0adf58cbad1/python_jose-3.3.0-py2.py3-none-any.whl
Requirement already satisfied: requests in c:\users\bunmi\onedrive\documents\lib\site-packages (from opentok) (2.25.0)
Requirement already satisfied: six in c:\users\bunmi\onedrive\documents\lib\site-packages (from opentok) (1.15.0)
Requirement already satisfied: pytz in c:\users\bunmi\onedrive\documents\lib\site-packages (from opentok) (2019.3)
Requirement already satisfied: pyasn1 in c:\users\bunmi\onedrive\documents\lib\site-packages (from python-jose->opentok) (0.4.8)
Collecting ecdsa!=0.15 (from python-jose->opentok)
Downloading https://files.pythonhosted.org/packages/4a/b6/b678b080967b2696e9a201c096dc076ad756fb35c87dca4e1d1a13496ff7/ecdsa-0.17.0-py2.py3-none-any.whl (119kB)
Requirement already satisfied: rsa in c:\users\bunmi\onedrive\documents\lib\site-packages (from python-jose->opentok) (4.6)
Requirement already satisfied: idna<3,>=2.5 in c:\users\bunmi\onedrive\documents\lib\site-packages (from requests->opentok) (2.10)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\users\bunmi\onedrive\documents\lib\site-packages (from requests->opentok) (1.26.2)
Requirement already satisfied: certifi>=2017.4.17 in c:\users\bunmi\onedrive\documents\lib\site-packages (from requests->opentok) (2020.11.8)
Requirement already satisfied: chardet<4,>=3.0.2 in c:\users\bunmi\onedrive\documents\lib\site-packages (from requests->opentok) (3.0.4)
Building wheels for collected packages: opentok
Building wheel for opentok (PEP 517): started
Building wheel for opentok (PEP 517): finished with status 'done'
Created wheel for opentok: filename=opentok-3.1.0-cp37-none-any.whl size=26318 sha256=db4026e59a81901a990ff2badc1e78a888fe684e8773ef83e0c6464aab24ff45
Stored in directory: C:\Users\Bunmi\AppData\Local\pip\Cache\wheels\f4\fb\17\0db41a13ec5785399e8047b1d2437725745a19eda37752d691
Successfully built opentok
Installing collected packages: ecdsa, python-jose, opentok
Successfully installed ecdsa-0.17.0 opentok-3.1.0 python-jose-3.3.0
473 changes: 473 additions & 0 deletions Untitled.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import uvicorn

if __name__ == "__main__":
uvicorn.run("fast:app", host="0.0.0.0", port=8000, reload=True)


176 changes: 176 additions & 0 deletions fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# # Copyright (c) 2019-present, HuggingFace Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
from argparse import ArgumentParser
from itertools import chain
from pprint import pformat
import warnings

import torch
import torch.nn.functional as F

from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, GPT2Tokenizer
from train import SPECIAL_TOKENS, build_input_from_segments, add_special_tokens_
from utils import get_dataset, download_pretrained_model

from transformers import AutoTokenizer, AutoModel
from fastapi import FastAPI

from fastapi import FastAPI
from fastapi import APIRouter, Body

app = FastAPI()

@app.get("/")
def hello():
return {"message":"Hello TutLinks.com"}

def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
whose total probability mass is greater than or equal to the threshold top_p.
In practice, we select the highest probability tokens whose cumulative probability mass exceeds
the threshold top_p.
threshold: a minimal threshold to keep logits
"""
assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
top_k = min(top_k, logits.size(-1))
if top_k > 0:
# Remove all tokens with a probability less than the last token in the top-k tokens
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value

if top_p > 0.0:
# Compute cumulative probabilities of sorted tokens
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probabilities > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

# Back to unsorted indices and set them to -infinity
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value

indices_to_remove = logits < threshold
logits[indices_to_remove] = filter_value

return logits


def sample_sequence(personality, history, tokenizer, model, args, current_output=None):
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
if current_output is None:
current_output = []

for i in range(args.max_length):
instance = build_input_from_segments(personality, history, current_output, tokenizer, with_eos=False)

input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0)
token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0)

logits = model(input_ids, token_type_ids=token_type_ids)
if isinstance(logits, tuple): # for gpt2 and maybe others
logits = logits[0]
logits = logits[0, -1, :] / args.temperature
logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p)
probs = F.softmax(logits, dim=-1)

prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1)
if i < args.min_length and prev.item() in special_tokens_ids:
while prev.item() in special_tokens_ids:
if probs.max().item() == 1:
warnings.warn("Warning: model generating special token with probability 1.")
break # avoid infinitely looping over special token
prev = torch.multinomial(probs, num_samples=1)

if prev.item() in special_tokens_ids:
break
current_output.append(prev.item())

return current_output


def run():
parser = ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.")
parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
parser.add_argument("--model", type=str, default="openai-gpt", help="Model type (openai-gpt or gpt2)", choices=['openai-gpt', 'gpt2']) # anything besides gpt2 will load openai-gpt
parser.add_argument("--model_checkpoint", type=str, default="", help="Path, url or short name of the model")
parser.add_argument("--max_history", type=int, default=2, help="Number of previous utterances to keep in history")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")

parser.add_argument("--no_sample", action='store_true', help="Set to use greedy decoding instead of sampling")
parser.add_argument("--max_length", type=int, default=20, help="Maximum length of the output utterances")
parser.add_argument("--min_length", type=int, default=1, help="Minimum length of the output utterances")
parser.add_argument("--seed", type=int, default=0, help="Seed")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling softmax temperature")
parser.add_argument("--top_k", type=int, default=0, help="Filter top-k tokens before sampling (<=0: no filtering)")
parser.add_argument("--top_p", type=float, default=0.9, help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)")
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__file__)
logger.info(pformat(args))

if args.model_checkpoint == "":
if args.model == 'gpt2':
raise ValueError("Interacting with GPT2 requires passing a finetuned model_checkpoint")
else:
args.model_checkpoint = download_pretrained_model()


if args.seed != 0:
random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)


logger.info("Get pretrained model and tokenizer")


tokenizer_class, model_class = (GPT2Tokenizer, GPT2LMHeadModel) if args.model == 'gpt2' else (OpenAIGPTTokenizer, OpenAIGPTLMHeadModel)
global tokenizer
tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
global model
model = model_class.from_pretrained(args.model_checkpoint)
model.to(args.device)
add_special_tokens_(model, tokenizer)

logger.info("Sample a personality")
dataset = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)
personalities = [dialog["personality"] for dataset in dataset.values() for dialog in dataset]
personality = random.choice(personalities)
logger.info("Selected personality: %s", tokenizer.decode(chain(*personality)))

return model, tokenizer, args, personality

model, tokenizer, args, personality = run()

@app.get('/predict/{raw_text}')
def predictions(raw_text):

history = []
while True:
while not raw_text:
print('Prompt should not be empty!')
raw_text = 'hi'
history.append(tokenizer.encode(raw_text))
with torch.no_grad():
out_ids = sample_sequence(personality, history, tokenizer, model, args)
history.append(out_ids)
history = history[-(2*args.max_history+1):]
out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
return {'text': out_text}



3 changes: 2 additions & 1 deletion interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def run():
dataset = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)
personalities = [dialog["personality"] for dataset in dataset.values() for dialog in dataset]
personality = random.choice(personalities)
print(personalities)
logger.info("Selected personality: %s", tokenizer.decode(chain(*personality)))

history = []
Expand All @@ -151,4 +152,4 @@ def run():


if __name__ == "__main__":
run()
run()
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import uvicorn

if __name__ == "__main__":
uvicorn.run("fast:app", host="0.0.0.0", port=8118, reload=True)


8 changes: 5 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
torch
torch == 1.7.1
pytorch-ignite
transformers==2.5.1
tensorboardX==1.8
tensorflow # for tensorboardX
fastapi ==0.70.0
uvicorn == 0.15.0
spacy
tensorboardX==1.8
tensorflow # for tensorboardX
50 changes: 50 additions & 0 deletions style.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

body{
font:15px/1.5 Arial, Helvetica,sans-serif;
padding: 0;
background-color: #f4f3f3;
}

.container{
text-align:center;
width:100%;
margin: auto;
overflow: hidden;
}

header{
background: #03a9f4;
border-bottom: #740a39 3px solid;
height:120px;
width:100%;
padding-top:30px;

}

.main-header{
text-align:center;
background-color: #03a9f4;
height:100px;
width:100%;
margin:0;
}

.brandname{
text-align:center;
font-size:30px;
color: #161515;
margin: 10px;
}

header h2{
text-align:center;
color:#fff;
}

.results{
border-radius: 15px 50px;
background: #345fe4;
padding: 20px;
width: 200px;
height: 150px;
}