From 4f36f7650c1870078016be51a6fe46a667abf775 Mon Sep 17 00:00:00 2001 From: mitya Date: Tue, 16 Apr 2024 14:59:04 +0300 Subject: [PATCH 1/4] attempt to train on jsonl preprocessed samples --- refact_data_pipeline/finetune_datasource.py | 77 ++++++++++++++++++- .../configuration/supported_models.py | 4 +- .../finetune/scripts/finetune_train.py | 31 +++++++- 3 files changed, 108 insertions(+), 4 deletions(-) diff --git a/refact_data_pipeline/finetune_datasource.py b/refact_data_pipeline/finetune_datasource.py index cb86a18c..a0fdef33 100644 --- a/refact_data_pipeline/finetune_datasource.py +++ b/refact_data_pipeline/finetune_datasource.py @@ -131,7 +131,7 @@ def files_len(self) -> int: def set_epoch_callback(self, callback): assert len(self._pipeline) > 0 file_reader = self._pipeline[0] - assert type(file_reader) is ReadFileByFile + assert isinstance(file_reader, ReadFileByFile) file_reader.set_epoch_callback(callback) def set_random_state(self, seed): @@ -190,3 +190,78 @@ def _build_pipeline(self, files: List[Dict[str, Any]]): dp = pp.DensePacker(fim, self._ds_options) shf = pp.Shuffle(dp, self._ds_options) return [read_by_file, fim, dp, shf] + + +class ReadJSONLFileByFile(ReadFileByFile): + def __iter__(self): + sample_num = 0 + file_num = 0 + epoch = 0 + while True: + for j in self.inner_filter: + with jsonlines.open(os.path.join(self.basedir, j["path"])) as r: + for data in r: + if not data["middle"]: + continue + yield { + **data, + "path": j["path"], + "stats": { + "sample_num": sample_num, + "file_num": file_num, + "epoch": epoch, + }, + } + sample_num += 1 + file_num += 1 + epoch += 1 + self.epoch_callback(epoch) + if epoch == self.quit_on_epoch: + break + + +class RAGFIM(PipelineNode): + def __init__( + self, + inner_filter, + dataopts: DatasetOpts, + ): + self.enc = dataopts.encoding + super().__init__(dataopts) + self.inner_filter = inner_filter + self.n_ctx = dataopts.get("n_ctx", 2048) + self.debug = bool(dataopts.get("debug", 0)) + self.special_tokens = [ + self.enc.PREFIX, + self.enc.SUFFIX, + self.enc.INFILL, + self.enc.EOT, + ] + assert len(set(self.special_tokens)) == len(self.special_tokens) + + def __iter__(self): + stats: Dict[str, Any] = { + "fim_out": 0, + } + for sample in self.inner_filter: + tokens = self.enc.encode(sample["prompt"]) + self.enc.encode(sample["middle"]) + mask = [0 if t in self.special_tokens else 1 for t in tokens] + + if len(tokens) + 1 > self.n_ctx: + continue + + stats["fim_out"] += 1 + yield { + "tokens": tokens + [self.enc.EOT], + "mask": mask + [1], + "stats": {**sample["stats"], **stats}, + } + + +class RefactRAGFIMDataset(RefactDataset): + def _build_pipeline(self, files: List[Dict[str, Any]]): + read_by_file = ReadJSONLFileByFile(self.basedir, files, self._ds_options) + fim = RAGFIM(read_by_file, self._ds_options) + dp = pp.DensePacker(fim, self._ds_options) + shf = pp.Shuffle(dp, self._ds_options) + return [read_by_file, fim, dp, shf] diff --git a/self_hosting_machinery/finetune/configuration/supported_models.py b/self_hosting_machinery/finetune/configuration/supported_models.py index 97badd59..1cabe19d 100644 --- a/self_hosting_machinery/finetune/configuration/supported_models.py +++ b/self_hosting_machinery/finetune/configuration/supported_models.py @@ -3,14 +3,14 @@ _fim_train_ds_pipeline = { "ds_opts": "n_ctx={n_ctx},debug=0,seed=42,shuffle_depth=256," "fim_probability=0.9,fim_drop_residual=1,random_trim_context_prob=0.01", - "ds_name": "RefactFIMCodeDataset" + "ds_name": "RefactRAGFIMDataset" } _fim_test_ds_pipeline = { "ds_opts": "n_ctx={n_ctx},debug=0,seed=42,shuffle_depth=0,quit_on_epoch=1," "fim_probability=0.9,fim_drop_residual=1,random_trim_context_prob=0.01," "pack_single=1,pack_complete=0,pack_buffer_size=50", - "ds_name": "RefactFIMCodeDataset" + "ds_name": "RefactRAGFIMDataset" } _bigcode_tokenizer_mapping = { "eot_idx": 0, diff --git a/self_hosting_machinery/finetune/scripts/finetune_train.py b/self_hosting_machinery/finetune/scripts/finetune_train.py index 7496c77d..41ab6213 100644 --- a/self_hosting_machinery/finetune/scripts/finetune_train.py +++ b/self_hosting_machinery/finetune/scripts/finetune_train.py @@ -167,6 +167,34 @@ def gpu_filter_and_build_config( return _build_finetune_config_by_heuristics(run_id, finetune_cfg, model_config, **kwargs) +def no_filter_build_config( + pname: str, + run_id: str, + model_name: str, + model_info: Dict[str, Any], + model_config: Dict[str, Any], + model_ctx_size: int, + **kwargs) -> Dict[str, Any]: + if model_ctx_size > 0: + model_info["T"] = model_ctx_size + finetune_cfg = { + **base_config(model_name=model_name, model_info=model_info), + **kwargs, + } + traces.log("locking \"%s\" for filtering" % pname) + if dist.get_rank() == 0: + with filelock.FileLock(env.PP_PROJECT_LOCK(pname)): + traces.log("locked \"%s\" successfully" % pname) + traces.log("completed filtering, now copy files to run \"%s\"" % run_id) + _copy_source_files( + env.PP_TRAIN_FILTERED_FILEPATH(pname), env.PERRUN_TRAIN_FILTERED_FILEPATH(run_id), pname, run_id) + _copy_source_files( + env.PP_TEST_FILTERED_FILEPATH(pname), env.PERRUN_TEST_FILTERED_FILEPATH(run_id), pname, run_id) + dist.barrier() + + return _build_finetune_config_by_heuristics(run_id, finetune_cfg, model_config, **kwargs) + + def _copy_source_files(jsonl_src, jsonl_dst, pname, run_id): for d in jsonlines.open(jsonl_src): try: @@ -357,7 +385,8 @@ def catch_sigusr1(signum, frame): status_tracker.update_status("working") _log_everywhere("Dest dir is %s" % traces.context().path) - finetune_cfg = gpu_filter_and_build_config(model_config=model_config, model_info=model_info, **vars(args)) + # finetune_cfg = gpu_filter_and_build_config(model_config=model_config, model_info=model_info, **vars(args)) + finetune_cfg = no_filter_build_config(model_config=model_config, model_info=model_info, **vars(args)) _log_everywhere(f"Building the model {finetune_cfg['model_name']}") model_context = ModelContext( From be002cd66d4214ce656e2ef0527cd2ebe5e87956 Mon Sep 17 00:00:00 2001 From: mitya Date: Mon, 22 Apr 2024 15:47:55 +0300 Subject: [PATCH 2/4] single iteration over all dataset --- refact_data_pipeline/finetune_datasource.py | 34 ++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/refact_data_pipeline/finetune_datasource.py b/refact_data_pipeline/finetune_datasource.py index a0fdef33..1c1e233a 100644 --- a/refact_data_pipeline/finetune_datasource.py +++ b/refact_data_pipeline/finetune_datasource.py @@ -197,23 +197,23 @@ def __iter__(self): sample_num = 0 file_num = 0 epoch = 0 - while True: - for j in self.inner_filter: - with jsonlines.open(os.path.join(self.basedir, j["path"])) as r: - for data in r: - if not data["middle"]: - continue - yield { - **data, - "path": j["path"], - "stats": { - "sample_num": sample_num, - "file_num": file_num, - "epoch": epoch, - }, - } - sample_num += 1 - file_num += 1 + + for j in self.inner_filter: + with jsonlines.open(os.path.join(self.basedir, j["path"])) as r: + for data in r: + if not data["middle"]: + continue + yield { + **data, + "path": j["path"], + "stats": { + "sample_num": sample_num, + "file_num": file_num, + "epoch": epoch, + }, + } + sample_num += 1 + file_num += 1 epoch += 1 self.epoch_callback(epoch) if epoch == self.quit_on_epoch: From a28005a7c227c8a37360104c080fff0fddfe8541 Mon Sep 17 00:00:00 2001 From: mitya Date: Mon, 29 Apr 2024 15:23:29 +0300 Subject: [PATCH 3/4] 0 mask for context part --- refact_data_pipeline/finetune_datasource.py | 28 ++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/refact_data_pipeline/finetune_datasource.py b/refact_data_pipeline/finetune_datasource.py index 1c1e233a..b15180db 100644 --- a/refact_data_pipeline/finetune_datasource.py +++ b/refact_data_pipeline/finetune_datasource.py @@ -244,16 +244,32 @@ def __iter__(self): "fim_out": 0, } for sample in self.inner_filter: - tokens = self.enc.encode(sample["prompt"]) + self.enc.encode(sample["middle"]) - mask = [0 if t in self.special_tokens else 1 for t in tokens] - - if len(tokens) + 1 > self.n_ctx: + prompt_tokens = self.enc.encode(sample["prompt"]) + context_tokens = [] + fim_tokens = [] + is_context_part = True + for t in prompt_tokens: + is_context_part &= (t not in self.special_tokens) + if is_context_part: + context_tokens.append(t) + else: + fim_tokens.append(t) + + assert prompt_tokens == context_tokens + fim_tokens + + context_mask = [0] * len(context_tokens) + fim_tokens = fim_tokens + self.enc.encode(sample["middle"]) + fim_mask = [0 if t in self.special_tokens else 1 for t in fim_tokens] + tokens = context_tokens + fim_tokens + [self.enc.EOT] + mask = context_mask + fim_mask + [1] + + if len(tokens) > self.n_ctx: continue stats["fim_out"] += 1 yield { - "tokens": tokens + [self.enc.EOT], - "mask": mask + [1], + "tokens": tokens, + "mask": mask, "stats": {**sample["stats"], **stats}, } From 8e53a9a33314cbd4e0125f30520736a3ed52db9a Mon Sep 17 00:00:00 2001 From: mitya Date: Wed, 1 May 2024 21:15:11 +0300 Subject: [PATCH 4/4] samples limit and full dataset epoch --- refact_data_pipeline/finetune_datasource.py | 45 ++++++++++++------- .../configuration/supported_models.py | 2 +- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/refact_data_pipeline/finetune_datasource.py b/refact_data_pipeline/finetune_datasource.py index b15180db..2863ee32 100644 --- a/refact_data_pipeline/finetune_datasource.py +++ b/refact_data_pipeline/finetune_datasource.py @@ -193,31 +193,42 @@ def _build_pipeline(self, files: List[Dict[str, Any]]): class ReadJSONLFileByFile(ReadFileByFile): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._samples_limit = self.dataopts.get("samples_limit", 0) + def __iter__(self): sample_num = 0 file_num = 0 epoch = 0 - for j in self.inner_filter: - with jsonlines.open(os.path.join(self.basedir, j["path"])) as r: - for data in r: - if not data["middle"]: - continue - yield { - **data, - "path": j["path"], - "stats": { - "sample_num": sample_num, - "file_num": file_num, - "epoch": epoch, - }, - } - sample_num += 1 - file_num += 1 + quit_flag = False + while not quit_flag: + for j in self.inner_filter: + with jsonlines.open(os.path.join(self.basedir, j["path"])) as r: + for data in r: + if not data["middle"]: + continue + yield { + **data, + "path": j["path"], + "stats": { + "sample_num": sample_num, + "file_num": file_num, + "epoch": epoch, + }, + } + sample_num += 1 + if self._samples_limit and sample_num >= self._samples_limit: + quit_flag = True + break + file_num += 1 + if quit_flag: + break epoch += 1 self.epoch_callback(epoch) if epoch == self.quit_on_epoch: - break + quit_flag = True class RAGFIM(PipelineNode): diff --git a/self_hosting_machinery/finetune/configuration/supported_models.py b/self_hosting_machinery/finetune/configuration/supported_models.py index 1cabe19d..cd06c06e 100644 --- a/self_hosting_machinery/finetune/configuration/supported_models.py +++ b/self_hosting_machinery/finetune/configuration/supported_models.py @@ -9,7 +9,7 @@ _fim_test_ds_pipeline = { "ds_opts": "n_ctx={n_ctx},debug=0,seed=42,shuffle_depth=0,quit_on_epoch=1," "fim_probability=0.9,fim_drop_residual=1,random_trim_context_prob=0.01," - "pack_single=1,pack_complete=0,pack_buffer_size=50", + "pack_single=1,pack_complete=0,pack_buffer_size=50,samples_limit=16", "ds_name": "RefactRAGFIMDataset" } _bigcode_tokenizer_mapping = {