Skip to content
This repository has been archived by the owner on Nov 17, 2022. It is now read-only.

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
Change-Id: I0d0b8682af66032b072a148fc204496c043a410d
  • Loading branch information
lianrongzhong committed May 26, 2020
1 parent 3bbc94d commit 6581eba
Show file tree
Hide file tree
Showing 24 changed files with 78,893 additions and 12,050 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


## Quick Start

### Requirement
Expand All @@ -12,10 +10,13 @@ Put train/valid/test data files in `data/` folder with the same prefix:
* `$prefix.valid`
* `$prefix.test`

Adding embeddings if you need a warm start

We provide only a small size dataset for demonstration

### Training
python run_seq2seq.py
python run_seq2seq.py : Seq2Seq
python run_knowledge.py : PostKS with HGFU

### Testing
python run_seq2seq.py --test --ckpt $model_state_file


1,000 changes: 0 additions & 1,000 deletions data/toy/dial.test

This file was deleted.

10,000 changes: 0 additions & 10,000 deletions data/toy/dial.train

This file was deleted.

1,000 changes: 0 additions & 1,000 deletions data/toy/dial.valid

This file was deleted.

100 changes: 100 additions & 0 deletions data/wizard/sample.test

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions data/wizard/sample.train

Large diffs are not rendered by default.

100 changes: 100 additions & 0 deletions data/wizard/sample.valid

Large diffs are not rendered by default.

Binary file added data/wizard/sample_20000.data.pt
Binary file not shown.
Binary file added data/wizard/sample_20000.vocab.pt
Binary file not shown.
197 changes: 197 additions & 0 deletions dialnlp/models/gold_seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Copyright (c) 2018 Baidu.com, Inc. All Rights Reserved
#
########################################################################

"""
File: knowledge_seq2eq.py
Author: lianrongzhong([email protected])
Date: 2018/11/19 17:43:46
"""
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
import pdb

from dialnlp.models.base_model import BaseModel
from dialnlp.modules.embedder import Embedder
from dialnlp.modules.encoders.rnn_encoder import RNNEncoder
from dialnlp.modules.decoders.attention_decoder import RNNDecoder
from dialnlp.utils.criterions import NLLLoss
from dialnlp.utils.misc import Pack
from dialnlp.utils.metrics import accuracy
from dialnlp.utils.metrics import attn_accuracy
from dialnlp.utils.metrics import perplexity
from dialnlp.modules.attention import Attention

class GoldSeq2Seq(BaseModel):
def __init__(self,
src_vocab_size,
tgt_vocab_size,
embed_size,
hidden_size,
padding_idx=None,
num_layers=1,
bidirectional=True,
attn_mode="mlp",
attn_hidden_size=None,
with_bridge=False,
tie_embedding=False,
dropout=0.0,
use_gpu=False,
extend_attn=False,
concat=False):
super(GoldSeq2Seq, self).__init__()

self.src_vocab_size = src_vocab_size
self.tgt_vocab_size = tgt_vocab_size
self.embed_size = embed_size
self.hidden_size = hidden_size
self.padding_idx = padding_idx
self.num_layers = num_layers
self.bidirectional = bidirectional
self.attn_mode = attn_mode
self.attn_hidden_size = attn_hidden_size
self.with_bridge = with_bridge
self.tie_embedding = tie_embedding
self.dropout = dropout
self.use_gpu = use_gpu

enc_embedder = Embedder(num_embeddings=self.src_vocab_size,
embedding_dim=self.embed_size,
padding_idx=self.padding_idx)

self.encoder = RNNEncoder(input_size=self.embed_size,
hidden_size=self.hidden_size,
embedder=enc_embedder,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
dropout=self.dropout)

if self.with_bridge:
self.bridge = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.Tanh()
)

