Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make compatible with AtomsBase 0.4.x #51

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
name = "ExtXYZ"
uuid = "352459e4-ddd7-4360-8937-99dcb397b478"
authors = ["James Kermode <[email protected]> and contributors"]
version = "0.1.15-DEV"
version = "0.2.0-dev"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
AtomsBaseTesting = "ed7c10db-df7e-4efa-a7be-4f4190f7f227"
PeriodicTable = "7b2266bf-644c-5ea3-82d8-af4bbd25a884"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
UnitfulAtomic = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"
extxyz_jll = "6ecdc6fc-93a8-5528-aee3-ac7ae1c60be7"

[compat]
AtomsBase = "0.3"
AtomsBase = "0.4"
PeriodicTable = "1"
StaticArrays = "1.5"
Unitful = "1"
UnitfulAtomic = "1"
julia = "1"
extxyz_jll = "0.1.3"
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
117 changes: 65 additions & 52 deletions src/atoms.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using AtomsBase
using Unitful
using UnitfulAtomic
using StaticArrays

import AtomsBase: AbstractSystem
export Atoms

const D = 3 # TODO generalise to arbitrary spatial dimensions
Expand All @@ -23,34 +25,38 @@ end

function Atoms(system::AbstractSystem{D})
n_atoms = length(system)
s = species(system, :)
atomic_symbols = [Symbol(element(atomic_number(at)).symbol) for at in system]
if atomic_symbols != atomic_symbol(system)
atomic_numbers = atomic_number.(s)
if atomic_symbols != Symbol.(s)
@warn("Mismatch between atomic numbers and atomic symbols, which is not supported " *
"in ExtXYZ. Atomic numbers take preference.")
end
atom_data = Dict{Symbol,Any}(
:atomic_symbol => atomic_symbols,
:atomic_number => atomic_number(system),
:atomic_mass => atomic_mass(system)
:atomic_number => Int.(atomic_number(system, :)), # gets messy if not Int
:species => s,
:mass => mass(system, :)
)
atom_data[:position] = map(1:n_atoms) do at
pos = zeros(3)u"Å"
pos[1:D] = position(system, at)
pos
SVector{D, eltype(pos)}(pos) # AtomsBase 0.4 requires SVector
end
atom_data[:velocity] = map(1:n_atoms) do at
vel = zeros(3) * uVelocity
if !ismissing(velocity(system)) && !ismissing(velocity(system, at))
if !ismissing(velocity(system, :)) && !ismissing(velocity(system, at))
vel[1:D] = velocity(system, at)
end
vel
SVector{D, eltype(vel)}(vel) # AtomsBase 0.4 requires SVector
end

