Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refact2 #400

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion refact_data_pipeline/finetune_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def files_len(self) -> int:
def set_epoch_callback(self, callback):
assert len(self._pipeline) > 0
file_reader = self._pipeline[0]
assert type(file_reader) is ReadFileByFile
assert isinstance(file_reader, ReadFileByFile)
file_reader.set_epoch_callback(callback)

def set_random_state(self, seed):
Expand Down Expand Up @@ -190,3 +190,105 @@ def _build_pipeline(self, files: List[Dict[str, Any]]):
dp = pp.DensePacker(fim, self._ds_options)
shf = pp.Shuffle(dp, self._ds_options)
return [read_by_file, fim, dp, shf]


class ReadJSONLFileByFile(ReadFileByFile):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._samples_limit = self.dataopts.get("samples_limit", 0)

def __iter__(self):
sample_num = 0
file_num = 0
epoch = 0

quit_flag = False
while not quit_flag:
for j in self.inner_filter:
with jsonlines.open(os.path.join(self.basedir, j["path"])) as r:
for data in r:
if not data["middle"]:
continue
yield {
**data,
"path": j["path"],
"stats": {
"sample_num": sample_num,
"file_num": file_num,
"epoch": epoch,
},
}
sample_num += 1
if self._samples_limit and sample_num >= self._samples_limit:
quit_flag = True
break
file_num += 1
if quit_flag:
break
epoch += 1
self.epoch_callback(epoch)
if epoch == self.quit_on_epoch:
quit_flag = True


class RAGFIM(PipelineNode):
def __init__(
self,
inner_filter,
dataopts: DatasetOpts,
):
self.enc = dataopts.encoding
super().__init__(dataopts)
self.inner_filter = inner_filter
self.n_ctx = dataopts.get("n_ctx", 2048)
self.debug = bool(dataopts.get("debug", 0))
self.special_tokens = [
self.enc.PREFIX,
self.enc.SUFFIX,
self.enc.INFILL,
self.enc.EOT,
]
assert len(set(self.special_tokens)) == len(self.special_tokens)

def __iter__(self):
stats: Dict[str, Any] = {
"fim_out": 0,
}
for sample in self.inner_filter:
prompt_tokens = self.enc.encode(sample["prompt"])
context_tokens = []
fim_tokens = []
is_context_part = True
for t in prompt_tokens:
is_context_part &= (t not in self.special_tokens)
if is_context_part:
context_tokens.append(t)
else:
fim_tokens.append(t)

assert prompt_tokens == context_tokens + fim_tokens

context_mask = [0] * len(context_tokens)
fim_tokens = fim_tokens + self.enc.encode(sample["middle"])
fim_mask = [0 if t in self.special_tokens else 1 for t in fim_tokens]
tokens = context_tokens + fim_tokens + [self.enc.EOT]
mask = context_mask + fim_mask + [1]

if len(tokens) > self.n_ctx:
continue

stats["fim_out"] += 1
yield {
"tokens": tokens,
"mask": mask,
"stats": {**sample["stats"], **stats},
}


class RefactRAGFIMDataset(RefactDataset):
def _build_pipeline(self, files: List[Dict[str, Any]]):
read_by_file = ReadJSONLFileByFile(self.basedir, files, self._ds_options)
fim = RAGFIM(read_by_file, self._ds_options)
dp = pp.DensePacker(fim, self._ds_options)
shf = pp.Shuffle(dp, self._ds_options)
return [read_by_file, fim, dp, shf]
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
_fim_train_ds_pipeline = {
"ds_opts": "n_ctx={n_ctx},debug=0,seed=42,shuffle_depth=256,"
"fim_probability=0.9,fim_drop_residual=1,random_trim_context_prob=0.01",
"ds_name": "RefactFIMCodeDataset"
"ds_name": "RefactRAGFIMDataset"
}

_fim_test_ds_pipeline = {
"ds_opts": "n_ctx={n_ctx},debug=0,seed=42,shuffle_depth=0,quit_on_epoch=1,"
"fim_probability=0.9,fim_drop_residual=1,random_trim_context_prob=0.01,"
"pack_single=1,pack_complete=0,pack_buffer_size=50",
"ds_name": "RefactFIMCodeDataset"
"pack_single=1,pack_complete=0,pack_buffer_size=50,samples_limit=16",
"ds_name": "RefactRAGFIMDataset"
}
_bigcode_tokenizer_mapping = {
"eot_idx": 0,
Expand Down
31 changes: 30 additions & 1 deletion self_hosting_machinery/finetune/scripts/finetune_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,34 @@ def gpu_filter_and_build_config(
return _build_finetune_config_by_heuristics(run_id, finetune_cfg, model_config, **kwargs)


def no_filter_build_config(
pname: str,
run_id: str,
model_name: str,
model_info: Dict[str, Any],
model_config: Dict[str, Any],
model_ctx_size: int,
**kwargs) -> Dict[str, Any]:
if model_ctx_size > 0:
model_info["T"] = model_ctx_size
finetune_cfg = {
**base_config(model_name=model_name, model_info=model_info),
**kwargs,
}
traces.log("locking \"%s\" for filtering" % pname)
if dist.get_rank() == 0:
with filelock.FileLock(env.PP_PROJECT_LOCK(pname)):
traces.log("locked \"%s\" successfully" % pname)
traces.log("completed filtering, now copy files to run \"%s\"" % run_id)
_copy_source_files(
env.PP_TRAIN_FILTERED_FILEPATH(pname), env.PERRUN_TRAIN_FILTERED_FILEPATH(run_id), pname, run_id)
_copy_source_files(
env.PP_TEST_FILTERED_FILEPATH(pname), env.PERRUN_TEST_FILTERED_FILEPATH(run_id), pname, run_id)
dist.barrier()

return _build_finetune_config_by_heuristics(run_id, finetune_cfg, model_config, **kwargs)


def _copy_source_files(jsonl_src, jsonl_dst, pname, run_id):
for d in jsonlines.open(jsonl_src):
try:
Expand Down Expand Up @@ -357,7 +385,8 @@ def catch_sigusr1(signum, frame):
status_tracker.update_status("working")
_log_everywhere("Dest dir is %s" % traces.context().path)

finetune_cfg = gpu_filter_and_build_config(model_config=model_config, model_info=model_info, **vars(args))
# finetune_cfg = gpu_filter_and_build_config(model_config=model_config, model_info=model_info, **vars(args))
finetune_cfg = no_filter_build_config(model_config=model_config, model_info=model_info, **vars(args))

_log_everywhere(f"Building the model {finetune_cfg['model_name']}")
model_context = ModelContext(
Expand Down