Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformat file and added documentation #280

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions download_model.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
"""Downloads GPT-2 Model.

Options:
python download_model.py 117M|124M|345M|774M|1558M
"""
import os
import sys
import requests
from tqdm import tqdm

if len(sys.argv) != 2:
print('You must enter the model name as a parameter, e.g.: download_model.py 124M')
sys.exit(1)
print('You must enter the model name as a parameter, e.g.: download_model.py '
'124M')
sys.exit(1)

model = sys.argv[1]

subdir = os.path.join('models', model)
if not os.path.exists(subdir):
os.makedirs(subdir)
subdir = subdir.replace('\\','/') # needed for Windows

for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:

r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True)
os.makedirs(subdir)
subdir = subdir.replace('\\', '/') # needed for Windows

with open(os.path.join(subdir, filename), 'wb') as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)
for filename in [
'checkpoint', 'encoder.json', 'hparams.json',
'model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta',
'vocab.bpe'
]:
url = 'https://openaipublic.blob.core.windows.net/gpt-2/{}/{}'.format(
subdir, filename)
r = requests.get(url, stream=True)
with open(os.path.join(subdir, filename), 'wb') as f:
file_size = int(r.headers['content-length'])
chunk_size = 1000
with tqdm(
ncols=100,
desc='Fetching ' + filename,
total=file_size,
unit_scale=True) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes.
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)