Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KernelTensorSum #507

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ TransformedKernel
ScaledKernel
KernelSum
KernelProduct
KernelIndependentSum
KernelTensorProduct
NormalizedKernel
```
Expand Down
4 changes: 3 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ export LinearKernel, PolynomialKernel
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
export PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct, KernelTensorProduct
export KernelSum, KernelProduct, KernelIndependentSum, KernelTensorProduct
export TransformedKernel, ScaledKernel, NormalizedKernel
export GibbsKernel
export ⊕

export Transform,
SelectTransform,
Expand Down Expand Up @@ -108,6 +109,7 @@ include("kernels/normalizedkernel.jl")
include("matrix/kernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
include("kernels/kernelindependentsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/overloads.jl")
include("kernels/neuralkernelnetwork.jl")
Expand Down
110 changes: 110 additions & 0 deletions src/kernels/kernelindependentsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
KernelIndependentSum

Independent sum of kernels.

# Definition

For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
sum of kernels ``k_1, \\ldots, k_n`` is defined as
```math
k(x, x'; k_1, \\ldots, k_n) = \\sum_{i=1}^n k_i(x_i, x'_i).
```

# Construction

The simplest way to specify a `KernelIndependentSum` is to use the `⊕` operator (can be typed by `\\oplus<tab>`).
```jldoctest independentsum
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);

julia> kernelmatrix(k1 ⊕ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) + kernelmatrix(k2, X[:, 2])
true
```

You can also specify a `KernelIndependentSum` by providing kernels as individual arguments
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
individual arguments guarantees that `KernelIndependentSum` is concretely typed but might
lead to large compilation times if the number of kernels is large.
```jldoctest independentsum
julia> KernelIndependentSum(k1, k2) == k1 ⊕ k2
true

julia> KernelIndependentSum((k1, k2)) == k1 ⊕ k2
true

