Skip to content

Commit

Permalink
Merge pull request #95 from TuringLang/dw/threadid
Browse files Browse the repository at this point in the history
Remove use of `threadid`
  • Loading branch information
cpfiffer authored Feb 21, 2022
2 parents fe972e8 + bb7ced2 commit 24f88b1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
15 changes: 3 additions & 12 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.0'
- '1.3'
- '1'
- nightly
os:
Expand All @@ -31,7 +31,7 @@ jobs:
arch: x86
- os: macOS-latest
arch: x86
- version: '1.0'
- version: '1.3'
num_threads: 2
include:
- version: '1'
Expand All @@ -45,16 +45,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
env:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "3.2.1"
version = "3.2.2"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand All @@ -25,7 +25,7 @@ ProgressLogging = "0.1"
StatsBase = "0.32, 0.33"
TerminalLoggers = "0.1"
Transducers = "0.4.30"
julia = "1"
julia = "1.3"

[extras]
Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1"
Expand Down
43 changes: 24 additions & 19 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,15 @@ function mcmcsample(
end

# Copy the random number generator, model, and sample for each thread
# NOTE: As of May 17, 2020, this relies on Julia's thread scheduling functionality
# that distributes a for loop into equal-sized blocks and allocates them
# to each thread. If this changes, we may need to rethink things here.
nchunks = min(nchains, Threads.nthreads())
chunksize = cld(nchains, nchunks)
interval = 1:min(nchains, Threads.nthreads())
rngs = [deepcopy(rng) for _ in interval]
models = [deepcopy(model) for _ in interval]
samplers = [deepcopy(sampler) for _ in interval]

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
# Create a seed for each chunk using the provided random number generator.
seeds = rand(rng, UInt, nchunks)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -340,20 +339,26 @@ function mcmcsample(

Distributed.@async begin
try
Threads.@threads for i in 1:nchains
# Obtain the ID of the current thread.
id = Threads.threadid()

# Seed the thread-specific random number generator with the pre-made seed.
subrng = rngs[id]
Random.seed!(subrng, seeds[i])

# Sample a chain and save it to the vector.
chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N;
progress = false, kwargs...)

# Update the progress bar.
progress && put!(channel, true)
Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers)
Threads.@spawn begin
# Seed the chunk-specific random number generator with the pre-made seed.
Random.seed!(_rng, seed)

chainidxs = if i == nchunks
((i - 1) * chunksize + 1):nchains
else
((i - 1) * chunksize + 1):(i * chunksize)
end

for chainidx in chainidxs
# Sample a chain and save it to the vector.
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
progress = false, kwargs...)

# Update the progress bar.
progress && put!(channel, true)
end
end
end
finally
# Stop updating the progress bar.
Expand Down

2 comments on commit 24f88b1

@cpfiffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/55133

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.2.2 -m "<description of version>" 24f88b1dac45effa78f5f3eedc89767dbc221165
git push origin v3.2.2

Please sign in to comment.