Skip to content

Commit

Permalink
continue
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 24, 2025
1 parent 5f2105b commit 4d0950b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
3 changes: 3 additions & 0 deletions examples/Transformer_WikiText/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
[deps]
AutoStructs = "2e0df379-9877-4907-ab94-cd881f8d985b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceDatasets = "d94b9a45-fdf5-4270-b024-5cbb9ef7117d"
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tsunami = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"
29 changes: 29 additions & 0 deletions examples/Transformer_WikiText/main.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
# Needs
# pkg> add https://github.com/MurrellGroup/HuggingFaceTokenizers.jl
using Tsunami, Flux
import HuggingFaceTokenizers as HFT
import HuggingFaceDatasets as HFD
using PythonCall: PyList

include("model.jl") # Transformer

function get_dataset()
dataset = HFD.load_dataset("wikitext", "wikitext-2-raw-v1")
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]
return train_dataset, val_dataset, test_dataset
end

function get_tokenizer(train_dataset)
# we follow https://huggingface.co/docs/tokenizers/quicktour
bpe = HFT.tokenizers.models.BPE(unk_token="[UNK]")
tokenizer = HFT.tokenizers.Tokenizer(bpe) # python tokenizer
pre_tok = HFT.tokenizers.pre_tokenizers.Whitespace()
tokenizer.pre_tokenizer = pre_tok
trainer = HFT.tokenizers.trainers.BpeTrainer(special_tokens=PyList(["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]))
tokenizer.train_from_iterator(train_dataset["text"], trainer=trainer)
return tokenizer
# return HFT.Tokenizer(tokenizer) # return the julia wrapper
end

model = Transformer()
train_dataset, val_dataset, test_dataset = get_dataset()
tokenizer = get_tokenizer(train_dataset)
2 changes: 1 addition & 1 deletion examples/Transformer_WikiText/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Flux.@layer TransformerBlock

function (block::TransformerBlock)(x, mask)
x = x .+ block.mha(block.mha_norm(x); mask)[1]
x = x .+ block.ffwd(block.ffwd_norm(x))[1]
x = x .+ block.ffwd(block.ffwd_norm(x))
return x
end

Expand Down

0 comments on commit 4d0950b

Please sign in to comment.