julia> KernelIndependentSum([k1, k2]) == k1 ⊕ k2
true
```
"""
struct KernelIndependentSum{K} <: Kernel
kernels::K
end

function KernelIndependentSum(kernel::Kernel, kernels::Kernel...)
return KernelIndependentSum((kernel, kernels...))
end

@functor KernelIndependentSum

Base.length(kernel::KernelIndependentSum) = length(kernel.kernels)

function (kernel::KernelIndependentSum)(x, y)
if !((nx = length(x)) == (ny = length(y)) == (nkernels = length(kernel)))
throw(
DimensionMismatch(
"number of kernels ($nkernels) and number of features (x=$nx, y=$ny) are not consistent",
),
)
end
return sum(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
end

function validate_domain(k::KernelIndependentSum, x::AbstractVector, y::AbstractVector)
return (dx = dim(x)) == (dy = dim(y)) == (nkernels = length(k)) || error(
"number of kernels ($nkernels) and group of features (x=$dx), y=$dy) are not consistent",
)
end

function validate_domain(k::KernelIndependentSum, x::AbstractVector)
return validate_domain(k, x, x)
end

function kernelmatrix(k::KernelIndependentSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, +, k.kernels, slices(x))
end

function kernelmatrix(k::KernelIndependentSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x, y)
return mapreduce(kernelmatrix, +, k.kernels, slices(x), slices(y))
end

function kernelmatrix_diag(k::KernelIndependentSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x))
end

function kernelmatrix_diag(k::KernelIndependentSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x, y)
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x), slices(y))
end

function Base.:(==)(x::KernelIndependentSum, y::KernelIndependentSum)
return (
length(x.kernels) == length(y.kernels) &&
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
)
end

Base.show(io::IO, kernel::KernelIndependentSum) = printshifted(io, kernel, 0)

function printshifted(io::IO, kernel::KernelIndependentSum, shift::Int)
print(io, "Tensor sum of ", length(kernel), " kernels:")
for k in kernel.kernels
print(io, "\n")
for _ in 1:(shift + 1)
print(io, "\t")
end
printshifted(io, k, shift + 2)
end
end
3 changes: 3 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
function ⊕ end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too generic to be defined and exported from KernelFunctions. Is it not part of TensorCore or some other lightweight interface package? We would also a non-Unicode alias, as for other keyword arguments and functions.

Copy link
Author

@martincornejo martincornejo May 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that, but is not part of TensorCore or as far as I know any other lightweight package (https://juliahub.com/ui/Search?q=%E2%8A%95&type=symbols). It is a help constructor for the new KernelTensorSum/KernelIndependentSum, so the non-Unicode function is already available.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions wellcome on how to improve this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no non-Unicode alternative similar to +, *, or tensor yet as far as I can tell?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I see, you're right

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too generic to be defined and exported from KernelFunctions.

This problem is not fixed yet, is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But since TensorCore.jl does not define , what should we do? Here are the packages that use . Kronecker.jl is one, but I guess we do not want to add this as a dependency, only to re-use the symbol.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make a PR to TensorCore. I think the operator should not be owned by KernelFunctions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


for (M, op, T) in (
(:Base, :+, :KernelSum),
(:Base, :*, :KernelProduct),
(:TensorCore, :tensor, :KernelTensorProduct),
(:KernelFunctions, :⊕, :KernelIndependentSum),
)
@eval begin
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)
Expand Down
67 changes: 67 additions & 0 deletions test/kernels/kernelindependentsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
@testset "kerneltensorsum" begin
rng = MersenneTwister(123456)
u1 = rand(rng, 10)
u2 = rand(rng, 10)
v1 = rand(rng, 5)
v2 = rand(rng, 5)

# kernels
k1 = SqExponentialKernel()
k2 = ExponentialKernel()
kernel1 = KernelIndependentSum(k1, k2)
kernel2 = KernelIndependentSum([k1, k2])

@test kernel1 == kernel2
@test kernel1.kernels == (k1, k2) === KernelIndependentSum((k1, k2)).kernels
for (_k1, _k2) in Iterators.product(
(k1, KernelIndependentSum((k1,)), KernelIndependentSum([k1])),
(k2, KernelIndependentSum((k2,)), KernelIndependentSum([k2])),
)
@test kernel1 == _k1 ⊕ _k2
end
@test length(kernel1) == length(kernel2) == 2
@test string(kernel1) == (
"Tensor sum of 2 kernels:\n" *
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
"\tExponential Kernel (metric = Euclidean(0.0))"
)
@test_throws DimensionMismatch kernel1(rand(3), rand(3))

@testset "val" begin
for (x, y) in (((v1, u1), (v2, u2)), ([v1, u1], [v2, u2]))
val = k1(x[1], y[1]) + k2(x[2], y[2])

@test kernel1(x, y) == kernel2(x, y) == val
end
end

# Standardised tests.
TestUtils.test_interface(kernel1, ColVecs{Float64})
TestUtils.test_interface(kernel1, RowVecs{Float64})
TestUtils.test_interface(
KernelIndependentSum(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String}
)
test_ADs(
x -> KernelIndependentSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))),
rand(1);
dims=[2, 2],
)
types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}]
test_interface_ad_perf(2.1, StableRNG(123456), types) do c
KernelIndependentSum(SqExponentialKernel(), LinearKernel(; c=c))
end
test_params(KernelIndependentSum(k1, k2), (k1, k2))

@testset "single kernel" begin
kernel = KernelIndependentSum(k1)
@test length(kernel) == 1

@testset "eval" begin
for (x, y) in (((v1,), (v2,)), ([v1], [v2]))
val = k1(x[1], y[1])

@test kernel(x, y) == val
end
end
end
end
9 changes: 7 additions & 2 deletions test/kernels/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
k2 = SqExponentialKernel()
k3 = RationalQuadraticKernel()

for (op, T) in ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct))
if T === KernelTensorProduct
for (op, T) in (
(+, KernelSum),
(*, KernelProduct),
(⊗, KernelTensorProduct),
(⊕, KernelIndependentSum),
)
if T === KernelTensorProduct || T === KernelIndependentSum
v2_1 = rand(rng, 2)
v2_2 = rand(rng, 2)
v3_1 = rand(rng, 3)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ include("test_utils.jl")
include("kernels/kernelproduct.jl")
include("kernels/kernelsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/kernelindependentsum.jl")
include("kernels/overloads.jl")
include("kernels/scaledkernel.jl")
include("kernels/transformedkernel.jl")
Expand Down