Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat vector #2409

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.14.14
* New layer `RepeatVector` which works like
RepeatVector in keras

## v0.14.13
* New macro `Flux.@layer` which should be used in place of `@functor`.
This also adds `show` methods for pretty printing.
Expand Down
30 changes: 30 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,33 @@ EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean)
function Base.show(io::IO, m::EmbeddingBag)
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end


"""
RepeatVector(n::Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RepeatArray is maybe a better name.


Repeat the input `n` times along the last dimension.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should specify that an extra dimension is added first.


# Examples
```jldoctest
julia> rv = RepeatVector(3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a jldoctest block each line should also have the corresponding output, otherwise the doctests will fail

julia> rv([1, 2, 3])
3×3 Matrix{Int64}:
1 1 1
2 2 2
3 3 3
```
"""
struct RepeatVector
n::Int
end

@layer RepeatVector

function (rv::RepeatVector)(x::AbstractArray{T}) where {T}
expanded = reshape(x, (size(x)..., 1))
repeated = repeat(expanded, outer = (1, rv.n, 1))
return repeated
Comment on lines +891 to +893
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could just be

return repeat(x, size(x)..., rv.n)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would do this:

julia> RepeatVector(2)(rand(3)) |> size
(9, 2)

julia> RepeatVector(2)(rand(3,5)) |> size
(9, 25, 2)

julia> RepeatVector(2)(rand(3,5,7)) |> size
(9, 25, 49, 2)

Perhaps you meant:

julia> function (rv::RepeatVector)(x::AbstractArray{T}) where {T}
          repeated = repeat(x, ntuple(_->1, ndims(x))..., rv.n)
       end

julia> RepeatVector(2)(rand(3)) |> size
(3, 2)

julia> RepeatVector(2)(rand(3,5)) |> size
(3, 5, 2)

julia> RepeatVector(2)(rand(3,5,7)) |> size
(3, 5, 7, 2)

end

Base.show(io::IO, rv::RepeatVector) = print(io, "RepeatVector($(rv.n))")
Loading