diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 1281e0d359e..96c8a195e97 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -17,6 +17,7 @@ Operators for neural networks BatchNorm3d, COCOReader, CTCLoss, + RNNTLoss, CoinFlip, ConstantPad1d, ConstantPad2d, diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 86cfc08fcf1..0b36ea1c0dc 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -1105,13 +1105,17 @@ class RNNTlossFunctor { CHECK_EQ_OR_RETURN(labels->dtype()->data_type() , DataType::kInt32)<<"labels must be int32"; CHECK_EQ_OR_RETURN(act_lens->dtype()->data_type() , DataType::kInt32)<<"act_lens must be int32"; CHECK_EQ_OR_RETURN(label_lens->dtype()->data_type(), DataType::kInt32)<<"label_lens must be int32"; + + CHECK_EQ_OR_RETURN(acts->shape()->NumAxes(),4)<<"the dim of acts must be 4"; + CHECK_EQ_OR_RETURN(labels->shape()->NumAxes(),2)<<"the dim of labels must be 2"; + CHECK_EQ_OR_RETURN(act_lens->shape()->NumAxes(),1)<<"the dim of act_lens must be 1"; + CHECK_EQ_OR_RETURN(label_lens->shape()->NumAxes(),1)<<"the dim of label_lens must be 1"; + MutableAttrMap attrs; JUST(attrs.SetAttr("blank_label", blank_label)); JUST(attrs.SetAttr("num_threads", num_threads)); - // std::shared_ptr acts_trans = JUST(LogSoftmax(acts,-1)); - return OpInterpUtil::Dispatch(*op_, {acts, labels, act_lens, label_lens}, attrs); } diff --git a/oneflow/user/ops/rnnt_op.cpp b/oneflow/user/ops/rnnt_op.cpp index 20915e3d70b..ac48ae6006b 100644 --- a/oneflow/user/ops/rnnt_op.cpp +++ b/oneflow/user/ops/rnnt_op.cpp @@ -50,19 +50,6 @@ REGISTER_USER_OP("RNNTloss") return Maybe::Ok(); }); -// REGISTER_USER_OP("RNNTloss_grad") -// .Input("grads") -// .Input("dy") -// .Output("dx") -// .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { -// CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0)); -// *ctx->OutputDType("dx", 0) = ctx->InputDType("grads", 0); -// return Maybe::Ok(); -// }) -// .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { -// ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); -// return Maybe::Ok(); -// }); REGISTER_USER_OP_GRAD("RNNTloss") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/python/oneflow/nn/modules/loss.py b/python/oneflow/nn/modules/loss.py index c48296404d5..6293b13b70d 100644 --- a/python/oneflow/nn/modules/loss.py +++ b/python/oneflow/nn/modules/loss.py @@ -690,12 +690,56 @@ def forward( class RNNTLoss(_Loss): + """The RNN Tranducer loss. + The documentation is referenced from: + https://github.com/HawkAaron/warp-transducer/tree/master/pytorch_binding + + This loss introduces probabilistic sequence transduction system, based entirely + on RNNs, that is in principle able to transform any input sequence into any finite, + discrete output sequence + + Args: + blank (int, optional): blank label. Default :math:`0`. + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the bathsize and. Default: ``'mean'`` + + Shape: + - acts: Tensor of size :math:`(B, T, U+1, V)` + B is the minibatch index + T is the time index + U is the label sequence length (+1 means blank label prepanded) + V indexes over activations for each symbol in the alphabet + - labels: 2 dimensional Tensor containing all the targets of the batch with zero padded. + - act_lens: A 1-D Tensor of ints, the length of each label for each example in the minibatch. + - label_lens: A 1-D Tensor of ints, the length of each label for each example in the minibatch. + + Reference: + A. Graves et al.: Sequence Transduction with Recurrent Neural Networks + http://www.cs.toronto.edu/~graves/icml_2012.pdf + + For example: + + .. code-block:: python + + >>> import oneflow as flow + + >>> acts = flow.rand(2,2,3,5) + >>> labels = flow.tensor([[1, 2],[2,2]],dtype=flow.int) + >>> acts_length = flow.tensor([2,2],dtype=flow.int) + >>> label_length = flow.tensor([2,2],dtype=flow.int) + >>> rnnt = flow.nn.RNNTLoss(blank=0,reduction="mean") + >>> rnnt(acts,labels,acts_length,label_length) + tensor([5.1914], dtype=oneflow.float32) + """ + def __init__( - self, blank: int = 0, reduction: str = "mean", thread: int=0 + self, blank: int = 0, reduction: str = "mean" ) -> None: + super(RNNTLoss, self).__init__(reduction) self.blank = blank - self.thread = thread + self.thread = 0 def forward( self, @@ -706,7 +750,6 @@ def forward( ) -> Tensor: if not acts.is_cuda: acts = flow._C.log_softmax(acts, -1) - print("cpu") loss = flow._C.RNNTloss(acts, labels, act_lens, label_lens, self.blank, self.thread) if self.reduction in ['sum', 'mean']: loss = loss.sum().unsqueeze(-1)