if self.tie_embedding:
assert self.src_vocab_size == self.tgt_vocab_size
dec_embedder = enc_embedder
knowledge_embedder = enc_embedder
else:
dec_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
embedding_dim=self.embed_size,
padding_idx=self.padding_idx)
knowledge_embedder = Embedder(num_embeddings=self.tgt_vocab_size,
embedding_dim=self.embed_size,
padding_idx=self.padding_idx)

self.knowledge_encoder = RNNEncoder(input_size=self.embed_size,
hidden_size=self.hidden_size,
embedder=enc_embedder,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
dropout=self.dropout)

self.decoder = RNNDecoder(input_size=self.embed_size,
hidden_size=self.hidden_size,
output_size=self.tgt_vocab_size,
embedder=dec_embedder,
num_layers=self.num_layers,
attn_mode=self.attn_mode,
memory_size=self.hidden_size,
feature_size=None,
dropout=self.dropout,
extend_attn=extend_attn,
concat=concat)

self.log_softmax = nn.LogSoftmax(dim=-1)
self.softmax = nn.Softmax(dim=-1)
self.sigmoid = nn.Sigmoid()

if self.padding_idx is not None:
self.weight = torch.ones(self.tgt_vocab_size)
self.weight[self.padding_idx] = 0
else:
self.weight = None
self.nll_loss = NLLLoss(weight=self.weight,
ignore_index=self.padding_idx,
reduction='mean')

if self.use_gpu:
self.cuda()

def encode(self, inputs, hidden=None, is_training=False):
outputs = Pack()
enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1]-2
enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)

if self.with_bridge:
enc_hidden = self.bridge(enc_hidden)

# knowledge
cue_inputs = inputs.cue[0].squeeze(1)[:, 1:-1], inputs.cue[1].squeeze(1)-2
cue_outputs, cue_hidden = self.knowledge_encoder(cue_inputs, hidden)
knowledge = cue_hidden[-1].unsqueeze(1)

dec_init_state = self.decoder.initialize_state(
hidden=enc_hidden,
attn_memory=enc_outputs if self.attn_mode else None,
memory_lengths=lengths if self.attn_mode else None,
knowledge=cue_outputs if self.attn_mode else knowledge,
knowledge_lengths=cue_inputs[1] if self.attn_mode else None)
#knowledge=knowledge)
return outputs, dec_init_state

def decode(self, input, state):
log_prob, state, output = self.decoder.decode(input, state)
return log_prob, state, output

def forward(self, enc_inputs, dec_inputs, hidden=None, is_training=False):
outputs, dec_init_state = self.encode(
enc_inputs, hidden, is_training=is_training)
log_probs, _ = self.decoder(dec_inputs, dec_init_state)
outputs.add(logits=log_probs)
return outputs

def collect_metrics(self, outputs, target, epoch=-1):
num_samples = target.size(0)
metrics = Pack(num_samples=num_samples)
loss = 0

logits = outputs.logits
nll_loss = self.nll_loss(logits, target)
num_words = target.ne(self.padding_idx).sum().item()
acc = accuracy(logits, target, padding_idx=self.padding_idx)
metrics.add(nll=(nll_loss, num_words), acc=acc)

loss += nll_loss

metrics.add(loss=loss)
return metrics

def iterate(self, inputs, optimizer=None, grad_clip=None, is_training=False, epoch=-1):
enc_inputs = inputs
dec_inputs = inputs.tgt[0][:, :-1], inputs.tgt[1]-1
target = inputs.tgt[0][:, 1:]

outputs = self.forward(enc_inputs, dec_inputs, is_training=is_training)
metrics = self.collect_metrics(outputs, target)

loss = metrics.loss
if torch.isnan(loss):
raise ValueError("nan loss encountered")

if is_training:
assert optimizer is not None
optimizer.zero_grad()
loss.backward()
if grad_clip is not None and grad_clip > 0:
clip_grad_norm_(parameters=self.parameters(),
max_norm=grad_clip)
optimizer.step()
return metrics
Loading

0 comments on commit 6581eba

Please sign in to comment.