This repository has been archived by the owner on Nov 17, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change-Id: I0d0b8682af66032b072a148fc204496c043a410d
- Loading branch information
lianrongzhong
committed
May 26, 2020
1 parent
3bbc94d
commit 6581eba
Showing
24 changed files
with
78,893 additions
and
12,050 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
######################################################################## | ||
# | ||
# Copyright (c) 2018 Baidu.com, Inc. All Rights Reserved | ||
# | ||
######################################################################## | ||
|
||
""" | ||
File: knowledge_seq2eq.py | ||
Author: lianrongzhong([email protected]) | ||
Date: 2018/11/19 17:43:46 | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn.utils import clip_grad_norm_ | ||
import torch.nn.functional as F | ||
import pdb | ||
|
||
from dialnlp.models.base_model import BaseModel | ||
from dialnlp.modules.embedder import Embedder | ||
from dialnlp.modules.encoders.rnn_encoder import RNNEncoder | ||
from dialnlp.modules.decoders.attention_decoder import RNNDecoder | ||
from dialnlp.utils.criterions import NLLLoss | ||
from dialnlp.utils.misc import Pack | ||
from dialnlp.utils.metrics import accuracy | ||
from dialnlp.utils.metrics import attn_accuracy | ||
from dialnlp.utils.metrics import perplexity | ||
from dialnlp.modules.attention import Attention | ||
|
||
class GoldSeq2Seq(BaseModel): | ||
def __init__(self, | ||
src_vocab_size, | ||
tgt_vocab_size, | ||
embed_size, | ||
hidden_size, | ||
padding_idx=None, | ||
num_layers=1, | ||
bidirectional=True, | ||
attn_mode="mlp", | ||
attn_hidden_size=None, | ||
with_bridge=False, | ||
tie_embedding=False, | ||
dropout=0.0, | ||
use_gpu=False, | ||
extend_attn=False, | ||
concat=False): | ||
super(GoldSeq2Seq, self).__init__() | ||
|
||
self.src_vocab_size = src_vocab_size | ||
self.tgt_vocab_size = tgt_vocab_size | ||
self.embed_size = embed_size | ||
self.hidden_size = hidden_size | ||
self.padding_idx = padding_idx | ||
self.num_layers = num_layers | ||
self.bidirectional = bidirectional | ||
self.attn_mode = attn_mode | ||
self.attn_hidden_size = attn_hidden_size | ||
self.with_bridge = with_bridge | ||
self.tie_embedding = tie_embedding | ||
self.dropout = dropout | ||
self.use_gpu = use_gpu | ||
|
||
enc_embedder = Embedder(num_embeddings=self.src_vocab_size, | ||
embedding_dim=self.embed_size, | ||
padding_idx=self.padding_idx) | ||
|
||
self.encoder = RNNEncoder(input_size=self.embed_size, | ||
hidden_size=self.hidden_size, | ||
embedder=enc_embedder, | ||
num_layers=self.num_layers, | ||
bidirectional=self.bidirectional, | ||
dropout=self.dropout) | ||
|
||
if self.with_bridge: | ||
self.bridge = nn.Sequential( | ||
nn.Linear(self.hidden_size, self.hidden_size), | ||
nn.Tanh() | ||
) | ||
|
||
if self.tie_embedding: | ||
assert self.src_vocab_size == self.tgt_vocab_size | ||
dec_embedder = enc_embedder | ||
knowledge_embedder = enc_embedder | ||
else: | ||
dec_embedder = Embedder(num_embeddings=self.tgt_vocab_size, | ||
embedding_dim=self.embed_size, | ||
padding_idx=self.padding_idx) | ||
knowledge_embedder = Embedder(num_embeddings=self.tgt_vocab_size, | ||
embedding_dim=self.embed_size, | ||
padding_idx=self.padding_idx) | ||
|
||
self.knowledge_encoder = RNNEncoder(input_size=self.embed_size, | ||
hidden_size=self.hidden_size, | ||
embedder=enc_embedder, | ||
num_layers=self.num_layers, | ||
bidirectional=self.bidirectional, | ||
dropout=self.dropout) | ||
|
||
self.decoder = RNNDecoder(input_size=self.embed_size, | ||
hidden_size=self.hidden_size, | ||
output_size=self.tgt_vocab_size, | ||
embedder=dec_embedder, | ||
num_layers=self.num_layers, | ||
attn_mode=self.attn_mode, | ||
memory_size=self.hidden_size, | ||
feature_size=None, | ||
dropout=self.dropout, | ||
extend_attn=extend_attn, | ||
concat=concat) | ||
|
||
self.log_softmax = nn.LogSoftmax(dim=-1) | ||
self.softmax = nn.Softmax(dim=-1) | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
if self.padding_idx is not None: | ||
self.weight = torch.ones(self.tgt_vocab_size) | ||
self.weight[self.padding_idx] = 0 | ||
else: | ||
self.weight = None | ||
self.nll_loss = NLLLoss(weight=self.weight, | ||
ignore_index=self.padding_idx, | ||
reduction='mean') | ||
|
||
if self.use_gpu: | ||
self.cuda() | ||
|
||
def encode(self, inputs, hidden=None, is_training=False): | ||
outputs = Pack() | ||
enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1]-2 | ||
enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) | ||
|
||
if self.with_bridge: | ||
enc_hidden = self.bridge(enc_hidden) | ||
|
||
# knowledge | ||
cue_inputs = inputs.cue[0].squeeze(1)[:, 1:-1], inputs.cue[1].squeeze(1)-2 | ||
cue_outputs, cue_hidden = self.knowledge_encoder(cue_inputs, hidden) | ||
knowledge = cue_hidden[-1].unsqueeze(1) | ||
|
||
dec_init_state = self.decoder.initialize_state( | ||
hidden=enc_hidden, | ||
attn_memory=enc_outputs if self.attn_mode else None, | ||
memory_lengths=lengths if self.attn_mode else None, | ||
knowledge=cue_outputs if self.attn_mode else knowledge, | ||
knowledge_lengths=cue_inputs[1] if self.attn_mode else None) | ||
#knowledge=knowledge) | ||
return outputs, dec_init_state | ||
|
||
def decode(self, input, state): | ||
log_prob, state, output = self.decoder.decode(input, state) | ||
return log_prob, state, output | ||
|
||
def forward(self, enc_inputs, dec_inputs, hidden=None, is_training=False): | ||
outputs, dec_init_state = self.encode( | ||
enc_inputs, hidden, is_training=is_training) | ||
log_probs, _ = self.decoder(dec_inputs, dec_init_state) | ||
outputs.add(logits=log_probs) | ||
return outputs | ||
|
||
def collect_metrics(self, outputs, target, epoch=-1): | ||
num_samples = target.size(0) | ||
metrics = Pack(num_samples=num_samples) | ||
loss = 0 | ||
|
||
logits = outputs.logits | ||
nll_loss = self.nll_loss(logits, target) | ||
num_words = target.ne(self.padding_idx).sum().item() | ||
acc = accuracy(logits, target, padding_idx=self.padding_idx) | ||
metrics.add(nll=(nll_loss, num_words), acc=acc) | ||
|
||
loss += nll_loss | ||
|
||
metrics.add(loss=loss) | ||
return metrics | ||
|
||
def iterate(self, inputs, optimizer=None, grad_clip=None, is_training=False, epoch=-1): | ||
enc_inputs = inputs | ||
dec_inputs = inputs.tgt[0][:, :-1], inputs.tgt[1]-1 | ||
target = inputs.tgt[0][:, 1:] | ||
|
||
outputs = self.forward(enc_inputs, dec_inputs, is_training=is_training) | ||
metrics = self.collect_metrics(outputs, target) | ||
|
||
loss = metrics.loss | ||
if torch.isnan(loss): | ||
raise ValueError("nan loss encountered") | ||
|
||
if is_training: | ||
assert optimizer is not None | ||
optimizer.zero_grad() | ||
loss.backward() | ||
if grad_clip is not None and grad_clip > 0: | ||
clip_grad_norm_(parameters=self.parameters(), | ||
max_norm=grad_clip) | ||
optimizer.step() | ||
return metrics |
Oops, something went wrong.