Skip to content

Commit

Permalink
Some more functions and good docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Dsantra92 committed Sep 6, 2024
1 parent c07289b commit 7c0566a
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ makedocs(
assets = ["assets/favicon.ico"],
collapselevel = 3
),
pages = ["Home" => "index.md", "API Reference" => ["o3" => "api/o3.md"]]
pages = ["Home" => "index.md", "API Reference" => ["Irreps" => "api/irreps.md"]]
)

deploydocs(repo = "github.com/Dsantra92/e3nn.jl.git")
12 changes: 12 additions & 0 deletions docs/src/api/irreps.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Irreducible Representation


```@autodocs
Modules = [e3nn.o3 ]
Pages = ["o3/irreps.jl"]
Private = false
```

```@docs
Base.iterate
```
9 changes: 0 additions & 9 deletions docs/src/api/o3.md

This file was deleted.

135 changes: 107 additions & 28 deletions src/o3/irreps.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,52 @@
using Random: GLOBAL_RNG, AbstractRNG
using Random: AbstractRNG

"""
Irrep(l::Int, p::Int)
Irrep(s::String)
Irreducible representation of ``O(3)``.
This struct does not contain any data; it is a structure that describes the representation.
It is typically used as an argument in other parts of the library to define the input and output representations of
functions.
# Arguments
- `l::Int`: non-negative integer, the degree of the representation, l = 0, 1, ...
- `p::Int`: the parity of the representation, either 1 (even) or -1 (odd)
- `s::String`: a string representation, e.g., "1o" for l=1, odd parity
# Examples
Create a scalar representation (l=0) of even parity:
```jldoctest
julia> Irrep(0, 1)
0e
```
Create a pseudotensor representation (l=2) of odd parity:
```jldoctest
julia> Irrep(2, -1)
2o
```
Create a vector representation (l=1) of the parity of the spherical harmonics (-1^l gives odd parity):
```jldoctest
julia> Irrep("1y")
1o
```
Other operations:
```jldoctest
julia> dim(Irrep("2o"))
5
julia> Irrep("2e") in Irrep("1o") * Irrep("1o")
true
julia> Irrep("1o") + Irrep("2o")
1x1o+1x2o
```
"""
struct Irrep
l::Int
p::Int
Expand All @@ -11,8 +58,8 @@ struct Irrep
end
end

function Irrep(ir::T) where {T <: AbstractString}
name = strip(ir)
function Irrep(l::T) where {T <: AbstractString}
name = strip(l)
try
l = parse(Int, name[1:(end - 1)])
(l >= 0) || throw(ArgumentError("l must be zero or positive integer, got $l"))
Expand All @@ -23,19 +70,19 @@ function Irrep(ir::T) where {T <: AbstractString}
end
end

function Irrep(ir::Tuple)
@assert length(ir) == 2
l, p = ir
function Irrep(l::Tuple)
@assert length(l) == 2
l, p = l
return Irrep(l, p)
end

Irrep(ir::Irrep) = ir
Irrep(l::Irrep) = l

dim(ir::Irrep) = 2 * ir.l + 1
dim(x::Irrep) = 2 * x.l + 1

isscalar(x::Irrep) = (x.l == 0) && (x.p == 1)

# Base.iterate(x::Irrep, args...) = iterate((x.l, x.p), args...)
Base.iterate(x::Irrep, args...) = iterate((x.l, x.p), args...)

function Base.:*(x1::Irrep, x2::Irrep)
p = x1.p * x2.p
Expand All @@ -48,17 +95,13 @@ function Base.:*(i::Int, x::Irrep)
return Irreps([(i, x)])
end

# pretty solid read: https://vladium.com/tutorials/study_julia_with_me/equality_vs_identity/
Base.:(==)(x1::Irrep, x2::Irrep) = (x1.l == x2.l) && (x1.p == x2.p)
Base.:(==)(lhs::Irrep, rhs::Union{String, Tuple}) = lhs == Irrep(rhs)
Base.:(==)(lhs::Union{String, Tuple}, rhs::Irrep) = Irrep(lhs) == rhs

function Base.:+(x1::Irrep, x2::Irrep)
return Irreps(x1) + Irreps(x2)
end

# Used for comparison
Base.isless(x1::Irrep, x2::Irrep) = Base.isless((x1.l, x1.p), (x2.l, x2.p))
function Base.isless(x1::Irrep, x2::Irrep)
Base.isless((x1.l, -x1.p * (-1)^x1.l), (x2.l, -x2.p * (-1)^x2.l))
end

function Base.show(io::IO, x::Irrep)
p = Dict(+1 => "e", -1 => "o")[x.p]
Expand Down Expand Up @@ -92,6 +135,9 @@ function Base.show(io::IO, mx::MulIrrep)
print(io, s)
end

"""
dim(x::MulIrrep)
"""
dim(mx::MulIrrep) = mx.mul * dim(mx.irrep)

Base.convert(::Type{Tuple}, x::MulIrrep) = (x.mul, x.irrep)
Expand Down Expand Up @@ -168,12 +214,15 @@ Base.getindex(xs::Irreps, idx::AbstractRange) = xs.irreps[idx] |> Irreps
Base.firstindex(xs::Irreps) = Base.firstindex(xs.irreps)
Base.lastindex(xs::Irreps) = Base.lastindex(xs.irreps)

# TODO: implement indexing by mul and dim if required

