Skip to content

Commit

Permalink
complete
Browse files Browse the repository at this point in the history
  • Loading branch information
lcylcy committed Nov 3, 2021
1 parent 68b3dd2 commit f89205b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Operators for neural networks
BatchNorm3d,
COCOReader,
CTCLoss,
RNNTLoss,
CoinFlip,
ConstantPad1d,
ConstantPad2d,
Expand Down
8 changes: 6 additions & 2 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,13 +1105,17 @@ class RNNTlossFunctor {
CHECK_EQ_OR_RETURN(labels->dtype()->data_type() , DataType::kInt32)<<"labels must be int32";
CHECK_EQ_OR_RETURN(act_lens->dtype()->data_type() , DataType::kInt32)<<"act_lens must be int32";
CHECK_EQ_OR_RETURN(label_lens->dtype()->data_type(), DataType::kInt32)<<"label_lens must be int32";

CHECK_EQ_OR_RETURN(acts->shape()->NumAxes(),4)<<"the dim of acts must be 4";
CHECK_EQ_OR_RETURN(labels->shape()->NumAxes(),2)<<"the dim of labels must be 2";
CHECK_EQ_OR_RETURN(act_lens->shape()->NumAxes(),1)<<"the dim of act_lens must be 1";
CHECK_EQ_OR_RETURN(label_lens->shape()->NumAxes(),1)<<"the dim of label_lens must be 1";

MutableAttrMap attrs;

JUST(attrs.SetAttr<int32_t>("blank_label", blank_label));
JUST(attrs.SetAttr<int32_t>("num_threads", num_threads));

// std::shared_ptr<one::Tensor> acts_trans = JUST(LogSoftmax(acts,-1));

return OpInterpUtil::Dispatch<Tensor>(*op_, {acts, labels, act_lens, label_lens}, attrs);
}

Expand Down
13 changes: 0 additions & 13 deletions oneflow/user/ops/rnnt_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,6 @@ REGISTER_USER_OP("RNNTloss")
return Maybe<void>::Ok();
});

// REGISTER_USER_OP("RNNTloss_grad")
// .Input("grads")
// .Input("dy")
// .Output("dx")
// .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
// CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0));
// *ctx->OutputDType("dx", 0) = ctx->InputDType("grads", 0);
// return Maybe<void>::Ok();
// })
// .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
// ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
// return Maybe<void>::Ok();
// });

REGISTER_USER_OP_GRAD("RNNTloss")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
Expand Down
49 changes: 46 additions & 3 deletions python/oneflow/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,12 +690,56 @@ def forward(


class RNNTLoss(_Loss):
"""The RNN Tranducer loss.
The documentation is referenced from:
https://github.com/HawkAaron/warp-transducer/tree/master/pytorch_binding
This loss introduces probabilistic sequence transduction system, based entirely
on RNNs, that is in principle able to transform any input sequence into any finite,
discrete output sequence
Args:
blank (int, optional): blank label. Default :math:`0`.
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the output losses will be divided by the bathsize and. Default: ``'mean'``
Shape:
- acts: Tensor of size :math:`(B, T, U+1, V)`
B is the minibatch index
T is the time index
U is the label sequence length (+1 means blank label prepanded)
V indexes over activations for each symbol in the alphabet
- labels: 2 dimensional Tensor containing all the targets of the batch with zero padded.
- act_lens: A 1-D Tensor of ints, the length of each label for each example in the minibatch.
- label_lens: A 1-D Tensor of ints, the length of each label for each example in the minibatch.
Reference:
A. Graves et al.: Sequence Transduction with Recurrent Neural Networks
http://www.cs.toronto.edu/~graves/icml_2012.pdf
For example:
.. code-block:: python
>>> import oneflow as flow
>>> acts = flow.rand(2,2,3,5)
>>> labels = flow.tensor([[1, 2],[2,2]],dtype=flow.int)
>>> acts_length = flow.tensor([2,2],dtype=flow.int)
>>> label_length = flow.tensor([2,2],dtype=flow.int)
>>> rnnt = flow.nn.RNNTLoss(blank=0,reduction="mean")
>>> rnnt(acts,labels,acts_length,label_length)
tensor([5.1914], dtype=oneflow.float32)
"""

def __init__(
self, blank: int = 0, reduction: str = "mean", thread: int=0
self, blank: int = 0, reduction: str = "mean"
) -> None:

super(RNNTLoss, self).__init__(reduction)
self.blank = blank
self.thread = thread
self.thread = 0

def forward(
self,
Expand All @@ -706,7 +750,6 @@ def forward(
) -> Tensor:
if not acts.is_cuda:
acts = flow._C.log_softmax(acts, -1)
print("cpu")
loss = flow._C.RNNTloss(acts, labels, act_lens, label_lens, self.blank, self.thread)
if self.reduction in ['sum', 'mean']:
loss = loss.sum().unsqueeze(-1)
Expand Down

0 comments on commit f89205b

Please sign in to comment.