Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Housekeeping - upgrade to Torch 1.3.1 #28

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,25 @@ Please consider citing the [paper](https://arxiv.org/abs/1708.00524) of DeepMoji

## Installation

We assume that you're using [Python 2.7-3.5](https://www.python.org/downloads/) with [pip](https://pip.pypa.io/en/stable/installing/) installed.

First you need to install [pyTorch (version 0.2+)](http://pytorch.org/), currently by:
```bash
conda install pytorch -c pytorch
```
At the present stage the model can't make efficient use of CUDA. See details in the [Hugging Face blog post](https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983).

When pyTorch is installed, run the following in the root directory to install the remaining dependencies:
Assuming you have [Conda](https://conda.io) installed, run:

```bash
conda create -n torchMoji -f environment.yml
conda activate torchMoji
pip install -e .
```

This will install the following dependencies:

* [PyTorch](https://pytorch.org)
* [scikit-learn](https://github.com/scikit-learn/scikit-learn)
* [text-unidecode](https://github.com/kmike/text-unidecode)
* [emoji](https://github.com/carpedm20/emoji)

If you do not want to use Conda, please install `torch==1.3.1` from PIP and then run `pip install -e .` from the root directory (don't forget to set up a virtual environment).

At the present stage the model can't make efficient use of CUDA. See details in the [Hugging Face blog post](https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983).

Then, run the download script to downloads the pretrained torchMoji weights (~85MB) from [here](https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0) and put them in the model/ directory:

```bash
Expand Down
41 changes: 41 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: torchMoji
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1
- blas=1.0
- ca-certificates=2019.11.27
- certifi=2019.11.28
- cffi=1.13.2
- cudatoolkit=10.1.243
- intel-openmp=2019.4
- libedit=3.1.20181209
- libffi=3.2.1
- libgcc-ng=9.1.0
- libgfortran-ng=7.3.0
- libstdcxx-ng=9.1.0
- mkl=2018.0.3
- ncurses=6.1
- ninja=1.9.0
- nose=1.3.7
- numpy=1.13.1
- openssl=1.1.1d
- pip=19.3.1
- pycparser=2.19
- python=3.6.9
- pytorch=1.3.1
- readline=7.0
- scikit-learn=0.19.0
- scipy=0.19.1
- setuptools=42.0.2
- sqlite=3.30.1
- text-unidecode=1.0
- tk=8.6.8
- wheel=0.33.6
- xz=5.2.4
- zlib=1.2.11
- pip:
- emoji==0.4.5
prefix: /home/cbowdon/miniconda3/envs/torchMoji

10 changes: 5 additions & 5 deletions torchmoji/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import BatchSampler, SequentialSampler
from torch.nn.utils import clip_grad_norm
from torch.nn.utils import clip_grad_norm_

from sklearn.metrics import f1_score

Expand Down Expand Up @@ -521,7 +521,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
torch.save(model.state_dict(), checkpoint_path)

model.eval()
best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy() for xv, yv in val_gen])
print("original val loss", best_loss)

epoch_without_impr = 0
Expand All @@ -535,17 +535,17 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
output = model(X_train)
loss = calc_loss(loss_op, output, y_train)
loss.backward()
clip_grad_norm(model.parameters(), 1)
clip_grad_norm_(model.parameters(), 1)
optim_op.step()

acc = evaluate_using_acc(model, [(X_train.data, y_train.data)])
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc)
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy(), "train acc", acc)

model.eval()
acc = evaluate_using_acc(model, val_gen)
print("val acc", acc)

val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy() for xv, yv in val_gen])
print("val loss", val_loss)
if best_loss is not None and val_loss >= best_loss:
epoch_without_impr += 1
Expand Down
7 changes: 4 additions & 3 deletions torchmoji/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def reset_parameters(self):
def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
if is_packed:
input, batch_sizes = input
batch_sizes = input.batch_sizes
input = input.data
max_batch_size = batch_sizes[0]
else:
batch_sizes = None
Expand Down Expand Up @@ -337,11 +338,11 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):

ingate = hard_sigmoid(ingate)
forgetgate = hard_sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
cellgate = torch.tanh(cellgate)
outgate = hard_sigmoid(outgate)

cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
hy = outgate * torch.tanh(cy)

return hy, cy

Expand Down
16 changes: 6 additions & 10 deletions torchmoji/model_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, nb_classes, nb_tokens, feature_output=False, output_logits=Fa
self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1)))
else:
self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1),
nn.Softmax() if self.nb_classes > 2 else nn.Sigmoid()))
nn.Softmax(dim=1) if self.nb_classes > 2 else nn.Sigmoid()))
self.init_weights()
# Put model in evaluation mode by default
self.eval()
Expand All @@ -156,15 +156,15 @@ def init_weights(self):
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
b = (param.data for name, param in self.named_parameters() if 'bias' in name)
nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5)
nn.init.uniform_(self.embed.weight.data, a=-0.5, b=0.5)
for t in ih:
nn.init.xavier_uniform(t)
nn.init.xavier_uniform_(t)
for t in hh:
nn.init.orthogonal(t)
nn.init.orthogonal_(t)
for t in b:
nn.init.constant(t, 0)
nn.init.constant_(t, 0)
if not self.feature_output:
nn.init.xavier_uniform(self.output_layer[0].weight.data)
nn.init.xavier_uniform_(self.output_layer[0].weight.data)

def forward(self, input_seqs):
""" Forward pass.
Expand All @@ -177,10 +177,8 @@ def forward(self, input_seqs):
"""
# Check if we have Torch.LongTensor inputs or not Torch.Variable (assume Numpy array in this case), take note to return same format
return_numpy = False
return_tensor = False
if isinstance(input_seqs, (torch.LongTensor, torch.cuda.LongTensor)):
input_seqs = Variable(input_seqs)
return_tensor = True
elif not isinstance(input_seqs, Variable):
input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long())
return_numpy = True
Expand Down Expand Up @@ -246,8 +244,6 @@ def forward(self, input_seqs):
outputs = reorered

# Adapt return format if needed
if return_tensor:
outputs = outputs.data
if return_numpy:
outputs = outputs.data.numpy()

Expand Down