Skip to content

Commit

Permalink
Support v7. (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan authored Dec 11, 2024
1 parent 5d5960e commit e59211a
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 25 deletions.
33 changes: 28 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ keywords = ["LLM", "deep-learning", "model", "rwkv"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/cgisky1980/ai00_rwkv_server"
rust-version = "1.76"
version = "0.5.10"
version = "0.5.11"

[workspace.dependencies]
anyhow = "1"
Expand All @@ -35,7 +35,7 @@ path = "crates/ai00-core"
# path = "../web-rwkv"
default-features = false
features = ["native"]
version = "0.9.4"
version = "0.9.5"

[patch.crates-io]
hf-hub = { git = "https://github.com/cgisky1980/hf-hub.git", branch = "main" }
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ It only supports Safetensors models with the `.st` extension now. Models saved w

1. [Download the `.pth` model](https://huggingface.co/BlinkDL)

2. (Recommended) Run the python script `convert2ai00.py` or `convert_safetensors.py`:
2. (Recommended) Run the python script `convert_ai00.py` or `convert_safetensors.py`:

```bash
$ python ./convert2ai00.py --input /path/to/model.pth --output /path/to/model.st
$ python assets/scripts/convert_ai00.py --input /path/to/model.pth --output /path/to/model.st
```

Requirements: Python, with `torch` and `safetensors` installed.
Expand Down
4 changes: 2 additions & 2 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@

1. [下载pth模型](https://huggingface.co/BlinkDL)

2. 克隆或下载本仓库下[convert2ai00.py](./convert2ai00.py)或[convert_safetensors.py](./convert_safetensors.py)程序,并安装相应的依赖库(`torch``safetensors`
2. 克隆或下载本仓库下[convert_ai00.py](./assets/scripts/convert_ai00.py)或[convert_safetensors.py](./assets/scripts/convert_safetensors.py)程序,并安装相应的依赖库(`torch``safetensors`

3. 运行上述程序,并指定输入输出路径

```bash
$ python convert_safetensors.py --input ./filename.pth --output ./filename.st
$ python assets/scripts/convert_safetensors.py --input ./filename.pth --output ./filename.st
```

4. 如果你不想安装 Python 或 Torch,可以前往[`web-rwkv`](https://github.com/cryscan/web-rwkv/releases)并下载不依赖于 Python 或 Torch 的转换器`web-rwkv-converter`
Expand Down
16 changes: 7 additions & 9 deletions convert2ai00.py → assets/scripts/convert_ai00.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,13 @@ def get_sha(file_path):

print(f"正在转换模型: {model_info}")

convert_file(
args.input,
args.output,
rename={"time_faaaa": "time_first", "time_maa": "time_mix",
"lora_A": "lora.0", "lora_B": "lora.1"},
transpose_names=["time_mix_w1", "time_mix_w2",
"time_decay_w1", "time_decay_w2", "time_state", "lora.0"],
model_info=model_info
)
convert_file(args.input, args.output,
rename={"time_faaaa": "time_first", "time_maa": "time_mix",
"lora_A": "lora.0", "lora_B": "lora.1"},
transpose_names=[
"time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2",
"w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2",
"time_state", "lora.0"])
print(f"Saved to {args.output}")

print(f"{args.output} __metadata__ :\n")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/python

import collections
import numpy
import os
Expand Down Expand Up @@ -65,6 +67,7 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
if transpose_name in new_k:
dims = len(v.shape)
v = v.transpose(dims - 2, dims - 1)
break
print(f"{new_k}\t{v.shape}\t{v.dtype}")
loaded[new_k] = {
"dtype": str(v.dtype).split(".")[-1],
Expand Down Expand Up @@ -92,5 +95,8 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
convert_file(args.input, args.output,
rename={"time_faaaa": "time_first", "time_maa": "time_mix",
"lora_A": "lora.0", "lora_B": "lora.1"},
transpose_names=["time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2", "time_state", "lora.0"])
transpose_names=[
"time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2",
"w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2",
"time_state", "lora.0"])
print(f"Saved to {args.output}")
32 changes: 32 additions & 0 deletions assets/scripts/convert_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/python

import json

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
type=str,
default="rwkv_vocab_v20230424.txt",
help="Path to input txt")
parser.add_argument(
"--output",
type=str,
default="rwkv_vocab_v20230424.json",
help="Path to output JSON",
)
args = parser.parse_args()


I_TO_TOKEN = {}
lines = open(args.input, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
if not isinstance(x, str):
x = list(x)
I_TO_TOKEN[idx] = x

out = open(args.output, "w")
out.write(json.dumps(I_TO_TOKEN, indent=4))
9 changes: 5 additions & 4 deletions crates/ai00-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use web_rwkv::{
runtime::{
loader::{Loader, Lora, LoraBlend, Reader},
model::{ContextAutoLimits, EmbedDevice, ModelBuilder, ModelInfo, ModelVersion, Quant},
v4, v5, v6,
v4, v5, v6, v7,
},
tensor::{serialization::Seed, TensorCpu},
tokenizer::Tokenizer,
Expand Down Expand Up @@ -273,7 +273,7 @@ async fn load_init_state<R: Reader>(
ModelVersion::V4 => bail!("v4 does not support init state yet"),
ModelVersion::V5 => v5::read_state(context, info, model).await,
ModelVersion::V6 => v6::read_state(context, info, model).await,
ModelVersion::V7 => bail!("v7 is not supported yet"),
ModelVersion::V7 => v7::read_state(context, info, model).await,
};
state.map_err(Into::into)
}
Expand Down Expand Up @@ -388,7 +388,6 @@ async fn load_runtime(
Runtime::new(context, bundle, reload, states, tokenizer).await
}
)+
(version, _) => bail!("unsupported version: {:?}", version)
}
}
}
Expand All @@ -398,9 +397,11 @@ async fn load_runtime(
(ModelVersion::V4, Precision::Fp16, v4::Model, build_v4, v4::Bundle::<f16>),
(ModelVersion::V5, Precision::Fp16, v5::Model, build_v5, v5::Bundle::<f16>),
(ModelVersion::V6, Precision::Fp16, v6::Model, build_v6, v6::Bundle::<f16>),
(ModelVersion::V7, Precision::Fp16, v7::Model, build_v7, v7::Bundle::<f16>),
(ModelVersion::V4, Precision::Fp32, v4::Model, build_v4, v4::Bundle::<f32>),
(ModelVersion::V5, Precision::Fp32, v5::Model, build_v5, v5::Bundle::<f32>),
(ModelVersion::V6, Precision::Fp32, v6::Model, build_v6, v6::Bundle::<f32>)
(ModelVersion::V6, Precision::Fp32, v6::Model, build_v6, v6::Bundle::<f32>),
(ModelVersion::V7, Precision::Fp32, v7::Model, build_v7, v7::Bundle::<f32>)
}
)
}
Expand Down

0 comments on commit e59211a

Please sign in to comment.