From 7a89b15af9c3de808017ade2f879b910bfb775d3 Mon Sep 17 00:00:00 2001 From: daniel Date: Sat, 31 Jul 2021 10:45:35 +0200 Subject: [PATCH 1/7] output for correlation analysis --- flamel.py | 2 ++ uncorrelate/statistical_inefficiency_dhdl.py | 9 +++++++-- uncorrelate/statistical_inefficiency_dhdl_all.py | 9 +++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/flamel.py b/flamel.py index c45fc3a..4f79fde 100755 --- a/flamel.py +++ b/flamel.py @@ -133,8 +133,10 @@ def main(): uncorrelator.set_u_nks(u_nks) if do_dhdl: + print("Uncorrelating dH/dl ...") dhdls = uncorrelator.uncorrelate(dhdls, args.equiltime) if do_u_nks: + print("Uncorrelating reduced potentials ...") u_nks = uncorrelator.uncorrelate(u_nks, args.equiltime) # Step 3: Estimate Free energy differences diff --git a/uncorrelate/statistical_inefficiency_dhdl.py b/uncorrelate/statistical_inefficiency_dhdl.py index 2a29dcc..15d2eca 100644 --- a/uncorrelate/statistical_inefficiency_dhdl.py +++ b/uncorrelate/statistical_inefficiency_dhdl.py @@ -50,11 +50,16 @@ def uncorrelate(self, dfs, lower): dl.append(dli) uncorrelated_dfs = [] - for dhdl_, l, df in zip(self.dhdls, dl, dfs): + print("Number of correlated and uncorrelated samples (Method=%s):\n\n%6s %12s %12s %12s\n" % ("dHdl", "State", "N", "N_k", "N/N_k")) + for idx, (dhdl_, l, df) in enumerate(zip(self.dhdls, dl, dfs)): ind = np.array(l, dtype=bool) ind = np.array(ind, dtype=int) dhdl_sum = dhdl_.dot(ind) - uncorrelated_dfs.append(alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False)) + uncorrelated_df = alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False) + N, N_k = len(df), len(uncorrelated_df) + g = N/N_k + print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) + uncorrelated_dfs.append(uncorrelated_df) return pandas.concat(uncorrelated_dfs) diff --git a/uncorrelate/statistical_inefficiency_dhdl_all.py b/uncorrelate/statistical_inefficiency_dhdl_all.py index fff38a5..2e41ea1 100644 --- a/uncorrelate/statistical_inefficiency_dhdl_all.py +++ b/uncorrelate/statistical_inefficiency_dhdl_all.py @@ -29,9 +29,14 @@ def uncorrelate(self, dfs, lower): """ uncorrelated_dfs = [] - for dhdl_, df in zip(self.dhdls, dfs): + print("Number of correlated and uncorrelated samples (Method=%s):\n\n%6s %12s %12s %12s\n" % ("dHdl (all)", "State", "N", "N_k", "N/N_k")) + for idx, (dhdl_, df) in enumerate(zip(self.dhdls, dfs)): dhdl_sum = dhdl_.sum(axis=1) - uncorrelated_dfs.append(alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False)) + uncorrelated_df = alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False) + N, N_k = len(df), len(uncorrelated_df) + g = N/N_k + print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) + uncorrelated_dfs.append(uncorrelated_df) return pandas.concat(uncorrelated_dfs) From 970bef2d8696ce228ddfe152fc11ca711050c22d Mon Sep 17 00:00:00 2001 From: daniel Date: Sat, 31 Jul 2021 10:56:40 +0200 Subject: [PATCH 2/7] skip uncorrelation if threshold is 0 --- flamel.py | 31 ++++++++++++------- uncorrelate/statistical_inefficiency_de.py | 2 +- uncorrelate/statistical_inefficiency_dhdl.py | 2 +- .../statistical_inefficiency_dhdl_all.py | 2 +- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/flamel.py b/flamel.py index 4f79fde..dd8565d 100755 --- a/flamel.py +++ b/flamel.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import argparse +import pandas as pd def get_available_plugin_ids(type): @@ -98,6 +99,7 @@ def main(): parser.add_argument('-q', '--suffix', dest='suffix', help='Suffix for datafile sets, i.e. \'xvg\' (default).', default='xvg') parser.add_argument('-e', dest='estimators', type=str, default=None, help="Comma separated Estimator methods") parser.add_argument('-n', '--uncorr', dest='uncorr', help='The observable to be used for the autocorrelation analysis; either \'dhdl_all\' (obtained as a sum over all energy components) or \'dhdl\' (obtained as a sum over those energy components that are changing; default) or \'dE\'. In the latter case the energy differences dE_{i,i+1} (dE_{i,i-1} for the last lambda) are used.', default='dhdl') + parser.add_argument('-i', '--uncorr_threshold', dest='uncorr_threshold', help='Proceed with correlated samples (N) if the number of uncorrelated samples (N_k) is found to be less than this number. If 0 is given, the time series analysis will not be performed at all. Default: 50.', default=50, type=int) parser.add_argument('-r', '--decimal', dest='decimal', help='The number of decimal places the free energies are to be reported with. No worries, this is for the text output only; the full-precision data will be stored in \'results.pickle\'. Default: 3.', default=3, type=int) parser.add_argument('-o', '--output', dest='output', type=str, default=None, help="Output methods") parser.add_argument('-a', '--software', dest='software', help='Package\'s name the data files come from: Gromacs, Sire, Desmond, or AMBER. Default: Gromacs.', default='Gromacs') @@ -127,17 +129,24 @@ def main(): u_nks = parser.get_u_nks() # Step 2: Uncorrelate the data - if uncorrelator.needs_dhdls: - uncorrelator.set_dhdls(dhdls) - if uncorrelator.needs_u_nks: - uncorrelator.set_u_nks(u_nks) - - if do_dhdl: - print("Uncorrelating dH/dl ...") - dhdls = uncorrelator.uncorrelate(dhdls, args.equiltime) - if do_u_nks: - print("Uncorrelating reduced potentials ...") - u_nks = uncorrelator.uncorrelate(u_nks, args.equiltime) + if args.uncorr_threshold > 0: + if uncorrelator.needs_dhdls: + uncorrelator.set_dhdls(dhdls) + if uncorrelator.needs_u_nks: + uncorrelator.set_u_nks(u_nks) + + if do_dhdl: + print("Uncorrelating dH/dl ...") + dhdls = uncorrelator.uncorrelate(dhdls, args.equiltime) + if do_u_nks: + print("Uncorrelating reduced potentials ...") + u_nks = uncorrelator.uncorrelate(u_nks, args.equiltime) + + # concat data for estimators + if u_nks is not None: + u_nks = pd.concat(u_nks) + if dhdls is not None: + dhdls = pd.concat(dhdls) # Step 3: Estimate Free energy differences for estimator in estimators: diff --git a/uncorrelate/statistical_inefficiency_de.py b/uncorrelate/statistical_inefficiency_de.py index b882403..19297e9 100644 --- a/uncorrelate/statistical_inefficiency_de.py +++ b/uncorrelate/statistical_inefficiency_de.py @@ -33,7 +33,7 @@ def uncorrelate(self, dfs, lower): statinefs.append(statinef) i += 1 - return pandas.concat(uncorrelated_dfs) + return uncorrelated_dfs def get_plugin(*args): diff --git a/uncorrelate/statistical_inefficiency_dhdl.py b/uncorrelate/statistical_inefficiency_dhdl.py index 15d2eca..05673f3 100644 --- a/uncorrelate/statistical_inefficiency_dhdl.py +++ b/uncorrelate/statistical_inefficiency_dhdl.py @@ -61,7 +61,7 @@ def uncorrelate(self, dfs, lower): print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) uncorrelated_dfs.append(uncorrelated_df) - return pandas.concat(uncorrelated_dfs) + return uncorrelated_dfs def get_plugin(*args): diff --git a/uncorrelate/statistical_inefficiency_dhdl_all.py b/uncorrelate/statistical_inefficiency_dhdl_all.py index 2e41ea1..b180802 100644 --- a/uncorrelate/statistical_inefficiency_dhdl_all.py +++ b/uncorrelate/statistical_inefficiency_dhdl_all.py @@ -38,7 +38,7 @@ def uncorrelate(self, dfs, lower): print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) uncorrelated_dfs.append(uncorrelated_df) - return pandas.concat(uncorrelated_dfs) + return uncorrelated_dfs def get_plugin(*args): From d7840cb61f9bcd6d9db800a744914b10ff5eb3c8 Mon Sep 17 00:00:00 2001 From: daniel Date: Sat, 31 Jul 2021 11:04:16 +0200 Subject: [PATCH 3/7] uncorrelation thresholds for dhdl correlators --- flamel.py | 2 +- uncorrelate/statistical_inefficiency_dhdl.py | 12 ++++++++++-- uncorrelate/statistical_inefficiency_dhdl_all.py | 12 ++++++++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/flamel.py b/flamel.py index dd8565d..e4345f9 100755 --- a/flamel.py +++ b/flamel.py @@ -107,7 +107,7 @@ def main(): args = parser.parse_args() parser = load_plugin_by_name('parser', args.software, args.temperature, args.prefix, args.suffix) - uncorrelator = load_plugin_by_name('uncorrelate', args.uncorr) + uncorrelator = load_plugin_by_name('uncorrelate', args.uncorr, args.uncorr_threshold) outputs = load_plugins('output', argsplit(args.output)) estimators = load_plugins('estimator', argsplit(args.estimators)) diff --git a/uncorrelate/statistical_inefficiency_dhdl.py b/uncorrelate/statistical_inefficiency_dhdl.py index 05673f3..0827be1 100644 --- a/uncorrelate/statistical_inefficiency_dhdl.py +++ b/uncorrelate/statistical_inefficiency_dhdl.py @@ -11,6 +11,10 @@ class StatisticalInefficiencyDhdl: needs_u_nks = False dhdl = None + uncorr_threshold = None + + def __init__(self, uncorr_threshold): + self.uncorr_threshold = uncorr_threshold def set_dhdls(self, dhdls): """ @@ -59,7 +63,11 @@ def uncorrelate(self, dfs, lower): N, N_k = len(df), len(uncorrelated_df) g = N/N_k print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) - uncorrelated_dfs.append(uncorrelated_df) + if N_k < self.uncorr_threshold: + print("WARNING: Only %d uncorrelated samples found at lambda number %d; proceeding with analysis using correlated samples..." % (N_k, idx)) + uncorrelated_dfs.append(df) + else: + uncorrelated_dfs.append(uncorrelated_df) return uncorrelated_dfs @@ -70,4 +78,4 @@ def get_plugin(*args): :return: Statitical inefficiency uncorrelator """ - return StatisticalInefficiencyDhdl() + return StatisticalInefficiencyDhdl(*args) diff --git a/uncorrelate/statistical_inefficiency_dhdl_all.py b/uncorrelate/statistical_inefficiency_dhdl_all.py index b180802..445a04e 100644 --- a/uncorrelate/statistical_inefficiency_dhdl_all.py +++ b/uncorrelate/statistical_inefficiency_dhdl_all.py @@ -11,6 +11,10 @@ class StatisticalInefficiencyDhdlAll: needs_u_nks = False dhdl = None + uncorr_threshold = None + + def __init__(self, uncorr_threshold): + self.uncorr_threshold = uncorr_threshold def set_dhdls(self, dhdls): """ @@ -36,7 +40,11 @@ def uncorrelate(self, dfs, lower): N, N_k = len(df), len(uncorrelated_df) g = N/N_k print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) - uncorrelated_dfs.append(uncorrelated_df) + if N_k < self.uncorr_threshold: + print("WARNING: Only %d uncorrelated samples found at lambda number %d; proceeding with analysis using correlated samples..." % (N_k, idx)) + uncorrelated_dfs.append(df) + else: + uncorrelated_dfs.append(uncorrelated_df) return uncorrelated_dfs @@ -47,4 +55,4 @@ def get_plugin(*args): :return: Statitical inefficiency uncorrelator using a sum of all dhdls """ - return StatisticalInefficiencyDhdlAll() + return StatisticalInefficiencyDhdlAll(*args) From efe08845f241935b4aba518eb32e6cae0a5602e1 Mon Sep 17 00:00:00 2001 From: daniel Date: Sat, 31 Jul 2021 11:05:29 +0200 Subject: [PATCH 4/7] update readme --- README.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/README.md b/README.md index d11a92f..86e8c9f 100644 --- a/README.md +++ b/README.md @@ -117,10 +117,4 @@ Alchemical Analysis with the same input files: Coulomb: -41.067 +- 0.180 -41.022 +- 0.129 -41.096 +- 0.170 vdWaals: 11.912 +- 0.160 11.954 +- 0.111 12.022 +- 0.139 TOTAL: -29.154 +- 0.241 -29.067 +- 0.170 -29.074 +- 0.220 -``` - -# Planed features: -- **Output of statistical inefficiencies** -alchemical-analysis offers information about the statistical inefficiencies of the input datasets. -- **Uncorrelation threshold** -In alchemical-analysis it is possible to specify a threshold for the number of samples to keep in the uncorrelation process. +``` \ No newline at end of file From 3b4cbe920f023a9091b16b85af6f27d8956e9dfb Mon Sep 17 00:00:00 2001 From: daniel Date: Mon, 2 Aug 2021 08:02:38 +0200 Subject: [PATCH 5/7] python 3 string formatting --- uncorrelate/statistical_inefficiency_dhdl.py | 7 ++++--- uncorrelate/statistical_inefficiency_dhdl_all.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/uncorrelate/statistical_inefficiency_dhdl.py b/uncorrelate/statistical_inefficiency_dhdl.py index 0827be1..9c74ef8 100644 --- a/uncorrelate/statistical_inefficiency_dhdl.py +++ b/uncorrelate/statistical_inefficiency_dhdl.py @@ -54,7 +54,8 @@ def uncorrelate(self, dfs, lower): dl.append(dli) uncorrelated_dfs = [] - print("Number of correlated and uncorrelated samples (Method=%s):\n\n%6s %12s %12s %12s\n" % ("dHdl", "State", "N", "N_k", "N/N_k")) + print("Number of correlated and uncorrelated samples (Method=dHdl (all)):") + print("\n State N N_k N/N_k\n") for idx, (dhdl_, l, df) in enumerate(zip(self.dhdls, dl, dfs)): ind = np.array(l, dtype=bool) ind = np.array(ind, dtype=int) @@ -62,9 +63,9 @@ def uncorrelate(self, dfs, lower): uncorrelated_df = alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False) N, N_k = len(df), len(uncorrelated_df) g = N/N_k - print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) + print(f"{idx:>6} {N:>12} {N_k:>12} {g:>12.2f}") if N_k < self.uncorr_threshold: - print("WARNING: Only %d uncorrelated samples found at lambda number %d; proceeding with analysis using correlated samples..." % (N_k, idx)) + print(f"WARNING: Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") uncorrelated_dfs.append(df) else: uncorrelated_dfs.append(uncorrelated_df) diff --git a/uncorrelate/statistical_inefficiency_dhdl_all.py b/uncorrelate/statistical_inefficiency_dhdl_all.py index 445a04e..2e31ec5 100644 --- a/uncorrelate/statistical_inefficiency_dhdl_all.py +++ b/uncorrelate/statistical_inefficiency_dhdl_all.py @@ -33,15 +33,16 @@ def uncorrelate(self, dfs, lower): """ uncorrelated_dfs = [] - print("Number of correlated and uncorrelated samples (Method=%s):\n\n%6s %12s %12s %12s\n" % ("dHdl (all)", "State", "N", "N_k", "N/N_k")) + print("Number of correlated and uncorrelated samples (Method=dHdl (all)):") + print("\n State N N_k N/N_k\n") for idx, (dhdl_, df) in enumerate(zip(self.dhdls, dfs)): dhdl_sum = dhdl_.sum(axis=1) uncorrelated_df = alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False) N, N_k = len(df), len(uncorrelated_df) g = N/N_k - print("%6s %12s %12s %12.2f" % (idx, N, N_k, g)) + print(f"{idx:>6} {N:>12} {N_k:>12} {g:>12.2f}") if N_k < self.uncorr_threshold: - print("WARNING: Only %d uncorrelated samples found at lambda number %d; proceeding with analysis using correlated samples..." % (N_k, idx)) + print(f"WARNING: Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") uncorrelated_dfs.append(df) else: uncorrelated_dfs.append(uncorrelated_df) From 5c41cc299e449eec24fc70e187590c973012695f Mon Sep 17 00:00:00 2001 From: daniel Date: Mon, 2 Aug 2021 08:07:04 +0200 Subject: [PATCH 6/7] warnings use warnings.warn --- uncorrelate/statistical_inefficiency_dhdl.py | 3 ++- uncorrelate/statistical_inefficiency_dhdl_all.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/uncorrelate/statistical_inefficiency_dhdl.py b/uncorrelate/statistical_inefficiency_dhdl.py index 9c74ef8..24eefc9 100644 --- a/uncorrelate/statistical_inefficiency_dhdl.py +++ b/uncorrelate/statistical_inefficiency_dhdl.py @@ -1,6 +1,7 @@ import alchemlyb.preprocessing import pandas import numpy as np +from warnings import warn # Todo: Use interface here @@ -65,7 +66,7 @@ def uncorrelate(self, dfs, lower): g = N/N_k print(f"{idx:>6} {N:>12} {N_k:>12} {g:>12.2f}") if N_k < self.uncorr_threshold: - print(f"WARNING: Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") + warn(f"Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") uncorrelated_dfs.append(df) else: uncorrelated_dfs.append(uncorrelated_df) diff --git a/uncorrelate/statistical_inefficiency_dhdl_all.py b/uncorrelate/statistical_inefficiency_dhdl_all.py index 2e31ec5..e19c4e3 100644 --- a/uncorrelate/statistical_inefficiency_dhdl_all.py +++ b/uncorrelate/statistical_inefficiency_dhdl_all.py @@ -1,6 +1,7 @@ import alchemlyb.preprocessing import pandas import numpy as np +from warnings import warn # Todo: Use interface here @@ -42,7 +43,7 @@ def uncorrelate(self, dfs, lower): g = N/N_k print(f"{idx:>6} {N:>12} {N_k:>12} {g:>12.2f}") if N_k < self.uncorr_threshold: - print(f"WARNING: Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") + warn(f"Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") uncorrelated_dfs.append(df) else: uncorrelated_dfs.append(uncorrelated_df) From 639a5d51c513bda8c0a9f51cda6e1eb23cabd741 Mon Sep 17 00:00:00 2001 From: daniel Date: Mon, 2 Aug 2021 08:29:38 +0200 Subject: [PATCH 7/7] baseclass for statistical inefficiency --- uncorrelate/statistical_inefficiency.py | 22 ++++++++++++++++++ uncorrelate/statistical_inefficiency_dhdl.py | 23 ++++++------------- .../statistical_inefficiency_dhdl_all.py | 21 +++++------------ 3 files changed, 35 insertions(+), 31 deletions(-) create mode 100644 uncorrelate/statistical_inefficiency.py diff --git a/uncorrelate/statistical_inefficiency.py b/uncorrelate/statistical_inefficiency.py new file mode 100644 index 0000000..bb9b6a3 --- /dev/null +++ b/uncorrelate/statistical_inefficiency.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from warnings import warn + +class StatisticalInefficiency(ABC): + + uncorr_threshold = None + + def __init__(self, uncorr_threshold): + self.uncorr_threshold = uncorr_threshold + + def check_sample_size(self, idx, df, uncorrelated_df): + N, N_k = len(df), len(uncorrelated_df) + g = N/N_k + print(f"{idx:>6} {N:>12} {N_k:>12} {g:>12.2f}") + if N_k < self.uncorr_threshold: + warn(f"Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") + return False + return True + + @abstractmethod + def uncorrelate(self, dfs, lower): + pass diff --git a/uncorrelate/statistical_inefficiency_dhdl.py b/uncorrelate/statistical_inefficiency_dhdl.py index 24eefc9..55ead0f 100644 --- a/uncorrelate/statistical_inefficiency_dhdl.py +++ b/uncorrelate/statistical_inefficiency_dhdl.py @@ -1,21 +1,16 @@ import alchemlyb.preprocessing import pandas import numpy as np -from warnings import warn +from uncorrelate.statistical_inefficiency import StatisticalInefficiency -# Todo: Use interface here -class StatisticalInefficiencyDhdl: +class StatisticalInefficiencyDhdl(StatisticalInefficiency): name = 'dhdl' needs_dhdls = True needs_u_nks = False dhdl = None - uncorr_threshold = None - - def __init__(self, uncorr_threshold): - self.uncorr_threshold = uncorr_threshold def set_dhdls(self, dhdls): """ @@ -55,22 +50,18 @@ def uncorrelate(self, dfs, lower): dl.append(dli) uncorrelated_dfs = [] - print("Number of correlated and uncorrelated samples (Method=dHdl (all)):") + print(f"Number of correlated and uncorrelated samples (Method={self.name}):") print("\n State N N_k N/N_k\n") for idx, (dhdl_, l, df) in enumerate(zip(self.dhdls, dl, dfs)): ind = np.array(l, dtype=bool) ind = np.array(ind, dtype=int) dhdl_sum = dhdl_.dot(ind) uncorrelated_df = alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False) - N, N_k = len(df), len(uncorrelated_df) - g = N/N_k - print(f"{idx:>6} {N:>12} {N_k:>12} {g:>12.2f}") - if N_k < self.uncorr_threshold: - warn(f"Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") - uncorrelated_dfs.append(df) - else: + if self.check_sample_size(idx, df, uncorrelated_df): uncorrelated_dfs.append(uncorrelated_df) - + else: + uncorrelated_dfs.append(df) + return uncorrelated_dfs diff --git a/uncorrelate/statistical_inefficiency_dhdl_all.py b/uncorrelate/statistical_inefficiency_dhdl_all.py index e19c4e3..98c70b4 100644 --- a/uncorrelate/statistical_inefficiency_dhdl_all.py +++ b/uncorrelate/statistical_inefficiency_dhdl_all.py @@ -1,21 +1,16 @@ import alchemlyb.preprocessing import pandas import numpy as np -from warnings import warn +from uncorrelate.statistical_inefficiency import StatisticalInefficiency -# Todo: Use interface here -class StatisticalInefficiencyDhdlAll: +class StatisticalInefficiencyDhdlAll(StatisticalInefficiency): name = 'dhdl_all' needs_dhdls = True needs_u_nks = False dhdl = None - uncorr_threshold = None - - def __init__(self, uncorr_threshold): - self.uncorr_threshold = uncorr_threshold def set_dhdls(self, dhdls): """ @@ -34,19 +29,15 @@ def uncorrelate(self, dfs, lower): """ uncorrelated_dfs = [] - print("Number of correlated and uncorrelated samples (Method=dHdl (all)):") + print(f"Number of correlated and uncorrelated samples (Method=f{self.name}):") print("\n State N N_k N/N_k\n") for idx, (dhdl_, df) in enumerate(zip(self.dhdls, dfs)): dhdl_sum = dhdl_.sum(axis=1) uncorrelated_df = alchemlyb.preprocessing.statistical_inefficiency(df, dhdl_sum, lower, conservative=False) - N, N_k = len(df), len(uncorrelated_df) - g = N/N_k - print(f"{idx:>6} {N:>12} {N_k:>12} {g:>12.2f}") - if N_k < self.uncorr_threshold: - warn(f"Only {N_k} uncorrelated samples found at lambda number {idx}; proceeding with analysis using correlated samples...") - uncorrelated_dfs.append(df) - else: + if self.check_sample_size(idx, df, uncorrelated_df): uncorrelated_dfs.append(uncorrelated_df) + else: + uncorrelated_dfs.append(df) return uncorrelated_dfs