Base.in(x::Irrep, xs::Irreps) = x [mx.irrep for mx in xs.irreps]
function Base.count(x::Irrep, xs::Irreps)
sum([mx.mul for mx in xs.irreps if mx.irrep == x], init = 0)
end
Base.iterate(xs::Irreps, state = 1) = state > length(xs) ? nothing :
(xs[state], state + 1)
function Base.iterate(xs::Irreps, state = 1)
state > length(xs) ? nothing : (xs[state], state + 1)
end

"""
Representation of spherical harmonics.
Expand Down Expand Up @@ -204,7 +253,9 @@ Simplify the representaions.
"""
function simplify(xs::Irreps)::Irreps
out = []
for (mul, irrep) in xs
xs = Base.sort(xs)
for mul_ir in xs
mul, irrep = mul_ir.mul, mul_ir.irrep
if length(out) != 0 && out[end][2] == irrep
out[end] = (out[end][1] + mul, irrep)
elseif mul > 0
Expand All @@ -220,16 +271,11 @@ Remove any irreps with multiplicities of zero.
remove_zero_multiplicities(xs::Irreps) = [(mul, irreps) for (mul, irreps) in xs if mul > 0] |>
Irreps

"""
Sort the representations.
"""
function Base.sort(xs::Irreps)
out = [(x.irrep, i, x.mul) for (i, x) in enumerate(xs)]
out = sort(out)
inv = [i for (_, i, _) in out]
p = sortperm(inv)
function Base.sort(xs::Irreps)::Irreps
out = [(mx.irrep, i, mx.mul) for (i, mx) in enumerate(xs)]
out = Base.sort(out)
sorted_irreps = Irreps([(mul, irrep) for (irrep, _, mul) in out])
return (irreps = sorted_irreps, perm = p, inv = inv)
return sorted_irreps
end

dim(xs::Irreps) = sum([mx.mul * dim(mx.irrep) for mx in xs], init = 0)
Expand All @@ -238,6 +284,39 @@ num_irreps(xs::Irreps) = sum([mx.mul for mx in xs], init = 0)

ls(xs::Irreps) = [mx.irrep.l for mx in xs for _ in 1:(mx.mul)]

"""
Base.iterate(::Type{Irrep}, state=(0, true))
Iterator through all the irreps of O(3).
# Examples
```julia-repl
julia> first(Irrep)
0e
julia> collect(Iterators.take(Irrep, 6)) # set lmax as 6
6-element Vector{Any}:
0e
0o
1o
1e
2e
2o
```
"""
function Base.iterate(::Type{Irrep}, state = (0, true))
l, is_positive = state
if is_positive
return (Irrep(l, (-1)^l), (l, false))
else
return (Irrep(l, -(-1)^l), (l + 1, true))
end
end

Base.IteratorSize(::Type{Irrep}) = Base.IsInfinite()
Base.IteratorEltype(::Type{Irrep}) = Base.HasEltype()
Base.eltype(::Type{Irrep}) = Irrep

function lmax(xs::Irreps)::Int
if length(xs) == 0
throw(ArgumentError("Cannot get lmax of empty Irreps"))
Expand Down
34 changes: 34 additions & 0 deletions test/o3/irreps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ using Test
@test ir.l == 5
@test ir.p == -1

# This is a weird way to do this,
# Iterators.take is a better way to do this
# Just wanted to test if the iterator works infinitely
iter = Irrep
for x in range(0, 500)
irrep, iter = Iterators.peel(iter)
@test irrep.l == x // 2 |> trunc
@test irrep.p in (-1, 1)
@test dim(irrep) == (2 * trunc(x // 2) + 1)
end

irreps = Irreps("4x1e + 6x2e + 12x2o")
@test Irreps(repr(irreps)) == irreps
end
Expand All @@ -45,6 +56,7 @@ using Test

@test 2 * Irreps("2x2e + 4x1o") == Irreps("2x2e + 4x1o + 2x2e + 4x1o")
@test Irreps("2x2e + 4x1o") * 2 == Irreps("2x2e + 4x1o + 2x2e + 4x1o")
@test Irreps("2x2e + 4x1o") * 2 |> simplify == Irreps("8x1o + 4x2e") # note the ordering
end

@testset "empty" begin
Expand All @@ -56,6 +68,13 @@ using Test
@test num_irreps(er) == 0
end

@testset "getitem" begin
irreps = Irreps("16x1e + 3e + 2e + 5o")
@test irreps[1] == MulIrrep(16, Irrep("1e"))
@test irreps[4] == MulIrrep(1, Irrep("5o"))
@test irreps[length(irreps)] == MulIrrep(1, Irrep("5o"))
end

@testset "cat" begin
irreps = Irreps("4x1e + 6x2e + 12x2o") + Irreps("1x1e + 2x2e + 12x4o")
@test length(irreps) == 6
Expand All @@ -71,6 +90,21 @@ using Test
@test num_irreps(irreps) == 4 + 6 + 12 + 1 + 2 + 12
end

@testset "ordering" begin
n_test = 100

last = nothing
for (i, irrep) in enumerate(Iterators.take(Irrep, n_test))
if !isnothing(last)
@test last < irrep
end
if i == n_test
break
end
last = irrep
end
end

@testset "contains" begin
@test Irrep("2e") Irreps("3x0e + 2x2e + 1x3o")
@test Irrep("2o") Irreps("3x0e + 2x2e + 1x3o")
Expand Down

0 comments on commit 7c0566a

Please sign in to comment.