Skip to content

Commit

Permalink
add bunch of cleanups and design principle section (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanchaol authored Feb 23, 2024
1 parent 78878f5 commit b4cda94
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 84 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ Note: This repository is currently under heavy development.

torchtrain contains PyTorch native parallelisms, tools and utilities to train large models.

## Design Principles

TorchTrain is a native PyTorch library with various training techniques. While it utilizes the PyTorch ecosystem for things like data loading (i.e. HuggingFace datasets), the core functionality is written in PyTorch.

* Designed to be easy to understand, use and extend for different training purposes.
* Minimal changes to the model code, when applying 1D/2D or 3D Parallelisms.
* Modular components instead of monolithic codebase

# Installation

Install PyTorch from source or install the latest pytorch nightly, then install requirements by
Expand Down
1 change: 0 additions & 1 deletion torchtrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.datasets.alpaca import build_alpaca_data_loader
from torchtrain.datasets.pad_batch_sequence import pad_batch_to_longest_seq
from torchtrain.datasets.tokenizer import create_tokenizer

__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"]
Expand Down
77 changes: 0 additions & 77 deletions torchtrain/datasets/pad_batch_sequence.py

This file was deleted.

6 changes: 0 additions & 6 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
}
# if layer_id == 0:
# # in first transformer block we need to shard the input
# layer_plan[""] = PrepareModuleInput(
# input_layouts=(Replicate(), None),
# desired_input_layouts=(Shard(0), None),
# )

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
Expand Down

0 comments on commit b4cda94

Please sign in to comment.