for k in atomkeys(system)
if k in (:atomic_symbol, :atomic_number, :atomic_mass, :velocity, :position)
if k in (:species, :atomic_symbol, :atomic_number, :mass, :velocity, :position)
continue # Already done
end
atoms_base_keys = (:charge, :covalent_radius, :vdw_radius,
# atomic_mass is deprecated for but is sometimes still used
atoms_base_keys = (:charge, :atomic_mass, :covalent_radius, :vdw_radius,
:magnetic_moment, :pseudopotential)
v = system[1, k]
if k in atoms_base_keys || v isa ExtxyzType || v isa AbstractVector{<: ExtxyzType}
Expand All @@ -65,20 +71,14 @@ function Atoms(system::AbstractSystem{D})
end
end

box = map(1:3) do i
v = zeros(3)u"Å"
i ≤ D && (v[1:D] = bounding_box(system)[i])
v
end
system_data = Dict{Symbol,Any}(
:bounding_box => box,
:boundary_conditions => boundary_conditions(system)
:bounding_box => bounding_box(system),
:periodicity => periodicity(system)
)

# Extract extra system properties
system_data = Dict{Symbol,Any}()
for (k, v) in pairs(system)
atoms_base_keys = (:charge, :multiplicity, :boundary_conditions, :bounding_box)
atoms_base_keys = (:charge, :multiplicity, :periodicity, :bounding_box)
if k in atoms_base_keys || v isa ExtxyzType || v isa AbstractArray{<: ExtxyzType}
# These are either Unitful quantities, which are uniformly supported
# across all of AtomsBase or the value has a type that Extxyz can write
Expand All @@ -104,24 +104,29 @@ function Atoms(dict::Dict{String, Any})
elseif haskey(arrays, "species")
Z = [element(Symbol(spec)).number for spec in arrays["species"]]
else
error("Cannot determine atomic numbers. Either 'Z' or 'species' must " *
error("Cannot determine atomic numbers. Either 'Z' or 'S' must " *
"be present in arrays")
end
@assert length(Z) == dict["N_atoms"]

atomic_symbols = [Symbol(element(num).symbol) for num in Z]
atom_data = Dict{Symbol, Any}(
:position => collect(eachcol(arrays["pos"]))u"Å",
:atomic_number => Z,
:atomic_symbol => atomic_symbols,
:species => AtomsBase.ChemicalSpecies.(atomic_symbols)
)
if haskey(arrays, "species")
atom_data[:atomic_symbol] = Symbol.(arrays["species"])
else
atom_data[:atomic_symbol] = [Symbol(element(num).symbol) for num in Z]
end
# TODO; Instead of the following, should there be a consistency check
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this would be preferable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed an issue

# between S and Z?
# if haskey(arrays, "species")
# atom_data[:atomic_symbol] = Symbol.(arrays["species"])
# else
# atom_data[:atomic_symbol] = [Symbol(element(num).symbol) for num in Z]
# end
if haskey(arrays, "mass")
atom_data[:atomic_mass] = arrays["mass"]u"u"
atom_data[:mass] = arrays["mass"]u"u"
else
atom_data[:atomic_mass] = [element(num).atomic_mass for num in Z]
atom_data[:mass] = [element(num).atomic_mass for num in Z]
end
if haskey(arrays, "velocities")
atom_data[:velocity] = collect(eachcol(arrays["velocities"])) * uVelocity
Expand All @@ -130,7 +135,7 @@ function Atoms(dict::Dict{String, Any})
end

for key in keys(arrays)
key in ("mass", "species", "Z", "pos", "velocities") && continue # Already done
key in ("mass", "species", "Z", "atomic_symbol", "pos", "velocities") && continue # Already done
if key in ("vdw_radius", "covalent_radius") # Add length unit
atom_data[Symbol(key)] = arrays[key] * u"Å"
elseif key in ("charge", ) # Add charge unit
Expand All @@ -146,16 +151,17 @@ function Atoms(dict::Dict{String, Any})
if haskey(dict, "cell")
system_data[:bounding_box] = collect(eachrow(dict["cell"]))u"Å"
if haskey(dict, "pbc")
system_data[:boundary_conditions] = [p ? Periodic() : DirichletZero()
for p in dict["pbc"]]
system_data[:periodicity] = tuple(dict["pbc"]...)
else
@warn "'pbc' not contained in dict. Defaulting to all-periodic boundary. "
system_data[:boundary_conditions] = fill(Periodic(), 3)
system_data[:periodicity] = (true, true, true)
end
else # Infinite system
haskey(dict, "pbc") && @warn "'pbc' ignored since no 'cell' entry found in dict."
system_data[:boundary_conditions] = fill(DirichletZero(), 3)
system_data[:bounding_box] = infinite_box(3)
system_data[:periodicity] = (false, false, false)
system_data[:bounding_box] = ( SVector(Inf, 0.0, 0.0) * u"Å",
SVector(0.0, Inf, 0.0) * u"Å",
SVector(0.0, 0.0, Inf) * u"Å" )
end

for key in keys(info)
Expand All @@ -168,6 +174,7 @@ function Atoms(dict::Dict{String, Any})

Atoms(NamedTuple(atom_data), NamedTuple(system_data))
end

read_dict(dict::Dict{String,Any}) = Atoms(dict)

function write_dict(atoms::Atoms)
Expand All @@ -179,8 +186,8 @@ function write_dict(atoms::Atoms)
@warn("Mismatch between atomic numbers and atomic symbols, which is not supported " *
"in ExtXYZ. Atomic numbers take preference.")
end
if atoms.atom_data.atomic_mass != [element(Z).atomic_mass for Z in arrays["Z"]]
arrays["mass"] = ustrip.(u"u", atoms.atom_data.atomic_mass)
if atoms.atom_data.mass != [element(Z).atomic_mass for Z in arrays["Z"]]
arrays["mass"] = ustrip.(u"u", atoms.atom_data.mass)
end

arrays["velocities"] = zeros(D, length(atoms))
Expand All @@ -193,7 +200,8 @@ function write_dict(atoms::Atoms)
end

for (k, v) in pairs(atoms.atom_data)
k in (:atomic_mass, :atomic_symbol, :atomic_number, :position, :velocity) && continue
k in (:mass, :atomic_mass, :atomic_symbol, :atomic_number, :position,
:velocity, :species) && continue
if k in (:vdw_radius, :covalent_radius) # Remove length unit
arrays[string(k)] = ustrip.(u"Å", v)
elseif k in (:charge, )
Expand All @@ -207,10 +215,7 @@ function write_dict(atoms::Atoms)
end
end

pbc = zeros(Bool, D)
for (i, bc) in enumerate(atoms.system_data.boundary_conditions)
pbc[i] = bc isa Periodic
end
pbc = atoms.system_data.periodicity
cell = zeros(D, D)
for (i, bvector) in enumerate(atoms.system_data.bounding_box)
cell[i, :] = ustrip.(u"Å", bvector)
Expand All @@ -219,7 +224,7 @@ function write_dict(atoms::Atoms)
# Deal with other system keys
info = Dict{String,Any}()
for (k, v) in pairs(atoms.system_data)
k in (:boundary_conditions, :bounding_box) && continue # Already dealt with
k in (:periodicity, :bounding_box) && continue # Already dealt with
if k in (:charge, )
info[string(k)] = ustrip(u"e_au", atoms.system_data[k])
elseif v isa ExtxyzType
Expand All @@ -236,7 +241,7 @@ function write_dict(atoms::Atoms)
end
end
dict = Dict("N_atoms" => length(atoms),
"pbc" => pbc,
"pbc" => [pbc...],
"info" => info,
"arrays" => arrays)
all(cell .!= Inf) && (dict["cell"] = cell) # only write cell if its finite
Expand All @@ -249,30 +254,38 @@ write_dict(system::AbstractSystem{D}) = write_dict(Atoms(system))
Base.length(sys::Atoms) = length(sys.atom_data.position)
Base.size(sys::Atoms) = (length(sys), )
AtomsBase.bounding_box(sys::Atoms) = sys.system_data.bounding_box
AtomsBase.boundary_conditions(sys::Atoms) = sys.system_data.boundary_conditions
AtomsBase.periodicity(sys::Atoms) = sys.system_data.periodicity

# AtomsBase now requires a cell object to be returned instead of bounding_box
# and boundary conditions. But this can just be constructed on the fly.
AtomsBase.cell(sys::Atoms) = AtomsBase.PeriodicCell(;
cell_vectors = sys.system_data.bounding_box,
periodicity = sys.system_data.periodicity )

AtomsBase.species_type(::FS) where {FS <: Atoms} = AtomView{FS}
Base.getindex(sys::Atoms, x::Symbol) = getindex(sys.system_data, x)
Base.haskey(sys::Atoms, x::Symbol) = haskey(sys.system_data, x)
Base.keys(sys::Atoms) = keys(sys.system_data)

Base.getindex(sys::Atoms, i::Integer) = AtomView(sys, i)
Base.getindex(sys::Atoms, i::Integer, x::Symbol) = getindex(sys.atom_data, x)[i]
Base.getindex(sys::Atoms, i::AbstractVector{<: Integer}, x::Symbol) = getindex(sys.atom_data, x)[i]
Base.getindex(sys::Atoms, ::Colon, x::Symbol) = getindex(sys.atom_data, x)

AtomsBase.atomkeys(sys::Atoms) = keys(sys.atom_data)
AtomsBase.hasatomkey(sys::Atoms, x::Symbol) = haskey(sys.atom_data, x)

AtomsBase.position(s::Atoms) = Base.getindex(s, :, :position)
AtomsBase.position(s::Atoms, i::Integer) = Base.getindex(s, i, :position)
AtomsBase.velocity(s::Atoms) = Base.getindex(s, :, :velocity)
AtomsBase.velocity(s::Atoms, i::Integer) = Base.getindex(s, i, :velocity)
AtomsBase.atomic_mass(s::Atoms) = Base.getindex(s, :, :atomic_mass)
AtomsBase.atomic_mass(s::Atoms, i::Integer) = Base.getindex(s, i, :atomic_mass)
AtomsBase.atomic_symbol(s::Atoms) = Base.getindex(s, :, :atomic_symbol)
AtomsBase.atomic_symbol(s::Atoms, i::Integer) = Base.getindex(s, i, :atomic_symbol)
AtomsBase.atomic_number(s::Atoms) = Base.getindex(s, :, :atomic_number)
AtomsBase.atomic_number(s::Atoms, i::Integer) = Base.getindex(s, i, :atomic_number)
const _IDX = Union{Colon, Integer, AbstractArray{<: Integer}}
AtomsBase.position(s::Atoms, i::_IDX) = getindex(s, i, :position)
AtomsBase.velocity(s::Atoms, i::_IDX) = getindex(s, i, :velocity)
AtomsBase.mass(s::Atoms, i::_IDX) = getindex(s, i, :mass)
AtomsBase.atomic_symbol(s::Atoms, i::_IDX) = getindex(s, i, :atomic_symbol)
AtomsBase.atomic_number(s::Atoms, i::_IDX) = getindex(s, i, :atomic_number)

# AtomsBase now requires the `species` function to be implemented. Since
# ExtXYZ requires that atoms are uniquely identified by their atomic number, we
# will use the atomic number as the species identifier.
AtomsBase.species(s::Atoms, i::_IDX) =
AtomsBase.ChemicalSpecies.(AtomsBase.atomic_symbol(s, i))

# --------- FileIO compatible interface (hence not exported)

Expand Down
2 changes: 1 addition & 1 deletion src/fileio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ function write_frame_dicts(fp::Ptr{Cvoid}, nat, info, arrays; verbose=false)
nat = Cint(nat)
cinfo = convert(Ptr{DictEntry}, info; transpose_arrays=true)

# ensure "species" goes in column 1 and "pos" goes in column 2
# ensure "species" (symbol!) goes in column 1 and "pos" goes in column 2
ordered_keys = collect(keys(arrays))
species_idx = findfirst(isequal("species"), ordered_keys)
ordered_keys[1], ordered_keys[species_idx] = ordered_keys[species_idx], ordered_keys[1]
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[compat]
AtomsBaseTesting = "0.1"
AtomsBaseTesting = "0.2"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
Expand Down
Loading
Loading