Skip to content

Commit

Permalink
Modified the parameters used to call the Gaussian and exponential HMC…
Browse files Browse the repository at this point in the history
… walks in bindings. Added tests to check they work as expected
  • Loading branch information
guillexm committed Jul 24, 2023
1 parent 81a903d commit 40ab1f9
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
7 changes: 3 additions & 4 deletions dingo/bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ double HPolytopeCPP::apply_sampling(int walk_len,
starting_point = inner_point2;
std::list<Point> rand_points;
NT variance = 1.0;
NT a = NT(1)/(NT(2)*variance);
int dim = 50;
Point c(dim);
c = GetDirection<Point>::apply(dim, rng, false);

if (method == 1) { // cdhr
uniform_sampling<CDHRWalk>(rand_points, HP, rng, walk_len, number_of_points,
Expand All @@ -133,10 +129,13 @@ double HPolytopeCPP::apply_sampling(int walk_len,
starting_point, number_of_points_to_burn);
}
else if (method == 8) { // gaussian sampling with gaussian HMC exact walk {
NT a = NT(1)/(NT(2)*variance);
gaussian_sampling<GaussianHamiltonianMonteCarloExactWalk>(rand_points, HP, rng, walk_len, number_of_points, a,
starting_point, number_of_points_to_burn);
}
else if (method == 9) { // exponential sampling with exponential HMC exact walk {
Point c(d);
c = GetDirection<Point>::apply(d, rng, false);
exponential_sampling<ExponentialHamiltonianMonteCarloExactWalk>(rand_points, HP, rng, walk_len, number_of_points, c, variance,
starting_point, number_of_points_to_burn);
}
Expand Down
75 changes: 75 additions & 0 deletions tests/sampling_no_multiphase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# dingo : a python library for metabolic networks sampling and analysis
# dingo is part of GeomScale project

# Copyright (c) 2022 Apostolos Chalkis
# Copyright (c) 2022 Vissarion Fisikopoulos
# Copyright (c) 2022 Haris Zafeiropoulos

# Licensed under GNU LGPL.3, see LICENCE file

import unittest
import os
from dingo import MetabolicNetwork, PolytopeSampler


class TestSampling(unittest.TestCase):

def test_sample_json(self):

input_file_json = os.getcwd() + "/ext_data/e_coli_core.json"
model = MetabolicNetwork.from_json( input_file_json )
sampler = PolytopeSampler(model)

#gaussian hmc sampling
steady_states = sampler.generate_steady_states_no_multiphase(method = 'gaussian_hmc_walk')

self.assertTrue( steady_states.shape[0] == 95 )
self.assertTrue( abs( steady_states[12].mean() - 2.504 ) < 1e-03 )

#exponential hmc sampling
steady_states = sampler.generate_steady_states_no_multiphase(method = 'exponential_hmc_walk')

self.assertTrue( steady_states.shape[0] == 95 )
self.assertTrue( abs( steady_states[12].mean() - 2.504 ) < 1e-03 )

def test_sample_mat(self):

input_file_mat = os.getcwd() + "/ext_data/e_coli_core.mat"
model = MetabolicNetwork.from_mat(input_file_mat)
sampler = PolytopeSampler(model)

#gaussian hmc sampling
steady_states = sampler.generate_steady_states_no_multiphase(method = 'gaussian_hmc_walk')

self.assertTrue( steady_states.shape[0] == 95 )
self.assertTrue( abs( steady_states[12].mean() - 2.504 ) < 1e-03 )

#exponential hmc sampling
steady_states = sampler.generate_steady_states_no_multiphase(method = 'exponential_hmc_walk')

self.assertTrue( steady_states.shape[0] == 95 )
self.assertTrue( abs( steady_states[12].mean() - 2.504 ) < 1e-03 )


def test_sample_sbml(self):

input_file_sbml = os.getcwd() + "/ext_data/e_coli_core.xml"
model = MetabolicNetwork.from_sbml( input_file_sbml )
sampler = PolytopeSampler(model)

#gaussian hmc sampling
steady_states = sampler.generate_steady_states_no_multiphase(method = 'gaussian_hmc_walk')

self.assertTrue( steady_states.shape[0] == 95 )
self.assertTrue( abs( steady_states[12].mean() - 2.504 ) < 1e-03 )

#exponential hmc sampling
steady_states = sampler.generate_steady_states_no_multiphase(method = 'exponential_hmc_walk')

self.assertTrue( steady_states.shape[0] == 95 )
self.assertTrue( abs( steady_states[12].mean() - 2.504 ) < 1e-03 )



if __name__ == "__main__":
unittest.main()

0 comments on commit 40ab1f9

Please sign in to comment.