Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying different bins for visualization and computation. #190

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "1.0.12"
version = "1.0.13"
description = "Data filters"
license = { text = "Apache-2.0" }
readme = "README.md"
Expand Down
20 changes: 15 additions & 5 deletions python/dolma/cli/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@
from dolma.core.paths import glob_path


@dataclass
class BinsConfig:
compute: int = field(
default=1_000,
help="Number of bins to use to compute the histograms.",
)
visualization: int = field(
default=10,
help="Number of bins to use when visualizing the histograms.",
)


@dataclass
class AnalyzerConfig:
attributes: List[str] = field(
Expand All @@ -22,10 +34,7 @@ class AnalyzerConfig:
"If not provided, the report will be printed to stdout."
),
)
bins: int = field(
default=1_000,
help="Number of bins to use for the histograms.",
)
bins: BinsConfig = field(default=BinsConfig(), help="Configuration for the bins to use for the histograms.")
processes: int = field(
default=1,
help="Number of parallel processes to use.",
Expand Down Expand Up @@ -80,7 +89,8 @@ def run(cls, parsed_config: AnalyzerConfig):
metadata_path=work_dirs.input,
debug=parsed_config.debug,
seed=parsed_config.seed,
num_bins=parsed_config.bins,
compute_bins=parsed_config.bins.compute,
visualize_bins=parsed_config.bins.visualization,
num_processes=parsed_config.processes,
name_regex=parsed_config.regex,
show_total=parsed_config.total,
Expand Down
29 changes: 18 additions & 11 deletions python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class MixerConfig:
default=False,
help="If true, only print the configuration and exit without running the mixer.",
)
skip_checks: bool = field(
default=False,
help="If true, skip checks on paths (e.g. validation, globbing). Useful in case many paths are being evaluated.",
)


class MixerCli(BaseCli):
Expand Down Expand Up @@ -141,19 +145,22 @@ def run(cls, parsed_config: MixerConfig):
if "span_replacement" not in stream_config_dict and "filter" not in stream_config_dict:
raise DolmaConfigError("Either `filter` or `span_replacement` must be specified")

# perform some path validation to make sure we don't call the mixer with invalid config
total_matching_documents = 0
for document in stream_config.documents:
if not parsed_config.skip_checks:
# perform some path validation to make sure we don't call the mixer with invalid config
total_matching_documents = 0
for document in stream_config.documents:

current_matching_documents = sum(1 for _ in glob_path(document))
if current_matching_documents == 0:
# only raise a warning if no documents are found for a single path
logger.warning("No documents found for path %s", document)
total_matching_documents += current_matching_documents
current_matching_documents = sum(1 for _ in glob_path(document))
if current_matching_documents == 0:
# only raise a warning if no documents are found for a single path
logger.warning("No documents found for path %s", document)
total_matching_documents += current_matching_documents

if total_matching_documents == 0:
# but raise an error if no documents are found for all paths
raise DolmaConfigError(f"No documents found for the paths for {stream_config.name} config.")
if total_matching_documents == 0:
# but raise an error if no documents are found for all paths
raise DolmaConfigError(
f"No documents found for the paths for {stream_config.name} config."
)

# populate the stream config dict
stream_config_dict["name"] = stream_config.name
Expand Down
12 changes: 7 additions & 5 deletions python/dolma/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def create_and_run_analyzer(
report: Optional[str] = None,
debug: bool = False,
seed: int = 0,
num_bins: int = 1000,
compute_bins: int = 1_000,
visualize_bins: int = 10,
num_processes: int = 1,
name_regex: Optional[str] = None,
show_total: bool = False,
Expand All @@ -300,7 +301,8 @@ def create_and_run_analyzer(
report (Optional[str], optional): Path to the report directory. Defaults to None.
debug (bool, optional): Enable debug mode. Defaults to False.
seed (int, optional): Seed value for randomization. Defaults to 0.
num_bins (int, optional): Number of bins for analysis. Defaults to 1000.
compute_bins (int, optional): Number of bins for analysis. Defaults to 1_000.
visualize_bins (int, optional): Number of bins for visualization. Defaults to 10.
num_processes (int, optional): Number of processes to use for analysis. Defaults to 1.
name_regex (Optional[str], optional): Regular expression for filtering attribute names. Defaults to None.
show_total (bool, optional): Show total summary. Defaults to False.
Expand Down Expand Up @@ -328,8 +330,8 @@ def create_and_run_analyzer(
retries_on_error=0,
num_processes=num_processes,
)
analyzer(num_bins=num_bins, name_regex=name_regex)
analyzer(num_bins=compute_bins, name_regex=name_regex)

summaries = aggregate_summaries(summaries_path=summaries_path, num_bins=num_bins)
visualize_summaries(summaries=summaries, show_total=show_total)
summaries = aggregate_summaries(summaries_path=summaries_path, num_bins=compute_bins)
visualize_summaries(summaries=summaries, show_total=show_total, num_viz_bins=visualize_bins)
write_output(summaries=summaries, report=report)
Loading