Skip to content

Commit

Permalink
Refactor whitening for closer integration with StatsBase types (#144)
Browse files Browse the repository at this point in the history
* refactor whitening for closer integration with StatsBase types (part of #109)
* deprecate `indim` & `outdim`
  • Loading branch information
wildart authored Jun 1, 2021
1 parent fc3c5aa commit f00cba3
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 35 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ A Julia package for multivariate statistics and data analysis (e.g. dimensionali
[![Coverage Status](https://coveralls.io/repos/JuliaStats/MultivariateStats.jl/badge.svg?branch=master)](https://coveralls.io/r/JuliaStats/MultivariateStats.jl?branch=master)
[![Build Status](https://travis-ci.org/JuliaStats/MultivariateStats.jl.svg?branch=master)](https://travis-ci.org/JuliaStats/MultivariateStats.jl)
[![CI](https://github.com/JuliaStats/MultivariateStats.jl/actions/workflows/ci.yml/badge.svg)](https://github.com/JuliaStats/MultivariateStats.jl/actions/workflows/ci.yml)
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliastats.org/MultivariateStats.jl/stable)
[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliastats.org/MultivariateStats.jl/dev)

-------

Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ end
makedocs(
sitename = "MultivariateStats.jl",
modules = [MultivariateStats],
pages = ["Home"=>"index.md", "lda.md", "Development"=>"api.md"]
pages = ["Home"=>"index.md", "whiten.md", "lda.md", "Development"=>"api.md"]
)

deploydocs(
Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ Note: `?` refers to a possible implementation that is missing or called differen
|length | + | | x | | | | | | | | |
|size | + | | | | | | | | | | |
| | | | | | | | | | | | |
|eee | | | | | | | | | | | |

- StatsBase.AbstractDataTransform
- Whitening
Expand Down
3 changes: 1 addition & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ end

[MultivariateStats.jl](https://github.com/JuliaStats/MultivariateStats.jl) is a Julia package for multivariate statistical analysis. It provides a rich set of useful analysis techniques, such as PCA, CCA, LDA, ICA, etc.


```@contents
Pages = ["lda.md", "api.md"]
Pages = ["whiten.md", "lda.md", "api.md"]
Depth = 2
```

Expand Down
35 changes: 35 additions & 0 deletions docs/src/whiten.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Data Transformation

## Whitening

A [whitening transformation](http://en.wikipedia.org/wiki/Whitening_transformation>) is a decorrelation transformation that transforms a set of random variables into a set of new random variables with identity covariance (uncorrelated with unit variances).

In particular, suppose a random vector has covariance ``\mathbf{C}``, then a whitening transform ``\mathbf{W}`` is one that satisfy:

```math
\mathbf{W}^T \mathbf{C} \mathbf{W} = \mathbf{I}
```

Note that ``\mathbf{W}`` is generally not unique. In particular, if ``\mathbf{W}`` is a whitening transform, so is any of its rotation ``\mathbf{W} \mathbf{R}`` with ``\mathbf{R}^T \mathbf{R} = \mathbf{I}``.

The package uses [`Whitening`](@ref) to represent a whitening transform.

```@docs
Whitening
```

Whitening transformation can be fitted to data using the `fit` method.

```@docs
fit(::Type{Whitening}, X::AbstractMatrix{T}; kwargs...) where {T<:Real}
transform(::Whitening, ::AbstractVecOrMat)
length(::Whitening)
mean(::Whitening)
size(::Whitening)
```

Additional methods
```@docs
cov_whitening
cov_whitening!
```
12 changes: 9 additions & 3 deletions src/MultivariateStats.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module MultivariateStats
using LinearAlgebra
using StatsBase: SimpleCovariance, CovarianceEstimator, pairwise, pairwise!
using StatsBase: SimpleCovariance, CovarianceEstimator, RegressionModel,
AbstractDataTransform, pairwise!
import Statistics: mean, var, cov, covm
import Base: length, size, show, dump
import StatsBase: RegressionModel, fit, predict, ConvergenceException, dof, coef, weights, pairwise
import StatsBase: fit, predict, predict!, ConvergenceException, dof_residual, coef, weights, dof, pairwise
import SparseArrays
import LinearAlgebra: eigvals

Expand Down Expand Up @@ -111,7 +112,6 @@ module MultivariateStats
faem, # Maximum likelihood probabilistic PCA
facm # EM algorithm for probabilistic PCA


## source files
include("common.jl")
include("lreg.jl")
Expand All @@ -125,4 +125,10 @@ module MultivariateStats
include("ica.jl")
include("fa.jl")

## deprecations
@deprecate indim(f::Whitening) length(f::Whitening)
@deprecate outdim(f::Whitening) length(f::Whitening)
# @deprecate transform(m, x; kwargs...) predict(m, x; kwargs...) #ex=false
# @deprecate transform(m; kwargs...) predict(m; kwargs...) #ex=false

end # module
6 changes: 3 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ decentralize(x::AbstractMatrix, m::AbstractVector) = (isempty(m) ? x : x .+ m)

# get a full mean vector

fullmean(d::Int, mv::Vector{T}) where T = (isempty(mv) ? zeros(T, d) : mv)
fullmean(d::Int, mv::AbstractVector{T}) where T = (isempty(mv) ? zeros(T, d) : mv)

preprocess_mean(X::AbstractMatrix{T}, m) where T<:Real =
(m === nothing ? vec(mean(X, dims=2)) : m == 0 ? T[] : m)
preprocess_mean(X::AbstractMatrix{T}, m; dims=2) where T<:Real =
(m === nothing ? vec(mean(X, dims=dims)) : m == 0 ? T[] : m)

# choose the first k values and columns
#
Expand Down
147 changes: 123 additions & 24 deletions src/whiten.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,151 @@
# Whitening

## Solve whitening based on covariance
#
# finds W, such that W'CW = I
#
"""
cov_whitening(C)
Derive the whitening transform coefficient matrix `W` given the covariance matrix `C`. Here, `C` can be either a square matrix, or an instance of `Cholesky`.
Internally, this function solves the whitening transform using Cholesky factorization. The rationale is as follows: let ``\\mathbf{C} = \\mathbf{U}^T \\mathbf{U}`` and ``\\mathbf{W} = \\mathbf{U}^{-1}``, then ``\\mathbf{W}^T \\mathbf{C} \\mathbf{W} = \\mathbf{I}``.
**Note:** The return matrix `W` is an upper triangular matrix.
"""
function cov_whitening(C::Cholesky{T}) where {T<:Real}
cf = C.UL
Matrix{T}(inv(istriu(cf) ? cf : cf'))
end

cov_whitening!(C::DenseMatrix{<:Real}) = cov_whitening(cholesky!(Hermitian(C, :U)))
cov_whitening(C::DenseMatrix{<:Real}) = cov_whitening!(copy(C))
"""
cov_whitening!(C)
In-place version of `cov_whitening(C)`, in which the input matrix `C` will be overwritten during computation. This can be more efficient when `C` is no longer used.
"""
cov_whitening!(C::AbstractMatrix{<:Real}) = cov_whitening(cholesky!(Hermitian(C, :U)))
cov_whitening(C::AbstractMatrix{<:Real}) = cov_whitening!(copy(C))

"""
cov_whitening!(C, regcoef)
In-place version of `cov_whitening(C, regcoef)`, in which the input matrix `C` will be overwritten during computation. This can be more efficient when `C` is no longer used.
"""
cov_whitening!(C::AbstractMatrix{<:Real}, regcoef::Real) = cov_whitening!(regularize_symmat!(C, regcoef))

"""
cov_whitening(C, regcoef)
cov_whitening!(C::DenseMatrix{<:Real}, regcoef::Real) = cov_whitening!(regularize_symmat!(C, regcoef))
cov_whitening(C::DenseMatrix{<:Real}, regcoef::Real) = cov_whitening!(copy(C), regcoef)
Derive a whitening transform based on a regularized covariance, as `C + (eigmax(C) * regcoef) * eye(d)`.
"""
cov_whitening(C::AbstractMatrix{<:Real}, regcoef::Real) = cov_whitening!(copy(C), regcoef)

## Whitening type

struct Whitening{T<:Real}
mean::Vector{T}
W::Matrix{T}
"""
A whitening transform representation.
"""
struct Whitening{T<:Real} <: AbstractDataTransform
mean::AbstractVector{T}
W::AbstractMatrix{T}

function Whitening{T}(mean::Vector{T}, W::Matrix{T}) where {T<:Real}
function Whitening{T}(mean::AbstractVector{T}, W::AbstractMatrix{T}) where {T<:Real}
d, d2 = size(W)
d == d2 || error("W must be a square matrix.")
isempty(mean) || length(mean) == d ||
throw(DimensionMismatch("Sizes of mean and W are inconsistent."))
return new(mean, W)
end
end
Whitening(mean::Vector{T}, W::Matrix{T}) where {T<:Real} = Whitening{T}(mean, W)
Whitening(mean::AbstractVector{T}, W::AbstractMatrix{T}) where {T<:Real} = Whitening{T}(mean, W)

"""
length(f)
Get the dimension of the whitening transform `f`.
"""
length(f::Whitening) = size(f.W, 1)

"""
size(f)
Dimensions of the coefficient matrix of the whitening transform `f`.
"""
size(f::Whitening) = size(f.W)

"""
mean(f)
Get the mean vector of the whitening transformation `f`.
indim(f::Whitening) = size(f.W, 1)
outdim(f::Whitening) = size(f.W, 2)
**Note:** if mean is empty, this function returns a zero vector of length [`outdim`](@ref) .
"""
mean(f::Whitening) = fullmean(indim(f), f.mean)

transform(f::Whitening, x::AbstractVecOrMat{<:Real}) = transpose(f.W) * centralize(x, f.mean)

## Fit whitening to data
"""
transform(f, x)
function fit(::Type{Whitening}, X::DenseMatrix{T};
Apply the whitening transform `f` to a vector or a matrix `x` with samples in columns, as ``\\mathbf{W}^T (\\mathbf{x} - \\boldsymbol{\\mu})``.
"""
function transform(f::Whitening, x::AbstractVecOrMat{<:Real})
s = size(x)
Z, dims = if length(s) == 1
length(f.mean) == s[1] || throw(DimensionMismatch("Inconsistent dimensions."))
x - f.mean, 2
else
dims = (s[1] == length(f.mean)) + 1
length(f.mean) == s[3-dims] || throw(DimensionMismatch("Inconsistent dimensions."))
x .- (dims == 2 ? f.mean : transpose(f.mean)), dims
end
if dims == 2
transpose(f.W) * Z
else
Z * f.W
end
end

"""
fit(::Type{Whitening}, X::AbstractMatrix{T}; kwargs...)
Estimate a whitening transform from the data given in `X`.
This function returns an instance of [`Whitening`](@ref)
**Keyword Arguments:**
- `regcoef`: The regularization coefficient. The covariance will be regularized as follows when `regcoef` is positive `C + (eigmax(C) * regcoef) * eye(d)`. Default values is `zero(T)`.
- `dims`: if `1` the transformation calculated from the row samples. fit standardization parameters in column-wise fashion;
if `2` the transformation calculated from the column samples. The default is `nothing`, which is equivalent to `dims=2` with a deprecation warning.
- `mean`: The mean vector, which can be either of:
- `0`: the input data has already been centralized
- `nothing`: this function will compute the mean (**default**)
- a pre-computed mean vector
**Note:** This function internally relies on [`cov_whitening`](@ref) to derive the transformation `W`.
"""
function fit(::Type{Whitening}, X::AbstractMatrix{T};
dims::Union{Integer,Nothing}=nothing,
mean=nothing, regcoef::Real=zero(T)) where {T<:Real}
n = size(X, 2)
n > 1 || error("X must contain more than one sample.")
mv = preprocess_mean(X, mean)
Z = centralize(X, mv)
if dims === nothing
Base.depwarn("fit(Whitening, x) is deprecated: use fit(Whitening, x, dims=2) instead", :fit)
dims = 2
end
if dims == 1
n = size(X,1)
n >= 2 || error("X must contain at least two rows.")
elseif dims == 2
n = size(X, 2)
n >= 2 || error("X must contain at least two columns.")
else
throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
end
mv = preprocess_mean(X, mean; dims=dims)
Z = centralize((dims==1 ? transpose(X) : X), mv)
C = rmul!(Z * transpose(Z), one(T) / (n - 1))
return Whitening(mv, cov_whitening!(C, regcoef))
end

# invsqrtm

function _invsqrtm!(C::Matrix{<:Real})
function _invsqrtm!(C::AbstractMatrix{<:Real})
n = size(C, 1)
size(C, 2) == n || error("C must be a square matrix.")
E = eigen!(Symmetric(C))
Expand All @@ -64,4 +158,9 @@ function _invsqrtm!(C::Matrix{<:Real})
return U * transpose(U)
end

invsqrtm(C::DenseMatrix{<:Real}) = _invsqrtm!(copy(C))
"""
invsqrtm(C)
Compute `inv(sqrtm(C))` through symmetric eigenvalue decomposition.
"""
invsqrtm(C::AbstractMatrix{<:Real}) = _invsqrtm!(copy(C))
25 changes: 24 additions & 1 deletion test/whiten.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using MultivariateStats
using LinearAlgebra
using LinearAlgebra, StatsBase, SparseArrays
using Test
import Statistics: mean, cov
import Random
Expand Down Expand Up @@ -55,6 +55,8 @@ import Random
W = f.W
@test isa(f, Whitening{Float64})
@test mean(f) === f.mean
@test length(f) == d
@test size(f) == (d,d)
@test istriu(W)
@test W'C * W Matrix(I, d, d)
@test transform(f, X) W' * (X .- f.mean)
Expand Down Expand Up @@ -92,4 +94,25 @@ import Random
# type consistency
@test eltype(mean(M)) == Float64
@test eltype(mean(MM)) == Float32

# sparse arrays
SX = sprand(Float32, d, n, 0.75)
SM = fit(Whitening, SX; mean=sprand(Float32, 3, 0.75))
Y = transform(SM, SX)
@test eltype(Y) == Float32

# different dimensions
@test_throws DomainError fit(Whitening, X'; dims=3)
M1 = fit(Whitening, X'; dims=1)
M2 = fit(Whitening, X; dims=2)
@test M1.W == M2.W
@test_throws DimensionMismatch transform(M1, rand(6,4))
@test_throws DimensionMismatch transform(M2, rand(4,6))
Y1 = transform(M1,X')
Y2 = transform(M2,X)
@test Y1' == Y2
@test_throws DimensionMismatch transform(M1, rand(7))
V1 = transform(M1,X[:,1])
V2 = transform(M2,X[:,1])
@test V1 == V2
end

0 comments on commit f00cba3

Please sign in to comment.