-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
end_class support for Autoregressive #125
Comments
Hi, I'm also in need of an autoregressive model with
Following the official documentation Autoregressive / Beam Search, I made some examples. import torch
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, 'pytorch-struct')
import torch_struct
batch, layer, H, C, N, K = 3, 1, 5, 4, 10, 2 # K: sample shape
init = (torch.rand(batch, layer, H),
torch.rand(batch, layer, H))
def t(a):
return [t.transpose(0, 1) for t in a]
def show_ar(chain):
plt.imshow(chain.detach().transpose(0, 1))
class RNN_AR(torch.nn.Module):
def __init__(self, sparse=True):
super().__init__()
self.sparse = sparse
self.rnn = torch.nn.RNN(H, H, batch_first=True)
self.proj = torch.nn.Linear(H, C)
if sparse:
self.embed = torch.nn.Embedding(C, H)
else:
self.embed = torch.nn.Linear(C, H)
def forward(self, inputs, state):
"""
@param inputs: {Tensor: (batch, 1)}
@param state: e.g. ({Tensor: (batch, layer, H)}, {Tensor: (batch, layer, H)})
@return: {Tensor: (batch, layer, C)}, [{Tensor: (batch, layer, H)}]
"""
if not self.sparse and inputs.dim() == 2:
inputs = torch.nn.functional.one_hot(inputs, C).float()
inputs = self.embed(inputs) # {Tensor: (batch, 1, H)}
out, state = self.rnn(inputs, t(state)[0]) # out: {Tensor: (batch, layer, H)}, t(state)[0] & state: {Tensor: (layer, batch, H)}
out = self.proj(out) # {Tensor: (batch, layer, C)}
return out, t((state,)) # t((state,))[0]: {Tensor: (batch, layer, H)}
dist = torch_struct.Autoregressive(RNN_AR(), init, C, N, end_class=1)
path, scores, logits = dist.greedy_max() # path, logits: {Tensor: (batch, N, C)}, scores: {Tensor: (batch,)}
for b in range(batch):
plt.subplot(1, batch, b + 1)
plt.axis('off')
show_ar(path[b])
plt.suptitle('dist.greedy_max()')
plt.show()
out = dist.sample(torch.Size([K])) # {Tensor: (K, batch, N, C)}
for k in range(K):
for b in range(batch):
plt.subplot(K, batch, batch * k + b + 1)
plt.axis('off')
show_ar(out[k, b])
plt.suptitle('dist.sample(torch.Size([K]))')
plt.show()
out = dist.beam_topk(K) # {Tensor: (K, batch, N, C)}, first output of _beam_search
for k in range(K):
for b in range(batch):
plt.subplot(K, batch, batch * k + b + 1)
plt.axis('off')
show_ar(out[k, b])
plt.suptitle('dist.beam_topk(K)')
plt.show() The output images are as follows. In the example above, Hope that help and any further support would be greatly appreciated. |
end_class is not used for the Autoregressive module:
pytorch-struct/torch_struct/autoregressive.py
Line 49 in 7146de5
The text was updated successfully, but these errors were encountered: