Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

end_class support for Autoregressive #125

Open
urchade opened this issue Feb 4, 2022 · 1 comment
Open

end_class support for Autoregressive #125

urchade opened this issue Feb 4, 2022 · 1 comment

Comments

@urchade
Copy link
Contributor

urchade commented Feb 4, 2022

end_class is not used for the Autoregressive module:

@CarlossShi
Copy link

Hi, I'm also in need of an autoregressive model with end_class. Here's my approach CarlossShi@b5a56e8. I use the variable active to record whether the sequences have ever output end_class or not. If there is no sequence alive, break the for loop to save time. I am not quite familiar with NLP, so I am not sure if this is common practice. In addition, some problems remain to be solved:

  • It may be necessary to add an output indicating the effective length of each sentence.
  • The _beam_search method does not work as expected (see example below).

Following the official documentation Autoregressive / Beam Search, I made some examples.

import torch
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, 'pytorch-struct')
import torch_struct
batch, layer, H, C, N, K = 3, 1, 5, 4, 10, 2  # K: sample shape
init = (torch.rand(batch, layer, H),
        torch.rand(batch, layer, H))


def t(a):
    return [t.transpose(0, 1) for t in a]


def show_ar(chain):
     plt.imshow(chain.detach().transpose(0, 1))


class RNN_AR(torch.nn.Module):
    def __init__(self, sparse=True):
        super().__init__()
        self.sparse = sparse
        self.rnn = torch.nn.RNN(H, H, batch_first=True)
        self.proj = torch.nn.Linear(H, C)
        if sparse:
            self.embed = torch.nn.Embedding(C, H)
        else:
            self.embed = torch.nn.Linear(C, H)

    def forward(self, inputs, state):
        """

        @param inputs: {Tensor: (batch, 1)}
        @param state:  e.g. ({Tensor: (batch, layer, H)}, {Tensor: (batch, layer, H)})
        @return: {Tensor: (batch, layer, C)}, [{Tensor: (batch, layer, H)}]
        """
        if not self.sparse and inputs.dim() == 2:
            inputs = torch.nn.functional.one_hot(inputs, C).float()
        inputs = self.embed(inputs)  # {Tensor: (batch, 1, H)}
        out, state = self.rnn(inputs, t(state)[0])  # out: {Tensor: (batch, layer, H)}, t(state)[0] & state: {Tensor: (layer, batch, H)}
        out = self.proj(out)  # {Tensor: (batch, layer, C)}
        return out, t((state,))  # t((state,))[0]: {Tensor: (batch, layer, H)}


dist = torch_struct.Autoregressive(RNN_AR(), init, C, N, end_class=1)

path, scores, logits = dist.greedy_max()  # path, logits: {Tensor: (batch, N, C)}, scores: {Tensor: (batch,)}
for b in range(batch):
    plt.subplot(1, batch, b + 1)
    plt.axis('off')
    show_ar(path[b])
plt.suptitle('dist.greedy_max()')
plt.show()

out = dist.sample(torch.Size([K]))  # {Tensor: (K, batch, N, C)}
for k in range(K):
    for b in range(batch):
        plt.subplot(K, batch, batch * k + b + 1)
        plt.axis('off')
        show_ar(out[k, b])
plt.suptitle('dist.sample(torch.Size([K]))')
plt.show()

out = dist.beam_topk(K)  # {Tensor: (K, batch, N, C)}, first output of _beam_search
for k in range(K):
    for b in range(batch):
        plt.subplot(K, batch, batch * k + b + 1)
        plt.axis('off')
        show_ar(out[k, b])
plt.suptitle('dist.beam_topk(K)')
plt.show()

The output images are as follows.
dist greedy_max()
dist sample(torch Size( K ))
dist beam_topk(K)

In the example above, end_class is set to 1. I expect that if all setences meet the end_class (i.e. there is a yellow square in the second row of each array), then the remaining columns are truncated. It seems that the sample method works expected, but the _beam_search not. I'm not quite familiar with the beam search function, so I just get stuck here.

Hope that help and any further support would be greatly appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants