Skip to content

Commit

Permalink
Add LineStatsCC
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff authored Nov 2, 2023
1 parent 90a07ac commit 6eb427b
Showing 1 changed file with 115 additions and 10 deletions.
125 changes: 115 additions & 10 deletions scripts/dolma_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,24 +541,19 @@ def process_single(
"gopher_count": 0,
"gopher_length": 0,
"gopher_matches": 0,
"gopher_spans": [],
"decontamination_count": 0,
"decontamination_length": 0,
"decontamination_matches": 0,
"decontamination_spans": [],
"dedupe_paragraphs_count": 0,
"dedupe_paragraphs_length": 0,
"dedupe_paragraphs_matches": 0,
"dedupe_paragraphs_spans": [],
"hatespeech_nsfw_count": 0,
"hatespeech_nsfw_length": 0,
"hatespeech_nsfw_matches": 0,
"hatespeech_nsfw_spans": [],
"pii_count": 0,
"pii_length": 0,
"pii_matches_le_5": 0,
"pii_matches_gt_5": 0,
"pii_spans": [],
}
documents = 0
interval = 10_000
Expand All @@ -585,14 +580,12 @@ def process_single(
stats["gopher_count"] += len(gopher_removal)
stats["gopher_length"] += sum(s[1] - s[0] for s in gopher_removal)
stats["gopher_matches"] += 1 if gopher_removal else 0
stats["gopher_spans"] = gopher_removal

# Deduplication stats
decontamination_removal = attrs.get("bff_duplicate_paragraph_spans_decontamination", [])
stats["decontamination_count"] += len(decontamination_removal)
stats["decontamination_length"] += sum(s[1] - s[0] for s in decontamination_removal)
stats["decontamination_matches"] += 1 if decontamination_removal else 0
stats["decontamination_spans"] = decontamination_removal

# jigsaw stats
jigsaw_match: List[Tuple[int, int, float]] = []
Expand All @@ -611,7 +604,6 @@ def process_single(
stats["hatespeech_nsfw_count"] += len(jigsaw_match)
stats["hatespeech_nsfw_length"] += sum(s[1] - s[0] for s in jigsaw_match)
stats["hatespeech_nsfw_matches"] += 1 if jigsaw_match else 0
stats["hatespeech_nsfw_spans"] = jigsaw_match

# PII stats
pii_removal = (
Expand All @@ -623,14 +615,12 @@ def process_single(
stats["pii_length"] += sum(s[1] - s[0] for s in pii_removal)
stats["pii_matches_le_5"] += 1 if 0 < len(pii_removal) <= 5 else 0
stats["pii_matches_gt_5"] += 1 if len(pii_removal) > 5 else 0
stats["pii_spans"] = pii_removal

# Duplicates stats
dups = [p for p in attrs.get("bff_duplicate_paragraph_spans", []) if p[1] - p[0] > 0]
stats["dedupe_paragraphs_count"] += len(dups)
stats["dedupe_paragraphs_length"] += sum(s[1] - s[0] for s in dups)
stats["dedupe_paragraphs_matches"] += 1 if dups else 0
stats["dedupe_paragraphs_spans"] = dups

documents += 1

Expand All @@ -649,6 +639,121 @@ class v15_cc_c4_cleaned(cc_v1_c4_cleaned):
stats = "s3://ai2-llm/stats/olmo-mix/v15/cc/v1_c4_cleaned/**/*.gz"
decontamination_key: str = 'perplexity_suite_v3_option2'

@Registry.add
class LineStatsCC(cc_v1_c4_cleaned):
# Selection of documents:
# import random; print([random.randint(0, 1334) for _ in range(10)])
documents = [
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0700.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0724.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0788.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-1286.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0600.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0752.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0239.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-1270.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0786.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0857.json.gz",
]
stats = [
"./cc_en_head-0700-stats.json.gz",
"./cc_en_head-0724-stats.json.gz",
"./cc_en_head-0788-stats.json.gz",
"./cc_en_head-1286-stats.json.gz",
"./cc_en_head-0600-stats.json.gz",
"./cc_en_head-0752-stats.json.gz",
"./cc_en_head-0239-stats.json.gz",
"./cc_en_head-1270-stats.json.gz",
"./cc_en_head-0786-stats.json.gz",
"./cc_en_head-0857-stats.json.gz",
]
decontamination_key: str = 'decontamination'

@classmethod
def cli(cls, num_workers: int = 1, debug: bool = False, **process_single_kwargs: Any) -> None:
cls._run_parallel_processor(
stats_root=cls.stats,
num_workers=num_workers,
debug=debug,
**process_single_kwargs,
)

@classmethod
def process_single(
cls, source_path: str, destination_path: str, queue: "Queue[Union[Tuple[int, ...], None]]", **kwargs: Any
):
attributes = [
source_path.replace("/documents/", "/attributes/gopher_rules/"),
source_path.replace("/documents/", f"/attributes/{cls.decontamination_key}/"),
source_path.replace("/documents/", "/attributes/hatespeech_nsfw_cc_v3/"),
source_path.replace("/documents/", "/attributes/pii_detection/"),
source_path.replace("/documents/", "/attributes/dedupe_paragraphs/"),
]

doc_decoder = msgspec.json.Decoder(InputSpec)
attr_decoder = msgspec.json.Decoder(OutputSpec)
documents = 0
interval = 10_000

with ExitStack() as stack:
doc_file = stack.enter_context(smart_open.open(source_path, "rb"))
out_file = stack.enter_context(smart_open.open(destination_path, "wt"))

try:
atts_files = [stack.enter_context(smart_open.open(path, "rb")) for path in attributes]
except Exception:
return

for doc_line, *attr_lines in zip(doc_file, *atts_files):
doc = doc_decoder.decode(doc_line)
attrs = {}
for line in attr_lines:
attrs.update(attr_decoder.decode(line).attributes)
out_line = {}

# Gopher stats
gopher_removal = cls.gopher_rules(attrs)
out_line["gopher_spans"] = gopher_removal

# Deduplication stats
decontamination_removal = attrs.get("bff_duplicate_paragraph_spans_decontamination", [])
out_line["decontamination_spans"] = decontamination_removal

# jigsaw stats
jigsaw_match: List[Tuple[int, int, float]] = []
nsfw = attrs.get("hatespeech_nsfw_cc_v3__jigsaw_nsfw_sencence_v2____label__nsfw", [])
for span in nsfw:
if span[2] > 0.4:
bisect.insort(jigsaw_match, (span[0], span[1], 1.0))

toxic = attrs.get("hatespeech_nsfw_cc_v3__jigsaw_hatespeech_sentence_v2____label__toxic", [])
for span in toxic:
if span[2] > 0.4:
bisect.insort(jigsaw_match, (span[0], span[1], 1.0))

jigsaw_match = cls._merge_spans(jigsaw_match)
out_line["hatespeech_spans"] = jigsaw_match

# PII stats
pii_removal = (
attrs.get("pii_detection__pii_regex_with_counts_fast_v2__EMAIL_ADDRESS", [])
+ attrs.get("pii_detection__pii_regex_with_counts_fast_v2__PHONE_NUMBER", [])
+ attrs.get("pii_detection__pii_regex_with_counts_fast_v2__IP_ADDRESS", [])
)
out_line["pii_spans"] = pii_removal

# Duplicates stats
dups = [p for p in attrs.get("bff_duplicate_paragraph_spans", []) if p[1] - p[0] > 0]
out_line["dedupe_paragraphs_spans"] = dups

documents += 1

if documents % interval == 0:
cls.increment_progressbar(queue, documents=interval)

out_file.write(json.dumps(out_line) + "\n")

cls.increment_progressbar(queue, files=1, documents=documents % interval)

class C4InputSpec(InputSpec):
metadata: Dict[str, Any] = msgspec.field(default_factory=dict)
Expand Down

0 comments on commit 6eb427b

Please sign in to comment.