-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add repeat vector #2409
base: master
Are you sure you want to change the base?
Add repeat vector #2409
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -864,3 +864,33 @@ EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean) | |
function Base.show(io::IO, m::EmbeddingBag) | ||
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") | ||
end | ||
|
||
|
||
""" | ||
RepeatVector(n::Int) | ||
|
||
Repeat the input `n` times along the last dimension. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should specify that an extra dimension is added first. |
||
|
||
# Examples | ||
```jldoctest | ||
julia> rv = RepeatVector(3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in a jldoctest block each line should also have the corresponding output, otherwise the doctests will fail |
||
julia> rv([1, 2, 3]) | ||
3×3 Matrix{Int64}: | ||
1 1 1 | ||
2 2 2 | ||
3 3 3 | ||
``` | ||
""" | ||
struct RepeatVector | ||
n::Int | ||
end | ||
|
||
@layer RepeatVector | ||
|
||
function (rv::RepeatVector)(x::AbstractArray{T}) where {T} | ||
expanded = reshape(x, (size(x)..., 1)) | ||
repeated = repeat(expanded, outer = (1, rv.n, 1)) | ||
return repeated | ||
Comment on lines
+891
to
+893
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could just be return repeat(x, size(x)..., rv.n) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would do this: julia> RepeatVector(2)(rand(3)) |> size
(9, 2)
julia> RepeatVector(2)(rand(3,5)) |> size
(9, 25, 2)
julia> RepeatVector(2)(rand(3,5,7)) |> size
(9, 25, 49, 2) Perhaps you meant: julia> function (rv::RepeatVector)(x::AbstractArray{T}) where {T}
repeated = repeat(x, ntuple(_->1, ndims(x))..., rv.n)
end
julia> RepeatVector(2)(rand(3)) |> size
(3, 2)
julia> RepeatVector(2)(rand(3,5)) |> size
(3, 5, 2)
julia> RepeatVector(2)(rand(3,5,7)) |> size
(3, 5, 7, 2) |
||
end | ||
|
||
Base.show(io::IO, rv::RepeatVector) = print(io, "RepeatVector($(rv.n))") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RepeatArray
is maybe a better name.