Skip to content

Commit

Permalink
Merge pull request #65 from microsoft/sequence_packing
Browse files Browse the repository at this point in the history
[WIP] Sequence packing
  • Loading branch information
pclucas14 authored Aug 2, 2024
2 parents e004593 + 5c918e6 commit f3244db
Show file tree
Hide file tree
Showing 20 changed files with 724 additions and 44 deletions.
13 changes: 8 additions & 5 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@ jobs:
- name: Set up python
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.11'
cache: 'pip'
# flash-attn requires torch to be installed
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install --upgrade pip
pip install -e '.[dev]'
pip install -e '.[flash-attn]'
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
Expand All @@ -57,11 +60,11 @@ jobs:
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.

# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality


# Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
Expand All @@ -70,7 +73,7 @@ jobs:
# ℹ️ Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun

# If the Autobuild fails above, remove it and uncomment the following three lines.
# If the Autobuild fails above, remove it and uncomment the following three lines.
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.

# - run: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ jobs:
with:
python-version: '3.11'
cache: 'pip'
# flash-attn requires torch to be installed
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -e '.[dev]'
pip install -e '.[flash-attn]'
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
Expand Down
25 changes: 19 additions & 6 deletions mttl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,16 @@ def __init__(self, filenames=None, kwargs=None, raise_error=True, silent=False):
self.post_init(silent=silent)

def post_init(self, silent=False):
pass
if self.attn_implementation == "eager" and self.pack_sequences:
logger.warning(
"Eager attention is not compatible with packed sequences"
+ ", tokens across examples will not be masked"
)
elif self.attn_implementation == "flash_attention_2" and self.pack_sequences:
logger.warning(
"The wrapper we provide for flash attention 2 may not behave as expected for"
+ " some models. Please make sure you test the model with packed sequences"
)

@classmethod
def fromdict(cls, data):
Expand Down Expand Up @@ -181,6 +190,14 @@ def _set_defaults(self):
# Data config
self.dataset = None
self.custom_tasks_splits = None
self.subsample_train = None
self.subsample_dev = None
self.subsample_test = None
self.subsample_per_task = False
self.pack_sequences = False
self.pad_to_multiple_of = 8
self.padding_side = "right"
self.max_seq_per_pack = 4

self.data_dir = os.getenv("TRAIN_DIR", "/tmp/")
self.output_dir = os.getenv("OUTPUT_DIR", "./output")
Expand Down Expand Up @@ -253,11 +270,6 @@ def _set_defaults(self):
self.seed = 42
self.eval_before_training = True

self.subsample_train = None
self.subsample_dev = None
self.subsample_test = None
self.subsample_per_task = False

self.ni_online_eval = False # zero-shot online eval for ni
self.t0_online_eval = False # zero-shot eval for t0
self.early_stop_on_zero_shot = False # zero-shot early stopping
Expand All @@ -281,6 +293,7 @@ def _set_defaults(self):

self.model = None
self.model_family = None # model family, either "gpt" or "encdec"
self.attn_implementation = None

self.precision = "32"
self.monitor_grad_alignment_on = None
Expand Down
Loading

0 comments on commit f3244db

Please sign in to comment.