diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index f96ff023..8c93c8e3 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -97,9 +97,8 @@ def __init__( self._max_treedepths: np.ndarray = np.zeros( self.runset.chains, dtype=int ) - self._chain_time: List[Dict[str, float]] = [] - # info from CSV header and initial and final comment blocks + # info from CSV initial comments and header config = self._validate_csv_files() self._metadata: InferenceMetadata = InferenceMetadata(config) if not self._is_fixed_param: @@ -241,14 +240,6 @@ def max_treedepths(self) -> Optional[np.ndarray]: """ return self._max_treedepths if not self._is_fixed_param else None - @property - def time(self) -> List[Dict[str, float]]: - """ - List of per-chain time info scraped from CSV file. - Each chain has dict with keys "warmup", "sampling", "total". - """ - return self._chain_time - def draws( self, *, inc_warmup: bool = False, concat_chains: bool = False ) -> np.ndarray: @@ -310,7 +301,6 @@ def _validate_csv_files(self) -> Dict[str, Any]: save_warmup=self._save_warmup, thin=self._thin, ) - self._chain_time.append(dzero['time']) # type: ignore if not self._is_fixed_param: self._divergences[i] = dzero['ct_divergences'] self._max_treedepths[i] = dzero['ct_max_treedepth'] @@ -323,7 +313,6 @@ def _validate_csv_files(self) -> Dict[str, Any]: save_warmup=self._save_warmup, thin=self._thin, ) - self._chain_time.append(drest['time']) # type: ignore for key in dzero: # check args that matter for parsing, plus name, version if ( diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index a50d9657..b7a3b21c 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -79,7 +79,6 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]: lineno = scan_warmup_iters(fd, dict, lineno) lineno = scan_hmc_params(fd, dict, lineno) lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param) - lineno = scan_time(fd, dict, lineno) except ValueError as e: raise ValueError("Error in reading csv file: " + path) from e return dict @@ -382,66 +381,6 @@ def scan_sampling_iters( return lineno -def scan_time(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int: - """ - Scan time information from the trailing comment lines in a Stan CSV file. - - # Elapsed Time: 0.001332 seconds (Warm-up) - # 0.000249 seconds (Sampling) - # 0.001581 seconds (Total) - - - It extracts the time values and saves them in the config_dict: key 'time', - value a dictionary with keys 'warmup', 'sampling', and 'total'. - Returns the updated line number after reading the time info. - - :param fd: Open file descriptor at comment row following all sample data. - :param config_dict: Dictionary to which the time info is added. - :param lineno: Current line number - """ - time = {} - keys = ['warmup', 'sampling', 'total'] - while True: - pos = fd.tell() - line = fd.readline() - if not line: - break - lineno += 1 - stripped = line.strip() - if not stripped.startswith('#'): - fd.seek(pos) - lineno -= 1 - break - content = stripped.lstrip('#').strip() - if not content: - continue - tokens = content.split() - if len(tokens) < 3: - raise ValueError(f"Invalid time at line {lineno}: {content}") - if 'Warm-up' in content: - key = 'warmup' - time_str = tokens[2] - elif 'Sampling' in content: - key = 'sampling' - time_str = tokens[0] - elif 'Total' in content: - key = 'total' - time_str = tokens[0] - else: - raise ValueError(f"Invalid time at line {lineno}: {content}") - try: - t = float(time_str) - except ValueError as e: - raise ValueError(f"Invalid time at line {lineno}: {content}") from e - time[key] = t - - if not all(key in time for key in keys): - raise ValueError(f"Invalid time, stopped at {lineno}") - - config_dict['time'] = time - return lineno - - def read_metric(path: str) -> List[int]: """ Read metric file in JSON or Rdump format. diff --git a/test/test_sample.py b/test/test_sample.py index 3c987ddc..944ce7a6 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -1714,12 +1714,6 @@ def test_metadata() -> None: assert fit.column_names == col_names assert fit.metric_type == 'diag_e' - assert len(fit.time) == 4 - for i in range(4): - assert 'warmup' in fit.time[i].keys() - assert 'sampling' in fit.time[i].keys() - assert 'total' in fit.time[i].keys() - assert fit.metadata.cmdstan_config['num_samples'] == 100 assert fit.metadata.cmdstan_config['thin'] == 1 assert fit.metadata.cmdstan_config['algorithm'] == 'hmc' diff --git a/test/test_utils.py b/test/test_utils.py index 27e5db53..397f3316 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -702,57 +702,3 @@ def test_munge_varnames() -> None: var = 'y.2.3:1.2:5:6' assert stancsv.munge_varname(var) == 'y[2,3].1[2].5.6' - - -def test_scan_time_normal() -> None: - csv_content = ( - "# Elapsed Time: 0.005 seconds (Warm-up)\n" - "# 0 seconds (Sampling)\n" - "# 0.005 seconds (Total)\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - final_line = stancsv.scan_time(fd, config_dict, start_line) - assert final_line == 3 - expected = {'warmup': 0.005, 'sampling': 0.0, 'total': 0.005} - assert config_dict.get('time') == expected - - -def test_scan_time_no_timing() -> None: - csv_content = ( - "# merrily we roll along\n" - "# roll along\n" - "# very merrily we roll along\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - with pytest.raises(ValueError, match="Invalid time"): - stancsv.scan_time(fd, config_dict, start_line) - - -def test_scan_time_invalid_value() -> None: - csv_content = ( - "# Elapsed Time: abc seconds (Warm-up)\n" - "# 0.200 seconds (Sampling)\n" - "# 0.300 seconds (Total)\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - with pytest.raises(ValueError, match="Invalid time"): - stancsv.scan_time(fd, config_dict, start_line) - - -def test_scan_time_invalid_string() -> None: - csv_content = ( - "# Elapsed Time: 0.22 seconds (foo)\n" - "# 0.200 seconds (Sampling)\n" - "# 0.300 seconds (Total)\n" - ) - fd = io.StringIO(csv_content) - config_dict = {} - start_line = 0 - with pytest.raises(ValueError, match="Invalid time"): - stancsv.scan_time(fd, config_dict, start_line)