-
-
Notifications
You must be signed in to change notification settings - Fork 908
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wip add new proposed message structure (#1904)
* wip add new proposed message structure * tokenization * wip * wip transform builder * wip make the chat dataset loadable * wip chatml + llama 3 new chat objects * chore: lint * chore: lint * fix tokenization * remove dacite dependency since we're using pydantic now * fix handling when already correctly split in messages * make sure to remove chat features from tokenized ds * move chat to be a input transform for messages * make sure llama3 has the bos token * remove non-working special token code * fix messages strat loader
- Loading branch information
Showing
23 changed files
with
1,285 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
accelerate==0.34.1 | ||
addict==2.4.0 | ||
aiofiles==23.2.1 | ||
aiohttp==3.9.0 | ||
aiosignal==1.3.1 | ||
aiostream==0.5.2 | ||
alembic==1.13.1 | ||
annotated-types==0.6.0 | ||
annoy==1.17.3 | ||
ansible==6.7.0 | ||
ansible-core==2.13.13 | ||
ansible-vault==2.1.0 | ||
anyio==3.7.1 | ||
appdirs==1.4.4 | ||
art==6.0 | ||
asgiref==3.7.2 | ||
async-timeout==4.0.2 | ||
attrdict==2.0.1 | ||
attrs==22.2.0 | ||
awscli==1.32.75 | ||
-e git+ssh://[email protected]/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl | ||
backoff==2.2.1 | ||
base58==2.1.1 | ||
beartype==0.17.2 | ||
bitnet==0.2.1 | ||
bitsandbytes==0.42.0 | ||
bittensor==6.7.0 | ||
black==23.7.0 | ||
blinker==1.7.0 | ||
boto3==1.34.75 | ||
botocore==1.34.75 | ||
cachetools==5.3.3 | ||
cachy==0.1.1 | ||
certifi==2023.7.22 | ||
cffi==1.16.0 | ||
cfgv==3.3.1 | ||
chai-guanaco==1.2.4 | ||
charset-normalizer==3.2.0 | ||
cleo==0.6.8 | ||
click==8.1.7 | ||
cloudpickle==2.0.0 | ||
cohere==4.11.2 | ||
colorama==0.4.4 | ||
coloredlogs==15.0.1 | ||
CoLT5-attention==0.10.20 | ||
contextlib2==21.6.0 | ||
contourpy==1.2.0 | ||
cryptography==41.0.3 | ||
cycler==0.12.1 | ||
cytoolz==0.12.3 | ||
databricks-cli==0.18.0 | ||
dataclasses-json==0.5.7 | ||
datasets==2.11.0 | ||
ddt==1.6.0 | ||
decorator==5.1.1 | ||
deepspeed==0.15.0 | ||
# Editable Git install with no remote (dialogpt==0.1) | ||
-e /Users/wing/Projects/ml/dialogpt/src | ||
dill==0.3.6 | ||
distlib==0.3.6 | ||
docker==7.0.0 | ||
docker-pycreds==0.4.0 | ||
docstring-parser==0.15 | ||
docutils==0.16 | ||
ecdsa==0.18.0 | ||
einops==0.7.0 | ||
einops-exts==0.0.4 | ||
einx==0.1.3 | ||
entrypoints==0.4 | ||
eth-hash==0.6.0 | ||
eth-keys==0.5.0 | ||
eth-typing==4.0.0 | ||
eth-utils==2.3.1 | ||
evaluate==0.4.0 | ||
exceptiongroup==1.1.1 | ||
fastapi==0.109.2 | ||
fastcore==1.5.29 | ||
ffmpy==0.4.0 | ||
filelock==3.12.2 | ||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet | ||
fire==0.5.0 | ||
first==2.0.2 | ||
flake8==7.0.0 | ||
Flask==3.0.1 | ||
fonttools==4.47.2 | ||
frozendict==2.4.1 | ||
frozenlist==1.3.3 | ||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe | ||
fsspec==2023.6.0 | ||
fuzzywuzzy==0.18.0 | ||
gitdb==4.0.10 | ||
GitPython==3.1.31 | ||
google-pasta==0.2.0 | ||
gradio==4.42.0 | ||
gradio_client==1.3.0 | ||
greenlet==2.0.2 | ||
grpclib==0.4.7 | ||
gunicorn==21.2.0 | ||
h11==0.14.0 | ||
h2==4.1.0 | ||
hpack==4.0.0 | ||
httpcore==0.17.3 | ||
httpx==0.24.1 | ||
huggingface-hub==0.23.4 | ||
humanfriendly==10.0 | ||
hyperframe==6.0.1 | ||
identify==2.5.24 | ||
idna==3.4 | ||
immutables==0.20 | ||
importlib-metadata==6.7.0 | ||
importlib-resources==6.1.1 | ||
inflection==0.5.1 | ||
iniconfig==2.0.0 | ||
itsdangerous==2.1.2 | ||
Jinja2==3.1.2 | ||
jmespath==1.0.1 | ||
joblib==1.3.2 | ||
jsonlines==3.1.0 | ||
jsonschema==2.6.0 | ||
kiwisolver==1.4.5 | ||
langchain==0.0.144 | ||
Levenshtein==0.24.0 | ||
libcst==1.1.0 | ||
liger-kernel==0.0.0 | ||
lion-pytorch==0.1.2 | ||
llama-cpp-python==0.1.36 | ||
llvmlite==0.40.1 | ||
local-attention==1.9.0 | ||
loguru==0.7.0 | ||
Mako==1.3.2 | ||
Markdown==3.5.2 | ||
markdown-it-py==3.0.0 | ||
markdown2==2.4.10 | ||
MarkupSafe==2.1.2 | ||
marshmallow==3.19.0 | ||
marshmallow-enum==1.5.1 | ||
matplotlib==3.8.2 | ||
mccabe==0.7.0 | ||
mdurl==0.1.2 | ||
MEGABYTE-pytorch==0.0.7 | ||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit | ||
mlflow==2.10.0 | ||
modal==0.62.77 | ||
more-itertools==10.2.0 | ||
mpmath==1.2.1 | ||
msgpack==1.0.7 | ||
msgpack-numpy-opentensor==0.5.0 | ||
multidict==6.0.4 | ||
multiprocess==0.70.14 | ||
munch==2.5.0 | ||
mypy==1.3.0 | ||
mypy-extensions==1.0.0 | ||
nest-asyncio==1.6.0 | ||
netaddr==0.10.1 | ||
networkx==3.0rc1 | ||
nh3==0.2.14 | ||
nodeenv==1.8.0 | ||
nomic==2.0.2 | ||
numba==0.57.1 | ||
numexpr==2.8.4 | ||
numpy==1.24.4 | ||
oauthlib==3.2.2 | ||
openai==0.27.4 | ||
openapi==1.1.0 | ||
openapi-schema-pydantic==1.2.4 | ||
optimum==1.8.6 | ||
orjson==3.10.7 | ||
packaging==23.1 | ||
pandas==2.0.0 | ||
parameterized==0.9.0 | ||
password-strength==0.0.3.post2 | ||
pastel==0.1.1 | ||
pathos==0.3.0 | ||
pathspec==0.11.1 | ||
pathtools==0.1.2 | ||
peft==0.11.1 | ||
pendulum==3.0.0 | ||
Pillow==9.5.0 | ||
pip-tools==1.11.0 | ||
platformdirs==3.2.0 | ||
pluggy==1.4.0 | ||
poetry==0.7.1 | ||
pox==0.3.2 | ||
ppft==1.7.6.6 | ||
pre-commit==3.3.2 | ||
prettytable==3.10.0 | ||
prompt-toolkit==3.0.39 | ||
protobuf==3.20.2 | ||
protobuf3-to-dict==0.1.5 | ||
psutil==5.9.5 | ||
psycopg==3.1.18 | ||
PuLP==2.8.0 | ||
py==1.11.0 | ||
py-bip39-bindings==0.1.11 | ||
py-cpuinfo==9.0.0 | ||
py-ed25519-zebra-bindings==1.0.1 | ||
py-sr25519-bindings==0.2.0 | ||
pyarrow==11.0.0 | ||
pyasn1==0.6.0 | ||
pycodestyle==2.11.1 | ||
pycparser==2.21 | ||
pycryptodome==3.20.0 | ||
pydantic==2.5.3 | ||
pydantic_core==2.14.6 | ||
pydub==0.25.1 | ||
pyfiglet==0.8.post1 | ||
pyflakes==3.2.0 | ||
Pygments==2.15.1 | ||
PyJWT==2.8.0 | ||
pylev==1.4.0 | ||
PyNaCl==1.5.0 | ||
pynvml==11.5.0 | ||
pyparsing==2.4.7 | ||
pyrsistent==0.14.11 | ||
pytest==8.0.2 | ||
pytest-asyncio==0.23.4 | ||
python-dateutil==2.8.2 | ||
python-dotenv==1.0.1 | ||
python-Levenshtein==0.24.0 | ||
python-multipart==0.0.9 | ||
pytz==2023.3 | ||
PyYAML==6.0.1 | ||
querystring-parser==1.2.4 | ||
rapidfuzz==3.6.1 | ||
regex==2023.6.3 | ||
requests==2.31.0 | ||
requests-toolbelt==0.8.0 | ||
resolvelib==0.8.1 | ||
responses==0.18.0 | ||
retry==0.9.2 | ||
rich==13.7.0 | ||
rsa==4.7.2 | ||
ruff==0.6.3 | ||
s3transfer==0.10.1 | ||
safetensors==0.4.5 | ||
sagemaker==2.148.0 | ||
scalecodec==1.2.7 | ||
schedulefree==1.2.1 | ||
schema==0.7.5 | ||
scikit-learn==1.4.0 | ||
scipy==1.9.3 | ||
seaborn==0.13.2 | ||
semantic-version==2.10.0 | ||
sentencepiece==0.2.0 | ||
sentry-sdk==1.19.1 | ||
setproctitle==1.3.2 | ||
shellingham==1.5.4 | ||
shortuuid==1.0.11 | ||
shtab==1.6.5 | ||
sigtools==4.0.1 | ||
six==1.16.0 | ||
skypilot==0.4.1 | ||
smdebug-rulesconfig==1.0.1 | ||
smmap==5.0.0 | ||
sniffio==1.3.0 | ||
SQLAlchemy==1.4.47 | ||
sqlparse==0.4.4 | ||
starlette==0.36.3 | ||
substrate-interface==1.5.2 | ||
svgwrite==1.4.3 | ||
sympy==1.11.1 | ||
synchronicity==0.6.7 | ||
tabulate==0.9.0 | ||
tblib==1.7.0 | ||
tenacity==8.2.2 | ||
tensor-parallel==2.0.0 | ||
termcolor==2.2.0 | ||
text2art==0.2.0 | ||
threadpoolctl==3.2.0 | ||
tiktoken==0.6.0 | ||
time-machine==2.14.1 | ||
timm==0.9.16 | ||
tokenizers==0.19.1 | ||
tokenmonster==1.1.12 | ||
toml==0.9.6 | ||
tomli==2.0.1 | ||
tomlkit==0.12.0 | ||
toolz==0.12.1 | ||
torch==2.2.0 | ||
torchdata==0.6.1 | ||
torchdiffeq==0.2.3 | ||
TorchFix==0.4.0 | ||
torchtext==0.15.2 | ||
torchvision==0.17.0 | ||
tqdm==4.66.2 | ||
transformers==4.44.2 | ||
trl==0.9.6 | ||
typer==0.12.5 | ||
types-certifi==2021.10.8.3 | ||
types-requests==2.31.0.20240125 | ||
types-setuptools==69.0.0.20240125 | ||
types-toml==0.10.8.7 | ||
typing==3.7.4.3 | ||
typing-inspect==0.8.0 | ||
typing_extensions==4.9.0 | ||
tyro==0.5.18 | ||
tzdata==2023.3 | ||
unique-names-generator==1.0.2 | ||
urllib3==2.2.2 | ||
uvicorn==0.22.0 | ||
vector_quantize_pytorch==1.14.1 | ||
virtualenv==20.23.0 | ||
voyager==2.0.2 | ||
wandb==0.16.2 | ||
watchfiles==0.21.0 | ||
wavedrom==2.0.3.post3 | ||
wcwidth==0.2.6 | ||
websocket-client==1.7.0 | ||
websockets==12.0 | ||
Werkzeug==3.0.1 | ||
wonderwords==2.2.0 | ||
xxhash==3.2.0 | ||
yarl==1.8.2 | ||
zetascale==2.2.7 | ||
zipp==3.15.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
ChatML transformation functions for MessageContents | ||
""" | ||
from typing import Optional | ||
|
||
from ..messages import MessageContents, Messages | ||
from .shared import wrap_tools | ||
|
||
|
||
def format_message( | ||
message: Messages, | ||
message_index: Optional[int] = None, # pylint: disable=unused-argument | ||
) -> Messages: | ||
if message.is_chat_formatted: | ||
return message | ||
|
||
# prepend the role prefix within a MessageContents to message.content | ||
message.content.insert( | ||
0, | ||
MessageContents( | ||
type="text", | ||
value=f"<|im_start|>{message.role}\n", | ||
weight=0, | ||
), | ||
) | ||
message.content.append( | ||
MessageContents(type="text", value="<|im_end|>", weight=message.weight) | ||
) | ||
message.content.append(MessageContents(type="text", value="\n", weight=0)) | ||
|
||
message = wrap_tools(message) | ||
|
||
message.is_chat_formatted = True | ||
return message |
Oops, something went wrong.