diff --git a/Cargo.lock b/Cargo.lock index a5d9ea898..6e0630435 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -58,6 +58,15 @@ dependencies = [ "libc", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.6.12" @@ -233,13 +242,46 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.3.4", "bitflags 1.3.2", "bytes", "futures-util", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 0.1.2", + "tokio", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +dependencies = [ + "async-trait", + "axum-core 0.4.5", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.0", + "hyper-util", "itoa", "matchit", "memchr", @@ -251,11 +293,12 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 1.0.1", "tokio", "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -267,28 +310,51 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 0.2.9", + "http-body 0.4.5", "mime", "rustversion", "tower-layer", "tower-service", ] +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.1", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-tracing-opentelemetry" -version = "0.10.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "164b95427e83b79583c7699a72b4a6b485a12bbdef5b5c054ee5ff2296d82f52" +checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08" dependencies = [ - "axum", - "futures", - "http", - "opentelemetry 0.18.0", + "axum 0.7.5", + "futures-core", + "futures-util", + "http 1.1.0", + "opentelemetry 0.21.0", + "pin-project-lite", "tower", - "tower-http 0.3.5", "tracing", - "tracing-opentelemetry 0.18.0", + "tracing-opentelemetry 0.22.0", + "tracing-opentelemetry-instrumentation-sdk", ] [[package]] @@ -719,33 +785,13 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys 0.3.7", -] - [[package]] name = "dirs" version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ - "dirs-sys 0.4.1", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", + "dirs-sys", ] [[package]] @@ -824,7 +870,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", - "flume 0.11.0", + "flume", "half", "lebe", "miniz_oxide", @@ -875,22 +921,9 @@ checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" [[package]] name = "flume" -version = "0.10.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" -dependencies = [ - "futures-core", - "futures-sink", - "nanorand", - "pin-project", - "spin 0.9.8", -] - -[[package]] -name = "flume" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ "spin 0.9.8", ] @@ -1062,6 +1095,12 @@ version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "grpc-metadata" version = "0.1.0" @@ -1083,7 +1122,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 2.0.0", "slab", "tokio", @@ -1146,7 +1185,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" dependencies = [ - "dirs 5.0.1", + "dirs", "futures", "indicatif", "log", @@ -1183,6 +1222,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -1190,15 +1240,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] [[package]] -name = "http-range-header" -version = "0.3.0" +name = "http-body" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "pin-project-lite", +] [[package]] name = "httparse" @@ -1223,8 +1290,8 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", + "http 0.2.9", + "http-body 0.4.5", "httparse", "httpdate", "itoa", @@ -1236,13 +1303,32 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + [[package]] name = "hyper-timeout" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper", + "hyper 0.14.27", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1255,12 +1341,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.27", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-util" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "hyper 1.5.0", + "pin-project-lite", + "tokio", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -1373,6 +1474,19 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "init-tracing-opentelemetry" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17" +dependencies = [ + "opentelemetry 0.20.0", + "opentelemetry-otlp 0.13.0", + "thiserror", + "tracing", + "tracing-opentelemetry 0.21.0", +] + [[package]] name = "instant" version = "0.1.12" @@ -1550,6 +1664,7 @@ dependencies = [ "ctrlc", "float_eq", "h2", + "hf-hub", "nix", "openssl", "reqwest", @@ -1557,7 +1672,7 @@ dependencies = [ "serde", "serde_json", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.17", "vergen", ] @@ -1567,15 +1682,15 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", "clap", - "flume 0.10.14", "futures", "h2", "hf-hub", "image", + "init-tracing-opentelemetry", "itertools 0.12.1", "lorax-client", "metrics", @@ -1587,7 +1702,7 @@ dependencies = [ "once_cell", "openssl", "opentelemetry 0.19.0", - "opentelemetry-otlp", + "opentelemetry-otlp 0.12.0", "rand", "regex", "reqwest", @@ -1601,10 +1716,12 @@ dependencies = [ "thiserror", "tokenizers", "tokio", - "tower-http 0.4.1", + "tokio-stream", + "tower-http", "tracing", "tracing-opentelemetry 0.19.0", - "tracing-subscriber", + "tracing-subscriber 0.3.17", + "tracing-test", "utoipa", "utoipa-swagger-ui", "vergen", @@ -1642,6 +1759,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" +[[package]] +name = "matchers" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1708,7 +1834,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a4964177ddfdab1e3a2b37aec7cf320e14169abb0ed73999f558136409178d5" dependencies = [ "base64 0.21.2", - "hyper", + "hyper 0.14.27", "indexmap 1.9.3", "ipnet", "metrics", @@ -1854,15 +1980,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom", -] - [[package]] name = "native-tls" version = "0.2.11" @@ -1897,12 +2014,12 @@ dependencies = [ "async-rustls", "async-trait", "awaitdrop", - "axum", + "axum 0.6.18", "base64 0.13.1", "bytes", "futures", "hostname", - "hyper", + "hyper 0.14.27", "muxado", "once_cell", "parking_lot 0.12.1", @@ -2131,22 +2248,38 @@ dependencies = [ [[package]] name = "opentelemetry" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e" +checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f" dependencies = [ - "opentelemetry_api 0.18.0", - "opentelemetry_sdk 0.18.0", + "opentelemetry_api 0.19.0", + "opentelemetry_sdk 0.19.0", ] [[package]] name = "opentelemetry" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f" +checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54" dependencies = [ - "opentelemetry_api 0.19.0", - "opentelemetry_sdk 0.19.0", + "opentelemetry_api 0.20.0", + "opentelemetry_sdk 0.20.0", +] + +[[package]] +name = "opentelemetry" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" +dependencies = [ + "futures-core", + "futures-sink", + "indexmap 2.0.0", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", ] [[package]] @@ -2158,15 +2291,34 @@ dependencies = [ "async-trait", "futures", "futures-util", - "http", + "http 0.2.9", "opentelemetry 0.19.0", - "opentelemetry-proto", + "opentelemetry-proto 0.2.0", "prost", "thiserror", "tokio", "tonic 0.8.3", ] +[[package]] +name = "opentelemetry-otlp" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275" +dependencies = [ + "async-trait", + "futures-core", + "http 0.2.9", + "opentelemetry-proto 0.3.0", + "opentelemetry-semantic-conventions", + "opentelemetry_api 0.20.0", + "opentelemetry_sdk 0.20.0", + "prost", + "thiserror", + "tokio", + "tonic 0.9.2", +] + [[package]] name = "opentelemetry-proto" version = "0.2.0" @@ -2180,32 +2332,53 @@ dependencies = [ "tonic 0.8.3", ] +[[package]] +name = "opentelemetry-proto" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb" +dependencies = [ + "opentelemetry_api 0.20.0", + "opentelemetry_sdk 0.20.0", + "prost", + "tonic 0.9.2", +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269" +dependencies = [ + "opentelemetry 0.20.0", +] + [[package]] name = "opentelemetry_api" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22" +checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2" dependencies = [ "fnv", "futures-channel", "futures-util", "indexmap 1.9.3", - "js-sys", "once_cell", "pin-project-lite", "thiserror", + "urlencoding", ] [[package]] name = "opentelemetry_api" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2" +checksum = "8a81f725323db1b1206ca3da8bb19874bbd3f57c3bcd59471bfb04525b265b9b" dependencies = [ - "fnv", "futures-channel", "futures-util", "indexmap 1.9.3", + "js-sys", "once_cell", "pin-project-lite", "thiserror", @@ -2214,9 +2387,9 @@ dependencies = [ [[package]] name = "opentelemetry_sdk" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113" +checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1" dependencies = [ "async-trait", "crossbeam-channel", @@ -2226,7 +2399,7 @@ dependencies = [ "futures-executor", "futures-util", "once_cell", - "opentelemetry_api 0.18.0", + "opentelemetry_api 0.19.0", "percent-encoding", "rand", "thiserror", @@ -2236,32 +2409,71 @@ dependencies = [ [[package]] name = "opentelemetry_sdk" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1" +checksum = "fa8e705a0612d48139799fcbaba0d4a90f06277153e43dd2bdc16c6f0edd8026" dependencies = [ "async-trait", "crossbeam-channel", - "dashmap", - "fnv", "futures-channel", "futures-executor", "futures-util", "once_cell", - "opentelemetry_api 0.19.0", + "opentelemetry_api 0.20.0", + "ordered-float 3.9.2", "percent-encoding", "rand", + "regex", + "serde_json", "thiserror", "tokio", "tokio-stream", ] +[[package]] +name = "opentelemetry_sdk" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f16aec8a98a457a52664d69e0091bac3a0abd18ead9b641cb00202ba4e0efe4" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.21.0", + "ordered-float 4.4.0", + "percent-encoding", + "rand", + "thiserror", +] + [[package]] name = "option-ext" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "3.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1e1c390732d15f1d48471625cd92d154e66db2c56645e29a9cd26f4699f72dc" +dependencies = [ + "num-traits", +] + +[[package]] +name = "ordered-float" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e7ccb95e240b7c9506a3d544f10d935e142cc90b0a1d56954fb44d89ad6b97" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" @@ -2761,9 +2973,9 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-tls", "ipnet", "js-sys", @@ -2795,7 +3007,7 @@ checksum = "88a3e86aa6053e59030e7ce2d2a3b258dd08fc2d337d52f73f6cb480f5858690" dependencies = [ "anyhow", "async-trait", - "http", + "http 0.2.9", "reqwest", "serde", "task-local-extensions", @@ -2813,8 +3025,8 @@ dependencies = [ "chrono", "futures", "getrandom", - "http", - "hyper", + "http 0.2.9", + "hyper 0.14.27", "parking_lot 0.11.2", "reqwest", "reqwest-middleware", @@ -2877,9 +3089,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "6.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a36224c3276f8c4ebc8c20f158eca7ca4359c8db89991c4925132aaaf6702661" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -2888,23 +3100,22 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "6.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b94b81e5b2c284684141a2fb9e2a31be90638caf040bf9afbc5a0416afe1ac" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "shellexpand", "syn 2.0.60", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "7.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d38ff6bf570dc3bb7100fce9f7b60c33fa71d80e88da3f2580df4ff2bdded74" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" dependencies = [ "sha2", "walkdir", @@ -3152,15 +3363,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shellexpand" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4" -dependencies = [ - "dirs 4.0.0", -] - [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -3211,9 +3413,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" @@ -3314,6 +3516,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" + [[package]] name = "sysinfo" version = "0.30.13" @@ -3617,15 +3825,15 @@ checksum = "8f219fad3b929bef19b1f86fbc0358d35daed8f2cac972037ac0dc10bbb8d5fb" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.6.18", "base64 0.13.1", "bytes", "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-timeout", "percent-encoding", "pin-project", @@ -3648,15 +3856,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" dependencies = [ "async-trait", - "axum", + "axum 0.6.18", "base64 0.21.2", "bytes", "futures-core", "futures-util", "h2", - "http", - "http-body", - "hyper", + "http 0.2.9", + "http-body 0.4.5", + "hyper 0.14.27", "hyper-timeout", "percent-encoding", "pin-project", @@ -3704,36 +3912,13 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.3.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" -dependencies = [ - "bitflags 1.3.2", - "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-range-header", - "pin-project-lite", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tower-http" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c" +checksum = "8437150ab6bbc8c5f0f519e3d5ed4aa883a83dd4cdd3d1b21f9482936046cb97" dependencies = [ "bitflags 2.3.3", "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-range-header", + "http 1.1.0", "pin-project-lite", "tower-layer", "tower-service", @@ -3741,9 +3926,9 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" @@ -3753,11 +3938,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", "log", "pin-project-lite", "tracing-attributes", @@ -3766,9 +3950,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", @@ -3777,9 +3961,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -3807,17 +3991,14 @@ dependencies = [ ] [[package]] -name = "tracing-opentelemetry" -version = "0.18.0" +name = "tracing-log" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ + "log", "once_cell", - "opentelemetry 0.18.0", - "tracing", "tracing-core", - "tracing-log", - "tracing-subscriber", ] [[package]] @@ -3830,8 +4011,54 @@ dependencies = [ "opentelemetry 0.19.0", "tracing", "tracing-core", - "tracing-log", - "tracing-subscriber", + "tracing-log 0.1.3", + "tracing-subscriber 0.3.17", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75327c6b667828ddc28f5e3f169036cb793c3f588d83bf0f262a7f062ffed3c8" +dependencies = [ + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry_sdk 0.20.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.1.3", + "tracing-subscriber 0.3.17", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c67ac25c5407e7b961fafc6f7e9aa5958fd297aada2d20fa2ae1737357e55596" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.21.0", + "opentelemetry_sdk 0.21.2", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber 0.3.17", + "web-time", +] + +[[package]] +name = "tracing-opentelemetry-instrumentation-sdk" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9920abb6a3ee3a2af7d30c9ff02900f8481935d36723c3da95cf807468218e8c" +dependencies = [ + "http 1.1.0", + "opentelemetry 0.21.0", + "tracing", + "tracing-opentelemetry 0.22.0", ] [[package]] @@ -3844,13 +4071,35 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "ansi_term", + "chrono", + "lazy_static", + "matchers 0.0.1", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log 0.1.3", + "tracing-serde", +] + [[package]] name = "tracing-subscriber" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" dependencies = [ - "matchers", + "matchers 0.1.0", "nu-ansi-term", "once_cell", "regex", @@ -3861,10 +4110,33 @@ dependencies = [ "thread_local", "tracing", "tracing-core", - "tracing-log", + "tracing-log 0.1.3", "tracing-serde", ] +[[package]] +name = "tracing-test" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3b48778c2d401c6a7fcf38a0e3c55dc8e8e753cbd381044a8cdb6fd69a29f53" +dependencies = [ + "lazy_static", + "tracing-core", + "tracing-subscriber 0.2.25", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c49adbab879d2e0dd7f75edace5f0ac2156939ecb7e6a1e8fa14e53728328c48" +dependencies = [ + "lazy_static", + "quote", + "syn 1.0.109", +] + [[package]] name = "try-lock" version = "0.2.4" @@ -3991,9 +4263,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "utoipa" -version = "3.4.0" +version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "520434cac5c98120177d5cc15be032703f6dca7d5ef82e725c798113b375000a" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ "indexmap 2.0.0", "serde", @@ -4003,9 +4275,9 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "3.4.1" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e22e88a487b6e0374533871b79b1f5ded05671bd0936bd547eb42f82fb9060d" +checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392" dependencies = [ "proc-macro-error", "proc-macro2", @@ -4016,11 +4288,11 @@ dependencies = [ [[package]] name = "utoipa-swagger-ui" -version = "3.1.4" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4602d7100d3cfd8a086f30494e68532402ab662fa366c9d201d677e33cee138d" +checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da" dependencies = [ - "axum", + "axum 0.7.5", "mime_guess", "regex", "rust-embed", @@ -4198,6 +4470,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa30049b1c872b72c89866d458eae9f20380ab280ffd1b1e18df2d3e2d98cfe0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" diff --git a/Dockerfile b/Dockerfile index e7aa5a3dd..eccefae58 100644 --- a/Dockerfile +++ b/Dockerfile @@ -216,7 +216,7 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 RUN pip install einops --no-cache-dir # Install flashinfer -RUN pip install --no-cache-dir flashinfer==0.1.5+cu124torch2.4 -i https://flashinfer.ai/whl/cu124/torch2.4 +RUN pip install --no-cache-dir flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4 # Install server COPY proto proto diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 905df9417..a8028834c 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -1,4 +1,5 @@ import json +import logging import requests from requests.adapters import HTTPAdapter, Retry @@ -20,7 +21,22 @@ from lorax.errors import parse_error import os -LORAX_DEBUG_MODE = os.getenv("LORAD_DEBUG_MODE", None) is not None +LORAX_DEBUG_MODE = os.getenv("LORAX_DEBUG_MODE", None) is not None +if LORAX_DEBUG_MODE: + # https://stackoverflow.com/a/16630836/1869739 + # These two lines enable debugging at httplib level (requests->urllib3->http.client) + # You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA. + # The only thing missing will be the response.body which is not logged. + import http.client as http_client + http_client.HTTPConnection.debuglevel = 1 + + # You must initialize logging, otherwise you'll not see debug output. + logging.basicConfig() + logging.getLogger().setLevel(logging.DEBUG) + requests_log = logging.getLogger("requests.packages.urllib3") + requests_log.setLevel(logging.DEBUG) + requests_log.propagate = True + class Client: """Client to make calls to a LoRAX instance diff --git a/docs/guides/contributing/development_env.md b/docs/guides/contributing/development_env.md index f33c4c2d6..5f35eebb1 100644 --- a/docs/guides/contributing/development_env.md +++ b/docs/guides/contributing/development_env.md @@ -47,12 +47,12 @@ We'll be working out of three different terminals during development, each servi Install development dependencies: ```shell -DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y +DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \ PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP + rm -f $PROTOC_ZIP && \ hash -r ``` @@ -71,8 +71,7 @@ tmux new -s server From within the `tmux` session, move into the LoRAX `server` directory within the repo (assumed to be in `/data/lorax`) and install dependencies: ```shell -cd /data/lorax/server -pip install -e . +cd /data/lorax/server && pip install -e . make gen-server ``` @@ -95,9 +94,9 @@ tmux new -s router Now move into the `router` directory within the repo and install dependencies: ```shell -cd /data/lorax/router -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -export PATH=$PATH:$HOME/.cargo/bin +cd /data/lorax/router && \ +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ +export PATH=$PATH:$HOME/.cargo/bin && \ touch ../proto/generate.proto ``` diff --git a/docs/guides/contributing/index.md b/docs/guides/contributing/index.md index 5c78b97f4..5d9eeef27 100644 --- a/docs/guides/contributing/index.md +++ b/docs/guides/contributing/index.md @@ -23,3 +23,22 @@ make export-requirements ``` Never modify `requirements.txt` directly, as it may introduce dependency conflicts. + +## Profiling + +LoRAX supports the [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) to measure performance of LoRAX. + +You can enable profiling when launching LoRAX by setting the `LORAX_PROFILER_DIR` environment variable to the directory +you wish to output the Tensorboard traces to. + +Once initialized, LoRAX will begin recording traces for every request to the server. Because traces can get very large, +we record only the first 10 prefill requests (plus any decode requests between them), then stop recording and write +out the results. A summary will be printed to stdout when this occurs. + +Once you have your traces written to the profiler directory, you can visualize them in Tensorboard using the +[PyTorch Profiler Tensorboard Plugin](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). + +```bash +pip install torch_tb_profiler +tensorboard --logdir=$LORAX_PROFILER_DIR +``` diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index cd1f3ef2c..efaf4e0dd 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -11,6 +11,7 @@ clap = { version = "4.1.4", features = ["derive", "env"] } ctrlc = { version = "3.2.5", features = ["termination"] } nix = "0.26.2" openssl = "0.10.66" +hf-hub = { version = "0.3.0", features = ["tokio"] } h2 = "0.3.26" rustix = "0.37.25" serde = { version = "1.0.152", features = ["derive"] } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index bac9f04a8..26a2e4aa9 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,5 +1,8 @@ use clap::{Parser, ValueEnum}; -use nix::libc::ip_mreq_source; +use hf_hub::{ + api::sync::{Api, ApiBuilder}, + Repo, RepoType, +}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use serde::Deserialize; @@ -20,34 +23,179 @@ use tracing_subscriber::EnvFilter; mod env_runtime; -#[derive(Clone, Copy, Debug, ValueEnum)] +fn get_config( + model_id: &str, + revision: &Option, +) -> Result> { + let mut path = std::path::Path::new(model_id).to_path_buf(); + let model_id = model_id.to_string(); + let filename = if !path.exists() { + // Assume it's a hub id + + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path + }; + + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; + + let config: Config = config.into(); + Ok(config) +} + +#[derive(Deserialize)] +struct RawConfig { + max_position_embeddings: Option, + n_positions: Option, + model_type: Option, + max_seq_len: Option, + quantization_config: Option, + n_embd: Option, + hidden_size: Option, + num_attention_heads: Option, + head_dim: Option, + vision_config: Option, + is_encoder_decoder: Option, +} + +#[derive(Deserialize)] +struct QuantizationConfig { + quant_method: Option, +} + +#[derive(Deserialize)] +struct VisionConfig {} + +#[derive(Deserialize)] +struct Config { + max_position_embeddings: Option, + quantize: Option, + head_dim: Option, + model_type: Option, + vision_config: Option, + is_encoder_decoder: bool, +} + +impl From for Config { + fn from(other: RawConfig) -> Self { + let max_position_embeddings = other + .max_position_embeddings + .or(other.max_seq_len) + .or(other.n_positions); + let quantize = other.quantization_config.and_then(|q| q.quant_method); + let head_dim = other.head_dim.or_else(|| { + match (other.hidden_size, other.n_embd, other.num_attention_heads) { + (Some(hidden_size), _, Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + // Legacy + (_, Some(hidden_size), Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + _ => None, + } + }); + let model_type = other.model_type; + let vision_config = other.vision_config; + let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); + Config { + max_position_embeddings, + quantize, + head_dim, + model_type, + vision_config, + is_encoder_decoder, + } + } +} + +#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)] +#[serde(rename_all = "kebab-case")] enum Quantization { - Bitsandbytes, - BitsandbytesNF4, - BitsandbytesFP4, - Gptq, + /// 4 bit quantization. Requires a specific AWQ quantized model: + /// . + /// Should replace GPTQ models wherever possible because of the better latency Awq, + /// 8 bit quantization, doesn't require specific model. + /// Should be a drop-in replacement to bitsandbytes with much better performance. + /// Kernels are from Eetq, + /// Variable bit quantization. Requires a specific EXL2 quantized model: + /// . Requires exllama2 kernels and does + /// not support tensor parallelism (num_shard > 1). + Exl2, + /// 4 bit quantization. Requires a specific GTPQ quantized model: . + /// text-generation-inference will use exllama (faster) kernels wherever possible, and use + /// triton kernel (wider support) when it's not. + /// AWQ has faster kernels. + Gptq, + /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, + /// but it is known that the model will be much slower to run than the native f16. + // #[deprecated( + // since = "1.1.0", + // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + // )] + Bitsandbytes, + /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, + /// but it is known that the model will be much slower to run than the native f16. + BitsandbytesNf4, + /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better + /// perplexity performance for you model + BitsandbytesFp4, + /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above + /// This dtype has native ops should be the fastest if available. + /// This is currently not the fastest because of local unpacking + padding to satisfy matrix + /// multiplication limitations. + Fp8, + /// FP8 with statically quantized KV cache + Fp8_KV, + /// 4 bit quantization. Requires a specific HQQ quantized model. Hqq_4bit, + /// 3 bit quantization. Requires a specific HQQ quantized model. Hqq_3bit, + /// 2 bit quantization. Requires a specific HQQ quantized model. Hqq_2bit, - Fp8, - Fp8_KV, } impl std::fmt::Display for Quantization { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // To keep in track with `server`. match self { + #[allow(deprecated)] + // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases Quantization::Bitsandbytes => { write!(f, "bitsandbytes") } - Quantization::BitsandbytesNF4 => { + Quantization::BitsandbytesNf4 => { write!(f, "bitsandbytes-nf4") } - Quantization::BitsandbytesFP4 => { + Quantization::BitsandbytesFp4 => { write!(f, "bitsandbytes-fp4") } + Quantization::Exl2 => { + write!(f, "exl2") + } Quantization::Gptq => { write!(f, "gptq") } @@ -57,6 +205,12 @@ impl std::fmt::Display for Quantization { Quantization::Eetq => { write!(f, "eetq") } + Quantization::Fp8 => { + write!(f, "fp8") + } + Quantization::Fp8_KV => { + write!(f, "fp8-kv") + } Quantization::Hqq_4bit => { write!(f, "hqq-4bit") } @@ -66,12 +220,6 @@ impl std::fmt::Display for Quantization { Quantization::Hqq_2bit => { write!(f, "hqq-2bit") } - Quantization::Fp8 => { - write!(f, "fp8") - } - Quantization::Fp8_KV => { - write!(f, "fp8-kv") - } } } } @@ -254,8 +402,9 @@ struct Args { /// for users. The larger this value, the longer prompt users can send which /// can impact the overall memory required to handle the load. /// Please note that some models have a finite range of sequence they can handle. - #[clap(default_value = "1024", long, env)] - max_input_length: usize, + /// Default to min(max_position_embeddings - 1, 4095) + #[clap(long, env)] + max_input_length: Option, /// This is the most important value to set as it defines the "memory budget" /// of running clients requests. @@ -265,8 +414,9 @@ struct Args { /// `1511` max_new_tokens. /// The larger this value, the larger amount each request will be in your RAM /// and the less effective batching can be. - #[clap(default_value = "2048", long, env)] - max_total_tokens: usize, + /// Default to min(max_position_embeddings, 4096) + #[clap(long, env)] + max_total_tokens: Option, /// This represents the ratio of waiting queries vs running queries where /// you want to start considering pausing the running queries to include the waiting @@ -284,8 +434,9 @@ struct Args { /// Limits the number of tokens for the prefill operation. /// Since this operation take the most memory and is compute bound, it is interesting /// to limit the number of requests that can be sent. - #[clap(default_value = "4096", long, env)] - max_batch_prefill_tokens: u32, + /// Default to `max_input_tokens + 50` to give a bit of room. + #[clap(long, env)] + max_batch_prefill_tokens: Option, /// **IMPORTANT** This is one critical control to allow maximum usage /// of the available hardware. @@ -1182,6 +1333,9 @@ fn spawn_shards( fn spawn_webserver( args: Args, + max_input_tokens: usize, + max_total_tokens: usize, + max_batch_prefill_tokens: u32, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, ) -> Result { @@ -1196,11 +1350,11 @@ fn spawn_webserver( "--max-stop-sequences".to_string(), args.max_stop_sequences.to_string(), "--max-input-length".to_string(), - args.max_input_length.to_string(), + max_input_tokens.to_string(), "--max-total-tokens".to_string(), - args.max_total_tokens.to_string(), + max_total_tokens.to_string(), "--max-batch-prefill-tokens".to_string(), - args.max_batch_prefill_tokens.to_string(), + max_batch_prefill_tokens.to_string(), "--max-active-adapters".to_string(), args.max_active_adapters.to_string(), "--adapter-cycle-time-s".to_string(), @@ -1407,18 +1561,69 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:?}", args); + let config: Option = get_config(&args.model_id, &args.revision).ok(); + let max_default = 4096; + let max_position_embeddings = if let Some(config) = &config { + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_length.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + max_default + } else { + max_position_embeddings + } + } else { + max_default + } + } else { + max_default + }; + + // Defaults + let max_input_tokens = { + match args.max_input_length { + Some(max_input_tokens) => max_input_tokens, + None => { + let value = max_position_embeddings - 1; + tracing::info!("Default `max_input_tokens` to {value}"); + value + } + } + }; + let max_total_tokens = { + match args.max_total_tokens { + Some(max_total_tokens) => max_total_tokens, + None => { + let value = max_position_embeddings; + tracing::info!("Default `max_total_tokens` to {value}"); + value + } + } + }; + let max_batch_prefill_tokens = { + match args.max_batch_prefill_tokens { + Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, + None => { + // Adding some edge in order to account for potential block_size alignement + // issue. + let value: u32 = (max_input_tokens + 50) as u32; + tracing::info!("Default `max_batch_prefill_tokens` to {value}"); + value + } + } + }; + // Validate args - if args.max_input_length >= args.max_total_tokens { + if max_input_tokens >= max_total_tokens { return Err(LauncherError::ArgumentValidation( "`max_input_length` must be < `max_total_tokens`".to_string(), )); } - if args.max_input_length as u32 > args.max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}", - args.max_batch_prefill_tokens, args.max_input_length - ))); - } if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( @@ -1438,16 +1643,16 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if args.max_batch_prefill_tokens > *max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_batch_prefill_tokens, max_batch_total_tokens + max_batch_prefill_tokens, max_batch_total_tokens ))); } - if args.max_total_tokens as u32 > *max_batch_total_tokens { + if max_total_tokens as u32 > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_total_tokens, max_batch_total_tokens + max_total_tokens, max_batch_total_tokens ))); } } @@ -1513,11 +1718,18 @@ fn main() -> Result<(), LauncherError> { return Ok(()); } - let mut webserver = - spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { - shutdown_shards(shutdown.clone(), &shutdown_receiver); - err - })?; + let mut webserver = spawn_webserver( + args, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + shutdown.clone(), + &shutdown_receiver, + ) + .map_err(|err| { + shutdown_shards(shutdown.clone(), &shutdown_receiver); + err + })?; // Default exit code let mut exit_code = Ok(()); diff --git a/router/Cargo.toml b/router/Cargo.toml index 27373864f..c325ba9dc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,10 +16,9 @@ path = "src/main.rs" [dependencies] async-stream = "0.3.3" -axum = { version = "0.6.4", features = ["json"] } -axum-tracing-opentelemetry = "0.10.0" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" clap = { version = "4.1.4", features = ["derive", "env"] } -flume = "0.10.14" futures = "0.3.26" hf-hub = { version = "0.3.0", features = ["tokio"] } h2 = "0.3.26" @@ -48,13 +47,17 @@ tokio = { version = "1.32.0", features = [ "signal", "sync", ] } -tower-http = { version = "0.4.0", features = ["cors"] } +tokio-stream = "0.1.14" +tower-http = { version = "0.6.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } -utoipa = { version = "3.0.1", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.12.3", features = ["axum"], optional = true } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } once_cell = "1.19.0" itertools = "0.12.1" async-trait = "0.1.80" @@ -68,6 +71,9 @@ base64 = "0.22.0" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } +[dev-dependencies] +tracing-test = "0.1" + [features] default = ["ngrok"] ngrok = ["dep:ngrok"] diff --git a/router/src/batch.rs b/router/src/batch.rs index 2be4bcdbe..39f8d32c7 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -12,7 +12,7 @@ use lorax_client::{ StoppingCriteriaParameters, TokenizedInputs, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use tokio::time::Instant; +use tokio::{sync::mpsc, time::Instant}; use tracing::{Instrument, Span}; use crate::{ @@ -22,6 +22,7 @@ use crate::{ }; pub(crate) trait ValidRequest: Sync + Send + Debug + Any { + fn decoder_input_details(&self) -> bool; fn input_length(&self) -> u32; fn input_ids(&self) -> Option>>; fn max_new_tokens(&self) -> u32; @@ -31,6 +32,10 @@ pub(crate) trait ValidRequest: Sync + Send + Debug + Any { } impl ValidRequest for ValidGenerateRequest { + fn decoder_input_details(&self) -> bool { + self.decoder_input_details + } + fn input_length(&self) -> u32 { self.input_length } @@ -69,6 +74,10 @@ pub(crate) struct ValidEmbedRequest { } impl ValidRequest for ValidEmbedRequest { + fn decoder_input_details(&self) -> bool { + false + } + fn input_length(&self) -> u32 { self.input_length } @@ -107,6 +116,10 @@ pub(crate) struct ValidClassifyRequest { } impl ValidRequest for ValidClassifyRequest { + fn decoder_input_details(&self) -> bool { + false + } + fn input_length(&self) -> u32 { self.input_length } @@ -154,7 +167,7 @@ pub(crate) struct Entry { /// Request pub request: Arc, /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: flume::Sender>, + pub response_tx: mpsc::UnboundedSender>, /// Span that will live as long as entry pub span: Span, /// Temporary span used as a guard when logging inference, wait times... diff --git a/router/src/infer.rs b/router/src/infer.rs index 5626302c3..0dbec8557 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -9,8 +9,6 @@ use crate::{ MessageChunk, TextMessage, Token, TokenizerConfigToken, Tool, }; use crate::{GenerateRequest, PrefillToken}; -use flume::r#async::RecvStream; -use flume::SendTimeoutError; use futures::future::try_join_all; use futures::stream::StreamExt; /// Batching and inference logic @@ -29,11 +27,12 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use std::time::Duration; use thiserror::Error; use tokenizers::Tokenizer; -use tokio::sync::{Mutex, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Mutex, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Span}; #[derive(Clone, Serialize, Deserialize, Default)] @@ -276,7 +275,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, - RecvStream>, + UnboundedReceiverStream>, ), InferError, > { @@ -330,7 +329,7 @@ impl Infer { })?; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Process the request by sending it to the queue associated with `adapter` self.adapter_scheduler.process( @@ -348,7 +347,7 @@ impl Infer { ); // Return stream - Ok((permit, response_rx.into_stream())) + Ok((permit, UnboundedReceiverStream::new(response_rx))) } /// Tokenizer the input @@ -542,7 +541,7 @@ impl Infer { }; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Process the request by sending it to the queue associated with `adapter` self.adapter_scheduler.process( @@ -562,7 +561,7 @@ impl Infer { // Return values let mut return_embeddings = None; - let mut stream = response_rx.into_stream(); + let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { // Add prefill tokens @@ -643,7 +642,7 @@ impl Infer { }; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Process the request by sending it to the queue associated with `adapter` self.adapter_scheduler.process( @@ -665,7 +664,7 @@ impl Infer { let mut result_start = None; let mut result_queued = None; - let mut stream = response_rx.into_stream(); + let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { // Add prefill tokens @@ -743,7 +742,7 @@ impl Infer { ); // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); let request_id_map: HashMap = request .inputs @@ -793,7 +792,7 @@ impl Infer { // Return values let mut all_entities = HashMap::new(); - let mut stream = response_rx.into_stream(); + let mut stream = UnboundedReceiverStream::new(response_rx); while let Some(response) = stream.next().await { match response? { // Add prefill tokens @@ -1104,10 +1103,10 @@ pub(crate) async fn prefill( // Update health generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + let removed = filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, removed).await; // TODO(travis) // if let Some(concat_duration) = timings.concat { @@ -1147,10 +1146,10 @@ pub(crate) async fn decode( // Update health generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + let removed = filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, removed).await; metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("lorax_batch_inference_success", "method" => "decode"); @@ -1198,10 +1197,7 @@ pub(crate) async fn embed( // request and we need to stop generating hence why we unwrap_or(true) let stopped = send_embeddings(embedding, entry) .map_err(|err| { - if let SendTimeoutError::Timeout(_) = *err { - tracing::error!("Entry response channel timed out.") - } - + tracing::error!("Entry response channel error."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); err }) @@ -1258,10 +1254,7 @@ pub(crate) async fn classify( // request and we need to stop generating hence why we unwrap_or(true) let stopped = send_classifications(predictions, entry) .map_err(|err| { - if let SendTimeoutError::Timeout(_) = *err { - tracing::error!("Entry response channel timed out.") - } - + tracing::error!("Entry response channel error."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); err }) @@ -1294,11 +1287,12 @@ async fn filter_batch( client: &mut ShardedClient, next_batch: Option, entries: &IntMap, + removed: bool, ) -> Option { let mut batch = next_batch?; - // No need to filter - if batch.size as usize == entries.len() { + // No need to filter is we haven't removed any entries + if !removed { return Some(batch); } @@ -1324,7 +1318,8 @@ async fn filter_batch( /// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// and filter entries #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { +fn filter_send_generations(generations: Vec, entries: &mut IntMap) -> bool { + let mut removed = false; generations.into_iter().for_each(|generation| { let id = generation.request_id; // Get entry @@ -1338,27 +1333,28 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); - err }).unwrap_or(true); if stopped { entries.remove(&id).expect("ID not found in entries. This is a bug."); + removed = true; } }); + removed } /// Send responses through the `entry` response channel fn send_responses( generation: Generation, entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { +) -> Result>>> { + // Return directly if the channel is closed + let request_id = generation.request_id; + if entry.response_tx.is_closed() { + tracing::error!("Entry id={request_id:?} response channel closed."); + metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); return Ok(true); } @@ -1366,13 +1362,10 @@ fn send_responses( if generation.prefill_tokens_length > 0 { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Prefill { - tokens: generation.prefill_tokens, - tokens_length: generation.prefill_tokens_length, - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::Prefill { + tokens: generation.prefill_tokens, + tokens_length: generation.prefill_tokens_length, + }))?; } // Create last Token @@ -1423,22 +1416,18 @@ fn send_responses( // Generation has ended stopped = true; // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::End { - token, - generated_text: generated_text.clone(), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + generated_text: generated_text.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; } _ => { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Token(token)), - Duration::from_millis(10), - )?; + entry + .response_tx + .send(Ok(InferStreamResponse::Token(token)))?; } } } @@ -1450,20 +1439,17 @@ fn send_responses( fn send_embeddings( embedding: Embedding, entry: &Entry, -) -> Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { return Ok(true); } - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Embed { - embedding: embedding.clone(), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::Embed { + embedding: embedding.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; // TODO(travis): redundant as we always return true, just make it return nothing Ok(true) @@ -1473,21 +1459,18 @@ fn send_embeddings( fn send_classifications( predictions: ClassifyPredictionList, entry: &Entry, -) -> Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { return Ok(true); } - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Classify { - predictions: predictions.clone(), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - id: entry.id, - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::Classify { + predictions: predictions.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + id: entry.id, + }))?; // TODO(travis): redundant as we always return true, just make it return nothing Ok(true) @@ -1506,7 +1489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx - .send_timeout(Err(err), Duration::from_millis(10)) + .send(Err(err)) .unwrap_or(()); }); } diff --git a/router/src/loader.rs b/router/src/loader.rs index 1233c664c..274e636de 100644 --- a/router/src/loader.rs +++ b/router/src/loader.rs @@ -3,20 +3,20 @@ use crate::infer::InferError; use crate::queue::{AdapterQueuesState, AdapterStatus}; use lorax_client::ShardedClient; use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tracing::Span; /// Request AdapterLoader #[derive(Debug, Clone)] pub(crate) struct AdapterLoader { /// Channel to communicate with the background task - sender: flume::Sender, + sender: mpsc::UnboundedSender, } impl AdapterLoader { pub(crate) fn new(client: ShardedClient) -> Self { // Create channel - let (sender, receiver) = flume::unbounded(); + let (sender, receiver) = mpsc::unbounded_channel(); // Launch background queue task tokio::spawn(loader_task(client, receiver)); @@ -115,10 +115,13 @@ impl AdapterLoader { } // Background task responsible of the loader state -async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver) { +async fn loader_task( + mut client: ShardedClient, + mut receiver: mpsc::UnboundedReceiver, +) { let mut err_msgs: HashMap = HashMap::new(); - while let Ok(cmd) = receiver.recv_async().await { + while let Some(cmd) = receiver.recv().await { match cmd { AdapterLoaderCommand::DownloadAdapter { adapter, diff --git a/router/src/main.rs b/router/src/main.rs index 5d47d93ea..249031258 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -141,9 +141,6 @@ async fn main() -> Result<(), RouterError> { "`max_input_length` must be < `max_total_tokens`".to_string(), )); } - if max_input_length as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); - } if validation_workers == 0 { return Err(RouterError::ArgumentValidation( @@ -151,15 +148,6 @@ async fn main() -> Result<(), RouterError> { )); } - if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } - } - // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin @@ -445,6 +433,18 @@ async fn main() -> Result<(), RouterError> { tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Connected"); + let supports_chunking = shard_info.chunked_prefill; + let max_batch_total_tokens = max_supported_batch_total_tokens; + if max_input_length as u32 > max_batch_prefill_tokens && !supports_chunking { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + let addr = match hostname.parse() { Ok(ip) => SocketAddr::new(ip, port), Err(_) => { @@ -533,7 +533,7 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { if let Ok(tracer) = tracer { layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); - axum_tracing_opentelemetry::init_propagator().unwrap(); + init_tracing_opentelemetry::init_propagator().unwrap(); }; } diff --git a/router/src/radix.rs b/router/src/radix.rs index 243df370a..75430f37f 100644 --- a/router/src/radix.rs +++ b/router/src/radix.rs @@ -6,13 +6,12 @@ use std::{ sync::Arc, }; -fn hash(adapter_index: u32, slice: &[u32]) -> u64 { +fn hash(slice: &[u32]) -> u64 { assert!(!slice.is_empty()); - if slice.len() == 1 && adapter_index == 0 { + if slice.len() == 1 { slice[0] as u64 } else { let mut s = std::hash::DefaultHasher::new(); - adapter_index.hash(&mut s); slice.hash(&mut s); s.finish() } @@ -50,7 +49,7 @@ impl RadixAllocator { } } - fn alloc_or_reclaim(&mut self, adapter_index: u32, n_blocks_needed: usize) -> Option> { + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { if self.free_blocks.len() < n_blocks_needed { // This is a bit annoying, we first extend the free list and then // split it off again below. This is because we need to put it on @@ -63,7 +62,7 @@ impl RadixAllocator { ); self.free_blocks.extend( self.cache_blocks - .evict(adapter_index, n_blocks_needed - self.free_blocks.len()), + .evict(n_blocks_needed - self.free_blocks.len()), ); } @@ -86,6 +85,17 @@ impl Allocator for RadixAllocator { tokens: u32, prefill_tokens: Option>>, ) -> Option { + // print out blocks for allocation + tracing::debug!( + "!!! Allocate blocks {:?} {:?} {:?}", + adapter_index, + tokens, + prefill_tokens.as_ref().as_slice() + ); + + // ensure root node exists + self.cache_blocks.get_or_create_root(adapter_index); + let mut blocks = vec![]; let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { let node_id = @@ -93,7 +103,7 @@ impl Allocator for RadixAllocator { .find(adapter_index, prefill_tokens.as_slice(), &mut blocks); node_id } else { - self.cache_blocks.root_id() + self.cache_blocks.root_id(adapter_index) }; // Even if this allocation fails below, we need to increase he @@ -108,8 +118,9 @@ impl Allocator for RadixAllocator { let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; tracing::debug!("Prefix {prefix_len} - Suffix {suffix_len}"); + tracing::debug!("Cached blocks: {blocks:?}"); - match self.alloc_or_reclaim(adapter_index, suffix_blocks as usize) { + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { tracing::debug!("Cannot allocate {:?}", self.cache_blocks); @@ -148,6 +159,15 @@ impl Allocator for RadixAllocator { self.allocation_id += 1; self.allocations.insert(self.allocation_id, allocation); + // log final blocks and slots + tracing::debug!( + "!!! BlockAllocation {:?} {:?} {:?} {:?}", + adapter_index, + blocks, + slots, + prefix_len + ); + Some(BlockAllocation { allocation_id: self.allocation_id, block_allocator: None, @@ -163,6 +183,13 @@ impl Allocator for RadixAllocator { None => unreachable!("Tried to free an unknown allocation."), }; + tracing::debug!( + "!!! Free blocks {:?} {:?} {:?}", + allocation.adapter_index, + allocation.cached_prefix_len, + allocation.prefill_tokens.as_ref().as_slice() + ); + self.cache_blocks .decref(allocation.prefix_node) .expect("Failed to decrement refcount"); @@ -241,8 +268,8 @@ pub type NodeId = DefaultKey; #[derive(Debug)] pub struct RadixTrie { - /// Identifier of the root nod. - root: DefaultKey, + /// Adapter index --> Identifier of the root node. + roots: HashMap, /// Leave node identifiers ordered by increasing recency. leaves: BTreeSet<(u64, NodeId)>, @@ -261,13 +288,13 @@ pub struct RadixTrie { impl RadixTrie { /// Construct a new radix trie. pub fn new(block_size: usize) -> Self { - let root = TrieNode::new(vec![], vec![], 0, None); - let mut nodes = SlotMap::new(); - let root = nodes.insert(root); + let nodes = SlotMap::new(); + let roots = HashMap::new(); + RadixTrie { leaves: BTreeSet::new(), nodes, - root, + roots, time: 0, block_size, } @@ -284,21 +311,15 @@ impl RadixTrie { /// Using this method will update the access time of the traversed nodes. pub fn find(&mut self, adapter_index: u32, key: &[u32], blocks: &mut Vec) -> NodeId { self.time += 1; - self.find_(adapter_index, self.root, key, blocks) + self.find_(self.root_id(adapter_index), key, blocks) } /// Find worker. - fn find_( - &mut self, - adapter_index: u32, - mut node_id: NodeId, - key: &[u32], - blocks: &mut Vec, - ) -> NodeId { + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { let node = &self.nodes[node_id]; if key.len() >= self.block_size { - let node_key = hash(adapter_index, &key[..self.block_size]); + let node_key = hash(&key[..self.block_size]); if let Some(&child_id) = node.children.get(&node_key) { self.update_access_time(child_id); let child = self.nodes.get(child_id).expect("Invalid child identifier"); @@ -308,7 +329,7 @@ impl RadixTrie { let key = &key[shared_prefix_len..]; if !key.is_empty() { - node_id = self.find_(adapter_index, child_id, key, blocks); + node_id = self.find_(child_id, key, blocks); } } } @@ -320,7 +341,7 @@ impl RadixTrie { pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { // We don't care about refcounting for root, since it will never // be evicted. - if node_id == self.root { + if self.is_root(node_id) { return Ok(()); } @@ -347,7 +368,7 @@ impl RadixTrie { /// Increase the reference count of a node. pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { - if node_id == self.root { + if self.is_root(node_id) { return Ok(()); } @@ -367,7 +388,7 @@ impl RadixTrie { /// /// Returns the evicted blocks. When the length is less than `n_blocks`, /// not enough blocks could be evicted. - pub fn evict(&mut self, adapter_index: u32, n_blocks: usize) -> Vec { + pub fn evict(&mut self, n_blocks: usize) -> Vec { // NOTE: we don't return Result here. If any of the unwrapping fails, // it's a programming error in the trie implementation, not a user // error caused by e.g. an invalid argument. @@ -382,7 +403,7 @@ impl RadixTrie { let blocks_needed = n_blocks.saturating_sub(evicted.len()); tracing::debug!("Evicting node {node_id:?} "); - let node = self.nodes.get(node_id).expect("Leave does not exist"); + let node = self.nodes.get(node_id).expect("Leaf does not exist"); assert_eq!( node.ref_count, 0, "Leaf must have refcount of 0, got {}", @@ -391,7 +412,7 @@ impl RadixTrie { if blocks_needed >= node.blocks.len() { // We need to evict the whole node if we need more blocks than it has. - let node = self.remove_node(adapter_index, node_id); + let node = self.remove_node(node_id); evicted.extend(node.blocks); if evicted.len() >= n_blocks { @@ -401,7 +422,7 @@ impl RadixTrie { // The node has more blocks than needed, so we'll just remove // the required number of blocks and leave the remaining blocks // untouched. - let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + let node = self.nodes.get_mut(node_id).expect("Leaf does not exist"); let truncate_blocks = node.blocks.len() - blocks_needed; let truncate_tokens = truncate_blocks * self.block_size; @@ -427,14 +448,14 @@ impl RadixTrie { blocks: &[u32], ) -> Result { self.time += 1; - let common = self.insert_(adapter_index, self.root, tokens, blocks)?; + let node_id = self.get_or_create_root(adapter_index); + let common = self.insert_(node_id, tokens, blocks)?; Ok(common) } /// Insertion worker. fn insert_( &mut self, - adapter_index: u32, node_id: NodeId, tokens: &[u32], blocks: &[u32], @@ -445,7 +466,7 @@ impl RadixTrie { assert_eq!(tokens.len(), blocks.len() * self.block_size); - let node_key = hash(adapter_index, &tokens[..self.block_size]); + let node_key = hash(&tokens[..self.block_size]); if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) { self.update_access_time(child_id); let child = self @@ -464,7 +485,6 @@ impl RadixTrie { if shared_prefix_len == child.key.len() { return Ok(shared_prefix_len + self.insert_( - adapter_index, child_id, &tokens[shared_prefix_len..], &blocks[shared_prefix_len / self.block_size..], @@ -474,17 +494,17 @@ impl RadixTrie { // The node's prefix and the insertion prefix only match partially, // split the node to just contain the matching part. Then insert the // remainder of the prefix into the node again - let child_id = self.split_node(adapter_index, child_id, shared_prefix_len); + let child_id = self.split_node(child_id, shared_prefix_len); let key = &tokens[shared_prefix_len..]; let blocks = &blocks[shared_prefix_len / self.block_size..]; - Ok(shared_prefix_len + self.insert_(adapter_index, child_id, key, blocks)?) + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) } else { - self.add_node(adapter_index, node_id, tokens, blocks); + self.add_node(node_id, tokens, blocks); Ok(0) } } - fn split_node(&mut self, adapter_index: u32, node_id: NodeId, prefix_len: usize) -> NodeId { + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { // We have to make the current node a child to ensure that its // properties and node id stay the same. @@ -503,10 +523,10 @@ impl RadixTrie { std::mem::swap(&mut node.key, &mut parent_key); std::mem::swap(&mut node.blocks, &mut parent_blocks); - let node_key = hash(adapter_index, &node.key[..self.block_size]); + let node_key = hash(&node.key[..self.block_size]); let grandparent_id = node.parent.expect("Node does not have a parent"); - let parent_id = self.add_node(adapter_index, grandparent_id, parent_key, parent_blocks); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); self.add_node_to_parent(parent_id, node_key, node_id); // Reborrow to make the borrow checker happy. @@ -522,14 +542,13 @@ impl RadixTrie { /// Create a node and add it to the parent. fn add_node( &mut self, - adapter_index: u32, parent_id: NodeId, key: impl Into>, blocks: impl Into>, ) -> NodeId { let key = key.into(); let blocks = blocks.into(); - let first = hash(adapter_index, &key[..self.block_size]); + let first = hash(&key[..self.block_size]); let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); let child_id = self.nodes.insert(child); @@ -552,7 +571,7 @@ impl RadixTrie { } /// Remove a node from the trie. - fn remove_node(&mut self, adapter_index: u32, node_id: NodeId) -> TrieNode { + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { // Unwrap here, passing in an unknown id is a programming error. let node = self.nodes.remove(node_id).expect("Unknown node"); assert!( @@ -563,7 +582,7 @@ impl RadixTrie { let parent_id = node.parent.expect("Attempted to remove root node"); let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); - let node_key = hash(adapter_index, &node.key[..self.block_size]); + let node_key = hash(&node.key[..self.block_size]); parent.children.remove(&node_key); self.decref(parent_id) .expect("Failed to decrease parent refcount"); @@ -587,8 +606,8 @@ impl RadixTrie { /// Print debugging output for the trie. /// /// In contrast to `Debug` nicely formatted. - pub fn print_debug(&self) { - self.print_debug_(self.root, 0); + pub fn print_debug(&self, adapter_index: u32) { + self.print_debug_(self.root_id(adapter_index), 0); } fn print_debug_(&self, node_id: NodeId, indent: usize) { @@ -609,8 +628,20 @@ impl RadixTrie { } } - pub(crate) fn root_id(&self) -> DefaultKey { - self.root + fn get_or_create_root(&mut self, adapter_index: u32) -> DefaultKey { + *self.roots.entry(adapter_index).or_insert_with(|| { + let root = TrieNode::new(vec![], vec![], 0, None); + self.nodes.insert(root) + }) + } + + pub(crate) fn root_id(&self, adapter_index: u32) -> DefaultKey { + self.roots[&adapter_index] + } + + pub(crate) fn is_root(&self, node_id: NodeId) -> bool { + let node = self.nodes.get(node_id).expect("Unknown node"); + node.parent.is_none() } } @@ -649,6 +680,7 @@ fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { #[cfg(test)] mod tests { use std::sync::Arc; + use tracing_test::traced_test; use super::*; @@ -704,6 +736,36 @@ mod tests { assert_eq!(allocation.prefix_len, 4); } + #[traced_test] + #[test] + fn allocator_reuses_prefixes_multi_adapter() { + let mut cache = RadixAllocator::new(1, 20, None); + + // Allocate 8 tokens: 4 tokens in prefill + 4 slots for generation + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![12, 13, 14, 15, 16, 17, 18, 19]); + assert_eq!(allocation.blocks, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + // 4 new blocks, 4 reused blocks from unused slots that were freed above. + let allocation = cache + .allocate(1, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11, 16, 17, 18, 19]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + // Same blocks as the first allocation, as cache was never evicted. + let allocation = cache + .allocate(0, 8, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation.blocks, vec![12, 13, 14, 15, 16, 17, 18, 19]); + assert_eq!(allocation.prefix_len, 4); + } + #[test] fn allocator_collects_older_prefixes_first() { let mut cache = RadixAllocator::new(1, 7, None); @@ -750,6 +812,29 @@ mod tests { assert_eq!(cache.free_blocks.len(), 5); } + #[traced_test] + #[test] + fn allocator_frees_fully_overlapping_prefills_multi_adapter() { + let mut cache = RadixAllocator::new(1, 5, None); + let allocation1 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(1, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + let allocation3 = cache + .allocate(0, 4, Some(Arc::new(vec![0, 1, 2, 3]))) + .unwrap(); + assert_eq!(allocation3.prefix_len, 0); + + // 5 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 0); + } + #[test] fn allocator_frees_partially_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 20, None); @@ -888,7 +973,7 @@ mod tests { let mut blocks = Vec::new(); // Remove less than the leave blocks. - assert_eq!(trie.evict(0, 1), vec![7]); + assert_eq!(trie.evict(1), vec![7]); trie.find(0, &[0, 1, 2, 3, 5, 6, 7], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); @@ -897,7 +982,7 @@ mod tests { trie.find(0, &[1, 2, 3], &mut blocks); // Remove the leave blocks exactly. - assert_eq!(trie.evict(0, 2), vec![5, 6]); + assert_eq!(trie.evict(2), vec![5, 6]); blocks.clear(); trie.find(0, &[0, 1, 2, 3, 5, 6, 7], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3]); @@ -905,12 +990,12 @@ mod tests { trie.find(0, &[1, 2, 3], &mut blocks); // Remove more than the leave blocks. - assert_eq!(trie.evict(0, 3), vec![4, 3, 2]); + assert_eq!(trie.evict(3), vec![4, 3, 2]); blocks.clear(); trie.find(0, &[0, 1, 2, 3, 4], &mut blocks); assert_eq!(blocks, vec![0, 1]); // Clear out the whole trie. - assert_eq!(trie.evict(0, 10), vec![1, 2, 3, 0, 1]); + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); } } diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 4231b3894..d557faa89 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -7,7 +7,7 @@ use crate::{ }; use lorax_client::{Batch, ShardedClient}; use std::{cmp::max, collections::HashSet, sync::Arc}; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tracing::{info_span, instrument, Instrument, Span}; enum AdapterSchedulerCommand { @@ -25,7 +25,7 @@ enum AdapterSchedulerCommand { #[derive(Clone)] pub(crate) struct AdapterScheduler { - sender: flume::Sender, + sender: mpsc::UnboundedSender, } impl AdapterScheduler { @@ -43,7 +43,7 @@ impl AdapterScheduler { chunked_prefill: bool, is_causal_lm: bool, ) -> Self { - let (sender, receiver) = flume::unbounded(); + let (sender, receiver) = mpsc::unbounded_channel(); // receives requests from the infer struct and sends them to the appropriate adapter queue tokio::spawn(adapter_scheduler_task( @@ -118,7 +118,7 @@ async fn adapter_scheduler_task( requires_padding: bool, block_size: u32, window_size: Option, - receiver: flume::Receiver, + mut receiver: mpsc::UnboundedReceiver, max_active_adapters: usize, adapter_cycle_time_s: u64, speculate: u32, @@ -141,7 +141,7 @@ async fn adapter_scheduler_task( is_causal_lm, ); - while let Ok(cmd) = receiver.recv_async().await { + while let Some(cmd) = receiver.recv().await { match cmd { AdapterSchedulerCommand::Append(adapter, entry) => { state.append(adapter, adapter_event.clone(), entry).await; @@ -330,7 +330,8 @@ impl AdapterSchedulerState { 'entry_loop: while let Some((id, mut entry, adapter)) = self.next_entry().await { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { + tracing::error!("Entry response channel closed."); metrics::increment_counter!("lorax_request_failure", "err" => "dropped"); continue; } @@ -365,6 +366,14 @@ impl AdapterSchedulerState { None } Some(block_allocator) => { + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details() { + None + } else { + entry.request.input_ids().clone() + }; + let tokens = entry.request.input_length() + entry.request.max_new_tokens() + self.speculate @@ -379,7 +388,7 @@ impl AdapterSchedulerState { ); let block_allocation = match block_allocator - .allocate(adapter.index(), tokens, entry.request.input_ids()) + .allocate(adapter.index(), tokens, input_ids) .await { None => { diff --git a/router/src/server.rs b/router/src/server.rs index 3370114e9..011b509e1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -25,7 +25,7 @@ use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; -use axum_tracing_opentelemetry::opentelemetry_tracing_layer; +use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use futures::stream::StreamExt; use futures::Stream; use lorax_client::{ShardInfo, ShardedClient}; @@ -39,6 +39,7 @@ use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::sync::Mutex; +use thiserror::Error; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::mpsc; @@ -480,6 +481,12 @@ async fn chat_completions_v1( } } +#[derive(Debug, Error)] +pub enum WebServerError { + #[error("Axum error: {0}")] + Axum(#[from] axum::BoxError), +} + type PreparedInput = (String, Option, bool); pub(crate) fn prepare_chat_input( @@ -1284,8 +1291,8 @@ pub async fn run( cors_expose_headers: Option, tokenizer_config: HubTokenizerConfig, ngrok: bool, - ngrok_authtoken: Option, - ngrok_edge: Option, + _ngrok_authtoken: Option, + _ngrok_edge: Option, adapter_source: String, eager_prefill: bool, prefix_caching: bool, @@ -1523,12 +1530,16 @@ pub async fn run( tracing::info!("REQUEST_LOGGER_URL not set, request logging is disabled"); } + #[allow(unused_mut)] // mut is needed for conditional compilation + let mut doc = ApiDoc::openapi(); + + // Configure Swagger UI + let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); + // Create router - let app = Router::new() - .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) + let base_routes = Router::new() // Base routes .route("/", post(compat_generate)) - .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/embed", post(embed)) .route("/classify", post(classify)) @@ -1537,16 +1548,27 @@ pub async fn run( .route("/v1/completions", post(completions_v1)) .route("/v1/chat/completions", post(chat_completions_v1)) // AWS Sagemaker route - .route("/invocations", post(compat_generate)) + .route("/invocations", post(compat_generate)); + + let info_routes = Router::new() + .route("/", get(health)) // Base Health route .route("/health", get(health)) - // Inference API health route - .route("/", get(health)) + .route("/info", get(get_model_info)) // AWS Sagemaker health route .route("/ping", get(health)) // Prometheus metrics route .route("/metrics", get(metrics)) - .route("/tokenize", post(tokenize)) + .route("/tokenize", post(tokenize)); + + // Combine routes and layers + let mut app = Router::new() + .merge(swagger_ui) + .merge(base_routes) + .merge(info_routes); + + // add layers after routes + app = app .layer(Extension(info)) .layer(Extension(client.clone())) .layer(Extension(request_logger_sender.clone())) @@ -1554,53 +1576,16 @@ pub async fn run( .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(prom_handle.clone())) - .layer(opentelemetry_tracing_layer()) + .layer(OtelAxumLayer::default()) .layer(cors_layer) .layer(Extension(cloned_tokenizer)); if ngrok { #[cfg(feature = "ngrok")] { - use ngrok::config::TunnelBuilder; - - let _ = addr; - - let authtoken = - ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); - - let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling"); - - let tunnel = ngrok::Session::builder() - .authtoken(authtoken) - .connect() - .await - .unwrap() - .labeled_tunnel() - .label("edge", edge); - - let listener = tunnel.listen().await.unwrap(); - - // Run prom metrics and health locally too - tokio::spawn( - axum::Server::bind(&addr) - .serve( - Router::new() - .route("/health", get(health)) - .route("/metrics", get(metrics)) - .layer(Extension(health_ext)) - .layer(Extension(prom_handle)) - .into_make_service(), - ) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()), - ); + panic!("ngrok feature is not functional with axum=0.7 and hyper=1, waiting on https://github.com/ngrok/ngrok-rust/pull/137/files to re-enable."); // Run server - axum::Server::builder(listener) - .serve(app.into_make_service()) - //Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()) - .await?; } #[cfg(not(feature = "ngrok"))] { @@ -1613,11 +1598,12 @@ pub async fn run( } } else { // Run server - axum::Server::bind(&addr) - .serve(app.into_make_service()) - // Wait until all requests are finished to shut down + + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) - .await?; + .await + .map_err(|err| WebServerError::Axum(Box::new(err)))?; } Ok(()) } diff --git a/router/src/validation.rs b/router/src/validation.rs index 76ec6d7ca..4a7a91edc 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -12,7 +12,7 @@ use std::io::Cursor; use std::iter; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use tracing::{instrument, Span}; use {once_cell::sync::Lazy, regex::Regex}; @@ -25,7 +25,7 @@ pub struct Validation { max_input_length: usize, max_total_tokens: usize, /// Channel to communicate with the background tokenization task - sender: Option>, + sender: Option>, } impl Validation { @@ -41,15 +41,17 @@ impl Validation { ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { - // Create channel - let (validation_sender, validation_receiver) = flume::unbounded(); + // Create round robin channel + let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); + let mut senders = Vec::with_capacity(workers); // Create workers for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); let config_clone = config.clone(); let preprocessor_config_clone = preprocessor_config.clone(); - let receiver_clone = validation_receiver.clone(); + let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); + senders.push(tokenizer_sender); // Spawn worker tokio::task::spawn_blocking(move || { @@ -57,10 +59,14 @@ impl Validation { tokenizer_clone, config_clone, preprocessor_config_clone, - receiver_clone, + tokenizer_receiver, ) }); } + + // Create tokenization round robin task + tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); + Some(validation_sender) } else { None @@ -390,15 +396,30 @@ impl Validation { } } +/// Round robin tokenization task +async fn round_robin_task( + mut receiver: mpsc::UnboundedReceiver, + senders: Vec>, +) { + loop { + for sender in &senders { + match receiver.recv().await { + None => return, + Some(request) => sender.send(request).unwrap(), + }; + } + } +} + /// Start tokenization workers fn tokenizer_worker( tokenizer: Tokenizer, config: Option, preprocessor_config: Option, - receiver: flume::Receiver, + mut receiver: mpsc::UnboundedReceiver, ) { // Loop over requests - while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() { + while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx .send(prepare_input( diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index fac7f5b8b..b2aec8a0d 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -10,7 +10,7 @@ from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights from lorax_server.utils.lora import LM_HEAD -from lorax_server.utils.sgmv import ( +from lorax_server.utils.punica import ( BGMV_MAX_RANK, MAX_RANK_CUSTOM, get_tmp_tensors, diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 68838e760..276764509 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -3,14 +3,15 @@ import torch import torch.distributed +from loguru import logger from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import MEDUSA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights from lorax_server.layers import FastLinear, TensorParallelColumnLinear +from lorax_server.utils.punica import segmented_matmul from lorax_server.utils.segments import find_segments -from lorax_server.utils.sgmv import segmented_matmul -from lorax_server.utils.state import get_speculative_tokens +from lorax_server.utils.state import LORAX_SPECULATION_MAX_BATCH_SIZE, get_speculative_tokens from lorax_server.utils.weights import AbstractWeights, InMemoryWeights if TYPE_CHECKING: @@ -159,7 +160,8 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights): def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None): # If we have too many tokens, we skip speculative logits - if x.shape[0] > 128: + if x.shape[0] > LORAX_SPECULATION_MAX_BATCH_SIZE: + logger.info(f"Skipping speculation at batch size = {x.shape[0]}") logits = lm_head(x) return logits, None @@ -311,11 +313,19 @@ def load( default_medusa=default_medusa, segments=MedusaSegments( w=[ - (adapter_weights[idx].model.medusa.linear.linear.weight if idx in adapter_weights else EMPTY_TENSOR) + ( + adapter_weights[idx].model.medusa.linear.linear.weight.data + if idx in adapter_weights + else EMPTY_TENSOR + ) for idx in segment_indices ], b=[ - (adapter_weights[idx].model.medusa.linear.linear.bias if idx in adapter_weights else EMPTY_TENSOR) + ( + adapter_weights[idx].model.medusa.linear.linear.bias.data + if idx in adapter_weights + else EMPTY_TENSOR + ) for idx in segment_indices ], s_start=segments[indices], diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index bfa4a0bf2..1ca77f001 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -1,19 +1,25 @@ from abc import ABC, abstractclassmethod from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type import torch from lorax_server.adapters.types import LORA from lorax_server.utils.lora import LM_HEAD +if TYPE_CHECKING: + from lorax_server.utils.punica import PunicaWrapper + @dataclass class AdapterBatchMetadata: - # [batch_size] + # [num_tokens] adapter_indices: torch.Tensor + # [batch_size] + adapter_list: List[int] + # [num_adapters] adapter_set: Set[int] @@ -106,12 +112,19 @@ class AdapterBatchData: # layer type -> adapter type -> batch weight data data: Dict[str, Dict[str, BatchAdapterWeights]] + # layer type -> fused lora weights + layer_to_lora_weights: Dict[Tuple[str, int], Tuple[torch.Tensor, torch.Tensor]] + + punica_wrapper: "PunicaWrapper" + prefill: bool @staticmethod def from_meta( meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights], + layer_to_lora_weights: Dict[Tuple[str, int], Tuple[torch.Tensor, torch.Tensor]], + punica_wrapper: "PunicaWrapper", prefill: bool, prefill_head_indices: Optional[torch.Tensor], ) -> "AdapterBatchData": @@ -122,7 +135,13 @@ def from_meta( layer_weights = v.get_data(meta, k, prefill, prefill_head_indices if k == LM_HEAD else None) if layer_weights: data[k] = layer_weights - return AdapterBatchData(meta=meta, data=data, prefill=prefill) + return AdapterBatchData( + meta=meta, + data=data, + layer_to_lora_weights=layer_to_lora_weights, + punica_wrapper=punica_wrapper, + prefill=prefill, + ) def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 79aa029cf..7eeaa9f36 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -90,7 +90,7 @@ def from_pb( padding_right_offset = 0 max_decode_tokens = 0 adapter_indices_list = [] - adapter_set = set() + adapter_list = [] for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i req_inputs = tokenizers.get_inputs(r, tokenizer) @@ -102,7 +102,7 @@ def from_pb( max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens) adapter_indices_list.append(r.adapter_index) - adapter_set.add(r.adapter_index) + adapter_list.append(r.adapter_index) adapter_indices = torch.tensor(adapter_indices_list, dtype=torch.int64, device=device) @@ -156,7 +156,8 @@ def from_pb( max_tokens=max_tokens, adapter_meta=AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), @@ -180,7 +181,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: all_input_ids = [] max_input_length = 0 - adapter_set = set() + adapter_list = [] next_token_choosers = [] stopping_criterias = [] @@ -209,7 +210,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: total_remaining_decode_tokens += remaining_decode_tokens new_padding_right_offset = max(new_padding_right_offset, remaining_decode_tokens) - adapter_set.add(self.requests[idx].adapter_index) + adapter_list.append(self.requests[idx].adapter_index) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] @@ -262,7 +263,8 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: self.max_tokens = max_tokens self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) @@ -301,7 +303,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) - adapter_set = set() + adapter_list = [] adapter_segment_builder = SegmentConcatBuilder() cumulative_adapter_indices_size = 0 @@ -344,7 +346,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) + adapter_list.extend(batch.adapter_meta.adapter_list) # Update adapter segments adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) @@ -476,7 +478,8 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": max_tokens=max_tokens, adapter_meta=AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), @@ -593,8 +596,10 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option # TODO(travis): don't update this if indices haven't changed # Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous adapter_data = AdapterBatchData.from_meta( - batch.adapter_meta, - self.layer_to_adapter_weights, + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, prefill=True, prefill_head_indices=None, ) diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index 64dddfc36..1c6c23320 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -524,6 +524,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -538,6 +539,11 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + # FIXME: simply running the LM head is not sufficient since we also need to scale the logits + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) logits *= self.logit_scale if speculative_logits is not None: diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index 5c16c81df..252bdd514 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -1009,6 +1009,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -1023,5 +1024,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index eec7fddf8..2aca92c7d 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -539,6 +539,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -554,5 +555,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index 2e8b6cba9..09be5d766 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -538,6 +538,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index eeb6e8d38..3b65e969d 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -367,6 +367,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index b308edc5d..87abc494a 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -598,6 +598,7 @@ def forward( prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -615,5 +616,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index b74ad7e73..5e244fa93 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -610,6 +610,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -635,5 +636,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 58e939932..f9eac6fb3 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -963,6 +963,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -987,5 +988,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index cc6df3382..e798b6f6b 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -357,6 +357,7 @@ def forward( max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.gpt_neox( input_ids, @@ -370,5 +371,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits = self.embed_out(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index b0b48688d..151d11641 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -506,6 +506,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -520,5 +521,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index d600928b9..e6b22fe71 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -388,6 +388,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, @@ -402,5 +403,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 068d2e0b6..b04d61698 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -326,8 +326,8 @@ def __init__(self, prefix, config, weights, layer_id): layer_id, [MLP_GATE_PROJ, MLP_UP_PROJ], sizes=[ - config.intermediate_size // 2, - config.intermediate_size // 2, + config.intermediate_size, + config.intermediate_size, ], process_group=weights.process_group, ) @@ -511,6 +511,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor @@ -536,6 +537,10 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 606248af1..c4a46db68 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -507,6 +507,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, @@ -521,5 +522,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index 4f3b36765..4fe821039 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -592,6 +592,7 @@ def forward( max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, @@ -605,5 +606,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits = self.lm_head(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index 4e98b97a2..a3dc31da6 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -423,6 +423,7 @@ def forward( max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.transformer( input_ids, @@ -436,5 +437,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits = self.lm_head(hidden_states) return logits, None diff --git a/server/lorax_server/models/custom_modeling/llava_next.py b/server/lorax_server/models/custom_modeling/llava_next.py index bede2691e..cb0797834 100644 --- a/server/lorax_server/models/custom_modeling/llava_next.py +++ b/server/lorax_server/models/custom_modeling/llava_next.py @@ -178,6 +178,7 @@ def forward( pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional["AdapterBatchData"] = None, + skip_lm_head: bool = False, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: @@ -264,5 +265,9 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + logits, speculative_logits = self.text_model.lm_head(hidden_states, adapter_data) return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index c48448b36..7aa2e01e0 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -884,6 +884,7 @@ def forward( # XXX: Putting these as optional so that the cuda warmup calls can go through. cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, + skip_lm_head: bool = False, ): if cross_attention_states is not None: seqlen_q = len(image_indices) @@ -954,6 +955,7 @@ def forward( prefill_cache_indices=prefill_cache_indices, lm_head_indices=lm_head_indices, cross_attention_states=cross_attention_states, + skip_lm_head=skip_lm_head, ) return outputs diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 1c74cbcac..627512330 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -7,12 +7,21 @@ import numpy as np import torch import torch.distributed +import torch.profiler from loguru import logger from opentelemetry import trace from tqdm import tqdm from transformers import AutoConfig, AutoTokenizer, GenerationConfig, PretrainedConfig, PreTrainedTokenizerBase from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata +from lorax_server.models.metadata_kernels import ( + block_tables_to_padded, + block_tables_to_ragged, + copy_next_input_ids_inplace, + has_triton, + prepare_position_slot_ids, + slots_filtering, +) from lorax_server.models.model import Model from lorax_server.models.types import ( Batch, @@ -24,10 +33,10 @@ from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, create_merged_weight_files from lorax_server.utils.attention.common import Seqlen -from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed from lorax_server.utils.graph import GraphCache from lorax_server.utils.import_utils import get_cuda_free_memory +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, PunicaWrapper from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB from lorax_server.utils.sources.hub import weight_files @@ -82,6 +91,10 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward before staying set in decode slots: Optional[torch.Tensor] + # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch + # used for filtering + cu_slots: torch.Tensor + max_input_length: int max_current_length: int @@ -89,9 +102,10 @@ class FlashCausalLMBatch(Batch): prefilling: bool # Whether each request is prefilling prefilling_mask: List[bool] + prefilling_mask_tensor: Optional[torch.Tensor] # Prefill metadata tensors to efficiently compute logprobs - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + # tensor of length b+1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers # as we only keep SLIDING_WINDOW values instead of the whole tensor @@ -193,6 +207,8 @@ def from_pb( all_input_ids = [] all_postfix_ids = [] requests_idx_mapping = {} + slots = [] + cu_slots = [0] next_token_chooser_parameters = [] stopping_criterias = [] @@ -203,7 +219,9 @@ def from_pb( max_length = 0 max_blocks = 0 + cu_blocks = [0] block_tables = [] + block_tables_ragged = [] # Parse batch for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): @@ -263,10 +281,18 @@ def from_pb( if not r.blocks: needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [b for b in range(num_blocks, num_blocks + needed_blocks)] + request_slots = [s for b in request_blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] else: request_blocks = r.blocks + request_slots = r.slots block_tables.append(request_blocks) + block_tables_ragged.extend(request_blocks) + cu_blocks.append(len(block_tables_ragged)) + + slots.extend(request_slots) + cu_slots.append(len(slots)) + cache_lengths.append(cache_length) num_blocks += len(request_blocks) @@ -294,12 +320,29 @@ def from_pb( # Create tensors on device all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64, device=device) - block_tables_tensor = torch.zeros((len(block_tables), max_blocks), dtype=torch.int32, device="cpu") - for i, request_blocks in enumerate(block_tables): - block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - block_tables_tensor = block_tables_tensor.to(device) + block_tables_ragged = torch.tensor(block_tables_ragged, device=device, dtype=torch.int32) + cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) + block_tables_tensor = torch.empty( + (len(block_tables), max_blocks), + device=device, + dtype=torch.int32, + ) + + # If the device supports Triton, we can use a fused kernel + if has_triton(): + block_tables_to_padded(max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged) + else: + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32, device=device) + slots = torch.tensor(slots, dtype=torch.int64, device=device) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) + + prefilling_mask = [True] * len(pb.requests) + prefilling_mask_tensor = torch.tensor(prefilling_mask, dtype=torch.bool, device=device) + return cls( batch_id=pb.id, requests=pb.requests, @@ -311,7 +354,8 @@ def from_pb( max_input_length=max_input_length, max_current_length=max_current_length, prefilling=True, - prefilling_mask=[True] * len(pb.requests), + prefilling_mask=prefilling_mask, + prefilling_mask_tensor=prefilling_mask_tensor, prefill_logprob_tokens=[None] * len(pb.requests), input_lengths=input_lengths, prompt_lengths=prompt_lengths, @@ -330,7 +374,8 @@ def from_pb( cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=None, - slots=None, + slots=slots, + cu_slots=cu_slots, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -343,9 +388,6 @@ def from_pb( def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if len(request_ids) == 0: raise ValueError("Batch must have at least one request") - # We assume that if len(requests) == len(self) then the requests are the same - if len(request_ids) == len(self): - return self device = self.block_tables_tensor.device @@ -356,7 +398,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": indices = [] # slots to keep after filtering - slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool, device=device) + if not has_triton(): + # slots to keep after filtering + slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool, device=device) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) @@ -373,17 +417,18 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": cache_lengths = [] prefix_offsets = [] read_offsets = [] + cu_slots = [0] prefilling_mask = [] prefill_logprob_tokens = [] stopping_criterias = [] - adapter_set = set() + adapter_list = [] num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 + max_slots = 0 + cumulative_slot_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -415,30 +460,33 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) - adapter_set.add(self.requests[idx].adapter_index) + adapter_list.append(self.requests[idx].adapter_index) request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) + start_slot = self.cu_slots[idx] + end_slot = self.cu_slots[idx + 1] + slot_length = end_slot - start_slot + + if not has_triton(): + # Set slice + slot_filtering_indices[start_slot:end_slot] = True + + cu_slots.append(cumulative_slot_tokens + slot_length) + # Input ids if the request was part of a prefilling batch # If the batch was decoding we can index into the tensor directly later if self.prefilling: input_ids.append(self.input_ids[idx]) else: # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length - - remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - - # Set slice - slot_filtering_indices[ - self.slot_indices[idx] : self.slot_indices[idx] + request_input_length + remaining_tokens - 1 - ] = True - - cumulative_max_length += request_input_length + remaining_tokens - 1 + slot_indices[i] = cumulative_slot_tokens + request_cache_length + cumulative_slot_tokens += slot_length max_blocks = max(max_blocks, len(request_block_table)) + max_slots = max(max_slots, slot_length) all_input_ids_tensor = self.all_input_ids_tensor[indices] @@ -446,23 +494,32 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": next_token_chooser = self.next_token_chooser.filter(indices) speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None prompt_lengths_tensor = self.prompt_lengths_tensor[indices] + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) + + if not has_triton(): + slots = self.slots[slot_filtering_indices] + else: + slots = self.slots.new_empty(cumulative_slot_tokens) + gpu_cu_slots = cu_slots.to(device) + slots_indexing_start = self.cu_slots.to(device)[indices] + slots_filtering(max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start) if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None slot_indices = None - slots = None cache_lengths_tensor = None input_lengths_tensor = None adapter_meta = None + prefilling_mask_tensor = self.prefilling_mask_tensor[indices] else: # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] cache_lengths_tensor = self.cache_lengths_tensor[indices] + prefilling_mask_tensor = None # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) @@ -471,7 +528,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) @@ -489,10 +547,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + cu_slots=cu_slots, max_input_length=max_input_length, max_current_length=max_current_length, prefilling=self.prefilling, prefilling_mask=prefilling_mask, + prefilling_mask_tensor=prefilling_mask_tensor, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -548,26 +608,29 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) prefilling = prefilling or b.prefilling + slots = batches[0].slots.new_empty(total_slots) + cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64) if prefilling: input_ids = [] # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None - slots = None slot_indices = None cache_lengths_tensor = None input_lengths_tensor = None + prefilling_mask_tensor = batches[0].prefilling_mask_tensor.new_empty(total_batch_size) adapter_meta = None adapter_segment_builder = None else: input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size) cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(total_batch_size) + prefilling_mask_tensor = None total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) adapter_segment_builder = SegmentConcatBuilder() + adapter_list = [] adapter_set = set() prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(total_batch_size) @@ -619,13 +682,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - if not prefilling: - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + slots[slots_start_index:slots_end_index] = batch.slots + cu_slots[start_index + 1 : end_index + 1] = batch.cu_slots[1:] + cumulative_slots + if not prefilling: input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - slots[slots_start_index:slots_end_index] = batch.slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor @@ -635,18 +699,17 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index + adapter_list.extend(batch.adapter_meta.adapter_list) adapter_set.update(batch.adapter_meta.adapter_set) adapter_segment_builder.concat( batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices, ) - - # Update - cumulative_slots += len(batch.slots) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() input_ids.extend(batch.input_ids) + prefilling_mask_tensor[start_index:end_index] = batch.prefilling_mask_tensor prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) @@ -669,6 +732,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch stopping_criterias.extend(batch.stopping_criterias) # Update + cumulative_slots += len(batch.slots) cumulative_batch_size += len(batch) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( @@ -679,14 +743,20 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch sequence_processors=sequence_processors, ) - speculative_ids = ( - torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None - ) + # We skip computing the speculative_ids when the batch size is too large, so + # we must check that all batches have them, otherwise they must be discarded + speculative_ids = None + if get_speculative_tokens() > 0: + if all(b.speculative_ids is not None for b in batches): + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + else: + logger.info("Discarding speculative IDs, not every batch has them") if adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, + adapter_list=adapter_list, adapter_set=adapter_set, adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, @@ -707,10 +777,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cache_lengths=cache_lengths, cache_lengths_tensor=cache_lengths_tensor, slots=slots, + cu_slots=cu_slots, max_input_length=max_input_length, max_current_length=max_current_length, prefilling=prefilling, prefilling_mask=prefilling_mask, + prefilling_mask_tensor=prefilling_mask_tensor, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -739,14 +811,41 @@ def prepare_for_prefill(self): # it simplifies everything assert self.speculative_ids is None + device = self.block_tables_tensor.device + + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32, device=device) + self.cu_seqlen_prefill = torch.nn.functional.pad(torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)).to( + torch.int32 + ) + self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32, device=device) + + # If the device supports Triton, we can use a fused kernel + if has_triton(): + self.position_ids = torch.empty(len(self.input_ids), dtype=torch.int32, device=device) + self.slot_indices = torch.empty(len(self.input_ids), dtype=torch.int64, device=device) + cu_slots_gpu = self.cu_slots.to(device) + + prepare_position_slot_ids( + self.max_input_length, + self.cache_lengths_tensor, + self.cu_seqlen_prefill, + cu_slots_gpu, + self.position_ids, + self.slot_indices, + ) + position_ids = [] - cu_seqlen_prefill = [0] slot_indices = [] prefill_cache_indices = [] all_prefill_logprobs = True no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] prefill_cu_outlens = [0] # Cumulative length @@ -754,9 +853,8 @@ def prepare_for_prefill(self): cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 - slots = [] adapter_indices_list = [] - adapter_set = set() + adapter_list = [] for i, ( r, @@ -776,24 +874,27 @@ def prepare_for_prefill(self): ) ): next_chunk_length = input_length - # Position ids - request_position_ids = torch.arange(cache_length, cache_length + input_length, dtype=torch.int32) - position_ids.append(request_position_ids) - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) + if not has_triton(): + # Position ids + request_position_ids = torch.arange(cache_length, cache_length + input_length, dtype=torch.int32) + position_ids.append(request_position_ids) - if not r.slots: - request_slots = [s for b in blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] - else: - request_slots = r.slots + if not r.slots: + request_slots = [s for b in blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] + else: + request_slots = r.slots - request_slots = request_slots[cache_length:] - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) + request_slot_indices = torch.arange( + cache_length + cumulative_slot_tokens, + cache_length + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + + slot_indices.append(request_slot_indices) + + # Update + cumulative_slot_tokens += len(request_slots) # Create tensor to slice into the kv tensor in prefill if SLIDING_WINDOW is not None: @@ -810,73 +911,93 @@ def prepare_for_prefill(self): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append( - torch.arange( - cumulative_length, - cumulative_length + input_length, - dtype=torch.int64, - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], - dtype=torch.int64, - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - slots.extend(request_slots) - slot_indices.append(request_slot_indices) - if SLIDING_WINDOW is not None: prefill_cache_indices.append(request_prefill_cache_indices) adapter_indices_list.append(torch.full((next_chunk_length,), r.adapter_index)) - adapter_set.add(r.adapter_index) + adapter_list.append(r.adapter_index) # Update cumulative_length += next_chunk_length - cumulative_slot_tokens += len(request_slots) - device = self.block_tables_tensor.device + if not all_prefill_logprobs and not no_prefill_logprobs: + prefill_head_indices = [] + prefill_next_token_indices = [] - if isinstance(self.input_ids, list): - if len(self) > 1: - input_ids = np.concatenate(self.input_ids, dtype=np.int64) - else: - input_ids = self.input_ids[0] - self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + # Cumulative length + cumulative_length = 0 + prefill_out_cumulative_length = 0 + + for i, ( + r, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.input_lengths, + self.prefilling_mask, + ) + ): + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length if len(self) > 1: - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) + if position_ids: + position_ids = torch.cat(position_ids) + if slot_indices: + slot_indices = torch.cat(slot_indices) if SLIDING_WINDOW is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: - position_ids = position_ids[0] - slot_indices = slot_indices[0] + if position_ids: + position_ids = position_ids[0] + if slot_indices: + slot_indices = slot_indices[0] if SLIDING_WINDOW is not None: prefill_cache_indices = prefill_cache_indices[0] + if not has_triton(): + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + self.prefill_cu_outlens = prefill_cu_outlens - cu_seqlen_prefill = torch.tensor(cu_seqlen_prefill, device=device, dtype=torch.int32) - self.cu_seqlen_prefill = cu_seqlen_prefill - self.position_ids = position_ids.to(device) - self.slot_indices = slot_indices.to(device) self.prefill_cache_indices = prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None - self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32, device=device) if all_prefill_logprobs: prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1 elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_head_indices = self.cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.cat(prefill_head_indices).to(device) @@ -884,14 +1005,13 @@ def prepare_for_prefill(self): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices - self.slots = torch.tensor(slots, dtype=torch.int64, device=device) - self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32, device=device) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ) @@ -958,13 +1078,13 @@ def __init__( config.quantize = quantize if is_fp8(config.quantize) and not is_fp8_supported(): - raise ValueError('FP8 quantization is only supported on hardware that supports FP8') + raise ValueError("FP8 quantization is only supported on hardware that supports FP8") if is_fp8_kv(config.quantize): if not FLASH_INFER: - raise ValueError('FP8 KV cache requires FLASH_INFER backend') + raise ValueError("FP8 KV cache requires FLASH_INFER backend") self.kv_dtype = torch.float8_e4m3fn - logger.info('Enabling FP8 KV cache. Prefix caching will not work.') + logger.info("Enabling FP8 KV cache. Prefix caching will not work.") else: self.kv_dtype = dtype @@ -1088,6 +1208,8 @@ def __init__( num_kv_heads=self.num_kv_heads, ) + self.punica_wrapper = None + @property def block_size(self) -> int: return BLOCK_SIZE @@ -1159,6 +1281,16 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model # The warmup batch is the biggest batch we could ever receive max_total_tokens = batch.max_input_length + max_new_tokens + get_speculative_tokens() + self.punica_wrapper = PunicaWrapper( + max_num_batched_tokens=get_max_prefill_tokens(), + max_batches=256, # TODO(travis): find a better way to set this programmatically + device=self.device, + enabled=( + not self.dynamic_adapter_loading_enabled # only supported for now with statically loaded adapters + and not LORAX_PUNICA_TRITON_DISABLED + ), + ) + torch.cuda.empty_cache() try: self.init_kv_cache( @@ -1198,9 +1330,6 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model if self.world_size > 1: raise ValueError("Cannot enable `--compile` when sharding across multiple GPUs") - # This will be recalculated in the graph step - self.decode_state = None - # Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache. # Needs to be estimated here rather than fully initialized as the graph cache relies on the # cache manager being set. @@ -1209,12 +1338,14 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.device, self.kv_cache, self.adapter_layers, - self.default_traced_adapter_layers, + self.traced_adapter_layers, self._forward_context, max_total_tokens, self.num_heads, self.num_kv_heads, self.sliding_window_blocks, + self.layer_to_lora_weights, + self.punica_wrapper, ) graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory() logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024) @@ -1258,6 +1389,9 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.model_graph_wrapper.warmup() torch.cuda.synchronize(self.device) + if self.profiler is not None: + self.profiler.start() + return int(num_blocks * BLOCK_SIZE) def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: @@ -1349,10 +1483,16 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, + # then update the slots with the additional indices to ensure we're grabbing the ones that have been + # allocated + slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + slots = batch.slots[slot_indices] + block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() max_s = max_s + speculative_length @@ -1381,6 +1521,9 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=max_s, ) with self._forward_context( @@ -1405,6 +1548,8 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> lm_head_indices=batch.prefill_head_indices, ) else: + skip_lm_head = get_speculative_tokens() > 0 + # CUDA graph mode out = model.forward( input_ids=input_ids, @@ -1422,6 +1567,10 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> lm_head_indices=batch.prefill_head_indices, ) + if skip_lm_head and hasattr(self.model, "lm_head"): + # re-run through the LM head as the graph did not capture it + out = self.model.lm_head(out[0], adapter_data) + if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -1446,6 +1595,7 @@ def generate_token( adapter_segments = adapter_meta.adapter_segments * new_length adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, + adapter_list=adapter_meta.adapter_list, adapter_set=adapter_meta.adapter_set, adapter_segments=adapter_segments, segment_indices=adapter_meta.segment_indices, @@ -1453,8 +1603,14 @@ def generate_token( # Assign pointers to adapter weights # TODO(travis): don't update this if indices haven't changed + self.punica_wrapper.update_metadata(adapter_meta, prefill) adapter_data = AdapterBatchData.from_meta( - adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices + adapter_meta, + self.layer_to_adapter_weights, + self.layer_to_lora_weights, + self.punica_wrapper, + prefill, + batch.prefill_head_indices, ) out, speculative_logits = self.forward(batch, adapter_data) @@ -1472,7 +1628,6 @@ def generate_token( else: prefill_logprobs = None next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices finished_prefilling = True next_chunk_lengths = [] @@ -1539,11 +1694,10 @@ def generate_token( # Since we are done prefilling, all the tensors that were concatenating values for all the requests # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) - elif not prefill: - next_position_ids = batch.position_ids + indices = batch.cu_seqlen_prefill[1:] - 1 + batch.position_ids = batch.position_ids[indices] + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[indices] # Zipped iterator iterator = zip( @@ -1562,8 +1716,8 @@ def generate_token( # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - index = 0 # Cumulative length + cu_accepted_ids = torch.nn.functional.pad(torch.cumsum(accepted_ids, dim=0), (1, 0)) cumulative_length = 0 for i, ( request, @@ -1575,55 +1729,55 @@ def generate_token( request_was_prefilling, request_is_prefilling, ) in enumerate(iterator): - if prefill and finished_prefilling: + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: # Indexing metadata - _start_index = cumulative_length - end_index = cumulative_length + input_length - - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] - - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - if not request_is_prefilling: - # Only save tokens if we are done prefilling for this request - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] - - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids + + # If the device does not support triton, we copy one by one + if not request_is_prefilling and not has_triton(): + # Only save tokens if we are done prefilling for this request + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] - index += n_accepted_ids cumulative_length += input_length + # If the device support triton, we can use a fused kernel + if has_triton(): + copy_next_input_ids_inplace( + speculative_tokens + 1, + batch.all_input_ids_tensor, + batch.cache_lengths_tensor, + batch.input_lengths_tensor, + batch.prompt_lengths_tensor, + next_input_ids, + cu_accepted_ids, + ) + # Update values # These values can be updated without a GPU -> CPU sync if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) @@ -1815,8 +1969,10 @@ def generate_token( # processing stopped = False new_input_length = next_chunk_lengths[i] + new_cache_length = cache_length + input_length else: - new_input_length = n_accepted_ids + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 # Append next token to all tokens next_token_texts = [] left = 0 @@ -1920,12 +2076,10 @@ def generate_token( # Update values index += n_accepted_ids - current_cache_length = cache_length + input_length - batch.cache_lengths[i] = current_cache_length - current_input_length = new_input_length - batch.max_input_length = max(batch.max_input_length, current_input_length) - batch.input_lengths[i] = current_input_length - current_length = current_cache_length + current_input_length + batch.cache_lengths[i] = new_cache_length + batch.max_input_length = max(batch.max_input_length, new_input_length) + batch.input_lengths[i] = new_input_length + current_length = new_cache_length + new_input_length batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index b52b2ef8d..f2c70687d 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -122,7 +122,12 @@ def embed(self, batch) -> torch.Tensor: adapter_meta = batch.adapter_meta prefill = False adapter_data = AdapterBatchData.from_meta( - adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices + meta=adapter_meta, + weights=self.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, + prefill=prefill, + prefill_head_indices=batch.prefill_head_indices, ) embedding, _ = self.forward(batch, adapter_data=adapter_data) return embedding.cpu().tolist() diff --git a/server/lorax_server/models/flash_roberta.py b/server/lorax_server/models/flash_roberta.py index 8e6d41d7e..74768336e 100644 --- a/server/lorax_server/models/flash_roberta.py +++ b/server/lorax_server/models/flash_roberta.py @@ -209,7 +209,14 @@ def forward(self, batch: FlashEmbeddingClassificationBatch): @tracer.start_as_current_span("embed") def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: - adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.layer_to_adapter_weights, False, None) + adapter_data = AdapterBatchData.from_meta( + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, + prefill=False, + prefill_head_indices=None, + ) with self._forward_context(cu_seqlens=batch.cu_seqlens): embedding: torch.Tensor = self.model.forward( diff --git a/server/lorax_server/models/metadata_kernels.py b/server/lorax_server/models/metadata_kernels.py new file mode 100644 index 000000000..830cbdca2 --- /dev/null +++ b/server/lorax_server/models/metadata_kernels.py @@ -0,0 +1,329 @@ +# From: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/metadata_kernels.py + +from typing import List, Optional + +import torch +import triton +import triton.language as tl +from loguru import logger +from torch.utils._triton import has_triton as has_triton_torch + +from lorax_server.utils.import_utils import ( + SYSTEM, +) + +_HAS_TRITON: Optional[bool] = None + + +def has_triton(): + global _HAS_TRITON + if _HAS_TRITON is None: + # FIXME: it seems that has_triton_torch is bugged on RocM + # For now, only accept cuda + _HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False + if _HAS_TRITON: + logger.info("Using optimized Triton indexing kernels.") + + return _HAS_TRITON + + +def block_tables_to_padded( + max_blocks: int, + cu_seqlen: torch.Tensor, + block_tables: torch.Tensor, + block_tables_ragged: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_blocks, meta["BLOCK_SIZE"]), + len(block_tables), + ) + + triton_block_tables_to_padded[grid]( + cu_seqlen, + block_tables, + block_tables_ragged, + block_tables.shape[1], + BLOCK_SIZE=256, + ) + + +def block_tables_to_ragged( + *, + block_tables: torch.Tensor, + input_lengths: List[int], + cache_lengths: List[int], + input_lengths_tensor: torch.Tensor, + cache_lengths_tensor: torch.Tensor, + max_current_length: int, +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(cache_lengths) + + total_len = sum(input_lengths) + sum(cache_lengths) + block_tables_ragged = torch.empty(total_len, dtype=torch.int32, device=block_tables.device) + + if has_triton(): + cu_seqlen = torch.nn.functional.pad(torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)) + + def grid(meta): + return ( + triton.cdiv(max_current_length, meta["BLOCK_SIZE"]), + len(cache_lengths), + ) + + triton_block_tables_to_ragged[grid]( + cu_seqlen, + block_tables, + block_tables_ragged, + block_tables.shape[1], + BLOCK_SIZE=256, + ) + else: + offset = 0 + for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): + seq_len = cache_length + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged + + +def copy_next_input_ids_inplace( + max_next_input_ids: int, + all_input_ids: torch.Tensor, + cache_lengths: torch.Tensor, + input_lengths: torch.Tensor, + prompt_lengths: torch.Tensor, + next_input_ids: torch.Tensor, + cu_accepted_ids: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]), + len(all_input_ids), + ) + + triton_copy_next_input_ids_inplace[grid]( + all_input_ids, + cache_lengths, + input_lengths, + prompt_lengths, + next_input_ids, + cu_accepted_ids, + all_input_ids.shape[1], + BLOCK_SIZE=16, + ) + + +def prepare_position_slot_ids( + max_input_length: int, + cache_lengths: torch.Tensor, + cu_seqlen: torch.Tensor, + cu_slots: torch.Tensor, + position_ids: torch.Tensor, + slot_indices: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_input_length, meta["BLOCK_SIZE"]), + len(cache_lengths), + ) + + triton_prepare_position_slot_ids[grid]( + cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256 + ) + + +def slots_filtering( + max_slots: int, + slots: torch.Tensor, + filtered_slots: torch.Tensor, + cu_slots: torch.Tensor, + slots_start: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_slots, meta["BLOCK_SIZE"]), + len(slots_start), + ) + + triton_slots_filtering[grid](slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256) + + +@triton.jit +def triton_slots_filtering( + # Inputs + slots_ptr, + filtered_slots_ptr, + slots_start_ptr, + cu_slots_ptr, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + filter_start = tl.load(slots_start_ptr + bid) + + slot_start = tl.load(cu_slots_ptr + bid) + slot_end = tl.load(cu_slots_ptr + bid + 1) + + mask = (slot_start + block_arange) < slot_end + + slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask) + tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask) + + +@triton.jit +def triton_block_tables_to_padded( + # Inputs + cu_seqlen_ptr, + # Outputs + block_tables_ptr, + block_tables_ragged_ptr, + # Stride + stride_block_tables, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + mask = (seq_start + block_arange) < seq_end + + blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask) + tl.store(block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask) + + +@triton.jit +def triton_block_tables_to_ragged( + # Inputs + cu_seqlen_ptr, + # Outputs + block_tables_ptr, + block_tables_ragged_ptr, + # Stride + stride_block_tables, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + mask = (seq_start + block_arange) < seq_end + + blocks = tl.load(block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask) + tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask) + + +@triton.jit +def triton_copy_next_input_ids_inplace( + # Inputs + all_input_ids_ptr, + cache_lengths_ptr, + input_lengths_ptr, + prompt_lengths_ptr, + next_input_ids_ptr, + cu_accepted_ids_ptr, + # Stride + stride_all_input_ids, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in max_accepted_ids / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + # Used for correctly indexing in all_input_ids + cache_length = tl.load(cache_lengths_ptr + bid) + input_length = tl.load(input_lengths_ptr + bid) + prompt_length = tl.load(prompt_lengths_ptr + bid) + + # Start/End of next_input_ids for this request + next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid) + next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1) + + # Mask values out of range + mask = (next_input_ids_start + block_arange) < next_input_ids_end + + # Mask values for request still prefilling + decode_mask = (cache_length + input_length + block_arange) >= prompt_length + + mask = mask & decode_mask + + # Load this request next input ids + next_input_ids = tl.load(next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask) + + # Store in all_input_ids, since it is a 2D tensor, apply stride * bid + tl.store( + all_input_ids_ptr + stride_all_input_ids * bid + cache_length + input_length + block_arange, + next_input_ids, + mask=mask, + ) + + +@triton.jit +def triton_prepare_position_slot_ids( + # Inputs + cache_lengths_ptr, + cu_seqlen_ptr, + cu_slots_ptr, + # Outputs + position_ids_ptr, + slot_indices_ptr, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in max_input_length / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + cache_length = tl.load(cache_lengths_ptr + bid) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + slot_start = tl.load(cu_slots_ptr + bid) + + mask = (seq_start + block_arange) < seq_end + + tl.store( + position_ids_ptr + seq_start + block_arange, + cache_length + block_arange, + mask=mask, + ) + tl.store( + slot_indices_ptr + seq_start + block_arange, + slot_start + cache_length + block_arange, + mask=mask, + ) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index f5e9dbe99..be80b3bdd 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -7,6 +7,8 @@ from loguru import logger from transformers import PreTrainedTokenizerBase +from lorax_server.adapters.lora import LoraWeights +from lorax_server.adapters.medusa_lora import MedusaLoraWeights from lorax_server.adapters.utils import download_adapter_weights from lorax_server.adapters.weights import LayerAdapterWeights from lorax_server.models.types import Batch, GeneratedText @@ -16,11 +18,13 @@ BASE_MODEL_ADAPTER_ID, load_and_merge_adapters, ) +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, pad_to_min_rank, use_cutlass_shrink from lorax_server.utils.sources import HUB from lorax_server.utils.state import ( BLOCK_SIZE, CHUNKED_PREFILL, FLASH_INFER, + LORAX_PROFILER_DIR, get_speculative_tokens, set_supports_chunking, ) @@ -77,6 +81,7 @@ def __init__( self.preloaded_adapter_indices = set() self.preloaded_adapter_memory_fractions = {} self.preloaded_adapters = [] + self.layer_to_lora_weights = {} self.trust_remote_code = trust_remote_code @@ -113,6 +118,18 @@ def __init__( self.check_initialized() + self.profiler = None + if LORAX_PROFILER_DIR is not None: + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler(LORAX_PROFILER_DIR, use_gzip=True), + ) + self.profiler_steps = 0 + @property def info(self) -> InfoResponse: if self.requires_padding and self.sliding_window is not None: @@ -217,6 +234,12 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return [] + @property + def traced_adapter_layers(self) -> List[str]: + if self.layer_to_adapter_weights: + return list(self.layer_to_adapter_weights.keys()) + return self.default_traced_adapter_layers + @property def default_traced_adapter_layers(self) -> List[str]: return [] @@ -237,6 +260,10 @@ def max_speculative_tokens(self) -> int: def register_preloaded_adapters( self, preloaded_adapters: List[generate_pb2.PreloadedAdapter], adapter_memory_fractions: List[float] ): + if preloaded_adapters is None: + return + + self.dynamic_adapter_loading_enabled = False self.preloaded_adapter_indices.update({adapter.adapter_index for adapter in preloaded_adapters}) self.preloaded_adapter_memory_fractions.update( { @@ -246,6 +273,57 @@ def register_preloaded_adapters( ) self.preloaded_adapters.extend(preloaded_adapters) + if LORAX_PUNICA_TRITON_DISABLED: + # Following code is only applicable to Triton kernels + return + + # For Triton kernels: need weights into contiguous tensor + # dict of (layer_name, layer_id) -> (lora_a_weights, lora_b_weights) + # where: + # lora_a_weights = [num_adapters, r, hidden_size] + # lora_b_weights = [num_adapters, hidden_size, r] + for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items(): + layer_id_to_lora_a_weights = defaultdict(list) + layer_id_to_lora_b_weights = defaultdict(list) + for adapter in preloaded_adapters: + adapter_index = adapter.adapter_index + adapter_weights = layer_adapter_weights.adapter_weights.get(adapter_index) + if not isinstance(adapter_weights, LoraWeights): + if isinstance(adapter_weights, MedusaLoraWeights): + # only use lora component + adapter_weights = adapter_weights.lora_weights + else: + # only applicable to lora for now + continue + + if adapter_weights is None: + # no weights for this layer + continue + + # transpose into col major + lora_b = adapter_weights.weights_b.transpose(1, 2) + lora_a = adapter_weights.weights_a + if use_cutlass_shrink(lora_b.size(2)): + lora_a = lora_a.transpose(1, 2) + + nlayers = lora_a.size(0) + for layer_id in range(nlayers): + layer_id_to_lora_a_weights[layer_id].append(lora_a[layer_id]) + layer_id_to_lora_b_weights[layer_id].append(lora_b[layer_id]) + + for layer_id, lora_a_weights in layer_id_to_lora_a_weights.items(): + lora_b_weights = layer_id_to_lora_b_weights[layer_id] + + # right pad every adapter to the max rank + r = max([w.size(-1) for w in lora_b_weights]) + lora_a_weights = [pad_to_min_rank(w, 0, r) for w in lora_a_weights] + lora_b_weights = [pad_to_min_rank(w, 1, r) for w in lora_b_weights] + + # stack into [num_adapters, r, hidden_size] and [num_adapters, hidden_size, r] + lora_a_weights = torch.stack(lora_a_weights).to(self.device).contiguous() + lora_b_weights = torch.stack(lora_b_weights).to(self.device).contiguous() + self.layer_to_lora_weights[(layer_name, layer_id)] = (lora_a_weights, lora_b_weights) + def load_adapter( self, adapter_parameters: AdapterParameters, @@ -269,10 +347,9 @@ def load_adapter( if dynamic and not self.dynamic_adapter_loading_enabled: raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." + "This model does not support dynamic adapter loading. " + "Please initialize a new model instance from the base model and remove preloaded adapters " + "to use the dynamic adapter loading feature." ) logger.info(f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}") @@ -351,10 +428,9 @@ def offload_adapter( if not self.dynamic_adapter_loading_enabled: raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." + "This model does not support dynamic adapter loading. " + "Please initialize a new model instance from the base model and remove preloaded adapters " + "to use the dynamic adapter loading feature." ) for layer_name in self.adapter_layers: diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 62ad44b02..b5e7654dd 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -181,7 +181,7 @@ def from_pb( max_s = 0 cumulative_length = 0 - adapter_set = set() + adapter_list = [] for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): tokenized_input = tokenized_input[-r.truncate :] @@ -199,7 +199,7 @@ def from_pb( position_ids.append(request_position_ids) adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) - adapter_set.add(r.adapter_index) + adapter_list.append(r.adapter_index) cumulative_length += input_length @@ -232,7 +232,8 @@ def from_pb( size=len(batch_tokenized_inputs), adapter_meta=AdapterBatchMetadata( adapter_indices=adapter_indices, - adapter_set=adapter_set, + adapter_list=adapter_list, + adapter_set=set(adapter_list), adapter_segments=adapter_segments, segment_indices=adapter_segment_indices, ), diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index a75906ea6..d8a2595cf 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -23,7 +23,7 @@ enum_string_to_adapter_source, is_base_model, ) -from lorax_server.utils.sgmv import has_sgmv +from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, has_sgmv from lorax_server.utils.state import set_max_prefill_tokens, set_speculative_tokens @@ -62,6 +62,7 @@ async def ClearCache(self, request, context): self.cache.delete(request.id) else: self.cache.clear() + return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context): @@ -110,6 +111,12 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) + if self.model.profiler: + self.model.profiler_steps += 1 + if self.model.profiler_steps == 10: + self.model.profiler.stop() + print(self.model.profiler.key_averages()) + return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, @@ -319,6 +326,13 @@ async def serve_inner( except ImportError: pass + # set speculative decoding tokens + speculative_tokens = max(model.max_speculative_tokens, speculative_tokens) + if speculative_tokens > 0: + # Only use ngram speculation if the model does not support speculative tokens itself + use_ngram = model.max_speculative_tokens == 0 + set_speculative_tokens(speculative_tokens, use_ngram=use_ngram) + if preloaded_adapter_ids: logger.info(f"Preloading {len(preloaded_adapter_ids)} adapters") @@ -390,11 +404,6 @@ def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool: adapter_memory_fractions = [r.memory_fraction for r in download_responses] model.register_preloaded_adapters(preloaded_adapters, adapter_memory_fractions) - # set speculative decoding tokens - speculative_tokens = max(model.max_speculative_tokens, speculative_tokens) - if speculative_tokens > 0: - set_speculative_tokens(speculative_tokens) - server = aio.server( interceptors=[ ExceptionInterceptor(), @@ -412,10 +421,12 @@ def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool: await server.start() # Log SGMV kernel status + if not LORAX_PUNICA_TRITON_DISABLED and not model.dynamic_adapter_loading_enabled: + logger.info("Trion kernel is enabled, multi-LoRA inference will be fast!") if has_sgmv(): logger.info("SGMV kernel is enabled, multi-LoRA inference will be fast!") else: - logger.info("SGMV kernel is disabled, multi-LoRA inference may be slow") + logger.info("Punica kernels are disabled, multi-LoRA inference may be slow") logger.info("Server started at {}".format(local_url)) diff --git a/server/lorax_server/utils/attention/utils.py b/server/lorax_server/utils/attention/utils.py deleted file mode 100644 index 8292be916..000000000 --- a/server/lorax_server/utils/attention/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import List - -import torch - - -def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int] -) -> torch.Tensor: - """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(cache_lengths) - - total_len = sum(input_lengths) + sum(cache_lengths) - block_tables_ragged = torch.empty(total_len, dtype=torch.int32, device=block_tables.device) - - offset = 0 - for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): - seq_len = cache_length + input_length - block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] - offset += seq_len - - return block_tables_ragged diff --git a/server/lorax_server/utils/flashinfer_attention.py b/server/lorax_server/utils/flashinfer_attention.py index 87bbbd218..cc99b7a51 100644 --- a/server/lorax_server/utils/flashinfer_attention.py +++ b/server/lorax_server/utils/flashinfer_attention.py @@ -81,7 +81,7 @@ def use_prefill_with_paged_kv_state( head_dim=head_size, q_data_type=dtype, page_size=page_size, - # window_left=window_left, # TODO + window_left=window_left, ) yield finally: diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 387ddf86c..8baaed757 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -16,10 +16,10 @@ from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.adapters.lora import BatchLoraWeights, RankSegments from lorax_server.adapters.types import LORA +from lorax_server.models.metadata_kernels import block_tables_to_ragged from lorax_server.utils.attention.common import Seqlen -from lorax_server.utils.attention.utils import block_tables_to_ragged -from lorax_server.utils.sgmv import BGMV_MAX_RANK -from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER +from lorax_server.utils.punica import BGMV_MAX_RANK, PunicaWrapper +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, get_speculative_tokens if TYPE_CHECKING: from lorax_server.models.flash_causal_lm import FlashCausalLMBatch @@ -155,10 +155,13 @@ def get_max_graph_state( adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), + adapter_list=[], adapter_set=set(), adapter_segments=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), segment_indices=[], ), + layer_to_lora_weights={}, + punica_wrapper=None, data=adapter_weight_data, prefill=False, ), @@ -198,6 +201,8 @@ def trace( num_kv_heads: int, sliding_window_blocks: Optional[int] = None, traced_adapter_layer_names: Optional[Set[str]] = None, + layer_to_lora_weights: Dict[str, Dict[str, Any]] = {}, + punica_wrapper: Optional[PunicaWrapper] = None, ) -> "GraphWrapper": max_input_state = get_max_graph_state(device, adapter_layers, max_total_tokens, sliding_window_blocks) @@ -258,10 +263,14 @@ def trace( block_tables=block_tables, input_lengths=input_lengths, cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_total_tokens, ) block_tables_ptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) + state = create_decode_state_cuda_graphs( device=max_input_state.input_ids.device, block_tables=block_tables, @@ -271,6 +280,15 @@ def trace( num_kv_heads=num_kv_heads, ) + meta = AdapterBatchMetadata( + adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], + adapter_list=max_input_state.adapter_data.meta.adapter_list, + adapter_set=max_input_state.adapter_data.meta.adapter_set, + adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size], + segment_indices=max_input_state.adapter_data.meta.segment_indices, + ) + punica_wrapper.update_metadata(meta=meta, prefill=False) + input_state = GraphState( input_ids=max_input_state.input_ids[:batch_size], position_ids=max_input_state.position_ids[:batch_size], @@ -287,12 +305,9 @@ def trace( cache_lengths=cache_lengths, cache_lengths_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( - meta=AdapterBatchMetadata( - adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], - adapter_set=max_input_state.adapter_data.meta.adapter_set, - adapter_segments=max_input_state.adapter_data.meta.adapter_segments[:batch_size], - segment_indices=max_input_state.adapter_data.meta.segment_indices, - ), + meta=meta, + layer_to_lora_weights=layer_to_lora_weights, + punica_wrapper=punica_wrapper, data=adapter_weight_data, prefill=False, ), @@ -324,6 +339,7 @@ def trace( adapter_data=input_state.adapter_data, prefill_cache_indices=None, lm_head_indices=None, + skip_lm_head=get_speculative_tokens() > 0, ) torch.cuda.synchronize() @@ -341,6 +357,7 @@ def trace( adapter_data=input_state.adapter_data, prefill_cache_indices=None, lm_head_indices=None, + skip_lm_head=get_speculative_tokens() > 0, ) torch.cuda.synchronize(device) @@ -375,6 +392,9 @@ def forward( block_tables=block_tables, input_lengths=seqlen.input_lengths, cache_lengths=seqlen.cache_lengths, + input_lengths_tensor=seqlen.input_lengths, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, ) self.input_state.block_tables[: block_tables.shape[0]] = block_tables else: @@ -403,6 +423,8 @@ def forward( pad_and_fill(dest_rank_data.lora_b_ptr, source_rank_data.lora_b_ptr, 0) pad_and_fill(dest_rank_data.indices, source_rank_data.indices, SEGMENT_PAD_VALUE) + self.input_state.adapter_data.punica_wrapper.update_metadata(meta=adapter_data.meta, prefill=False) + with self.forward_context( block_tables=self.input_state.block_tables, cu_seqlen_prefill=None, @@ -433,6 +455,8 @@ def __init__( num_heads: int, num_kv_heads: int, sliding_window_blocks: Optional[int] = None, + layer_to_lora_weights: Dict[str, Dict[str, Any]] = {}, + punica_wrapper: Optional[PunicaWrapper] = None, ): self.model = model self.device = device @@ -446,6 +470,8 @@ def __init__( self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.sliding_window_blocks = sliding_window_blocks + self.layer_to_lora_weights = layer_to_lora_weights + self.punica_wrapper = punica_wrapper def can_use_graph( self, @@ -502,6 +528,8 @@ def get_estimated_cache_memory(self) -> int: self.num_kv_heads, self.sliding_window_blocks, self.adapter_layers, # estimate memory assuming all adapters are traced + self.layer_to_lora_weights, + self.punica_wrapper, ) tmp_cache[key] = graph pool = graph.memory_pool @@ -527,6 +555,7 @@ def get_estimated_cache_memory(self) -> int: def warmup(self): ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS) pool = None + logger.info("Tracing CUDA graphs with initial adapter layers: {}", self.default_traced_adapter_layers) with tqdm(total=ngraphs, desc="Trace CUDA graphs") as pbar: for batch_size in reversed(CACHED_BATCH_SIZES): pbar.set_postfix({"batch_size": batch_size}) @@ -546,6 +575,8 @@ def warmup(self): self.num_kv_heads, self.sliding_window_blocks, self.default_traced_adapter_layers, + self.layer_to_lora_weights, + self.punica_wrapper, ) self.cache[key] = graph pool = graph.memory_pool @@ -577,7 +608,8 @@ def forward( graph.input_state.traced_adapter_layer_names if graph is not None else set() ) logger.info( - "Retrace graph with new adapter layers: {} -> {}", + "batch_size={} -- retrace graph with new adapter layers: {} -> {}", + batch_size, current_traced_adapter_layer_names, adapter_data.layer_names(), ) @@ -595,6 +627,8 @@ def forward( self.num_kv_heads, self.sliding_window_blocks, adapter_data.layer_names(), + self.layer_to_lora_weights, + self.punica_wrapper, ) self.cache[key] = graph diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index d7f54a420..0feaae609 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -12,7 +12,7 @@ from lorax_server.layers.linear import FastLinear, get_linear # noqa: F401 from lorax_server.layers.tensor_parallel import SuperLayer, TensorParallelColumnLinear, TensorParallelHead # noqa: F401 from lorax_server.utils.lora import LM_HEAD -from lorax_server.utils.sgmv import ( +from lorax_server.utils.punica import ( add_lora_a_bgmv, add_lora_b_bgmv, has_sgmv, @@ -74,8 +74,37 @@ def forward_layer_type( ) -> torch.Tensor: data = adapter_data.data.get(layer_type) data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None + can_vectorize = data is not None and data.can_vectorize(self.process_group) + + # Triton Punica kernels + key = (layer_type, self.layer_id) + if ( + adapter_data.punica_wrapper is not None and adapter_data.punica_wrapper.enabled + and key in adapter_data.layer_to_lora_weights + and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size + and can_vectorize + ): + if end_idx - start_idx != result.shape[1]: + y_offset = start_idx + y_slice_size = end_idx - start_idx + else: + y_offset = None + y_slice_size = None + + lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[key] + adapter_data.punica_wrapper.add_lora( + result, + input, + lora_a_weights, + lora_b_weights, + 1.0, + y_offset, + y_slice_size, + callback=self.collect_lora_a if self.process_group.size() > 1 else None, + ) - if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + # Legacy Punica kernels + elif has_sgmv() and can_vectorize: if end_idx - start_idx != result.shape[1]: proj = torch.zeros_like(result[:, start_idx:end_idx]) else: @@ -135,6 +164,8 @@ def forward_layer_type( if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj + + # Vanilla PyTorch else: adapter_indices = adapter_data.meta.adapter_indices if data is not None and data.prefill_head_indices is not None and data.layer_name == LM_HEAD: diff --git a/server/lorax_server/utils/ops/__init__.py b/server/lorax_server/utils/ops/__init__.py new file mode 100644 index 000000000..22f53e4d0 --- /dev/null +++ b/server/lorax_server/utils/ops/__init__.py @@ -0,0 +1 @@ +# Source: https://github.com/vllm-project/vllm/tree/main/vllm/lora/ops diff --git a/server/lorax_server/utils/ops/bgmv_expand.py b/server/lorax_server/utils/ops/bgmv_expand.py new file mode 100644 index 000000000..59562cee8 --- /dev/null +++ b/server/lorax_server/utils/ops/bgmv_expand.py @@ -0,0 +1,167 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = lora_ptr + l0_stride * lora_index + pid_sn * split_n_length * lora_k_stride + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + current_n_c = tl.max_contiguous(current_n, BLOCK_N) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n_c[:, None] * lora_k_stride + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + batches (int): batch size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + batches = lora_indices_tensor.size(0) + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + grid = lambda META: ( # noqa: E731 + META["SPLIT_N"], + batches, + ) + _bgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return diff --git a/server/lorax_server/utils/ops/bgmv_expand_slice.py b/server/lorax_server/utils/ops/bgmv_expand_slice.py new file mode 100644 index 000000000..a4eb1b425 --- /dev/null +++ b/server/lorax_server/utils/ops/bgmv_expand_slice.py @@ -0,0 +1,179 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = lora_ptr + l0_stride * lora_index + pid_sn * split_n_length * lora_k_stride + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + slice_offset * cn_stride + + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n[:, None] * lora_k_stride + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + batches (int): batch size + add_inputs (bool, optional): Defaults to False. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + batches = lora_indices_tensor.size(0) + + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + + grid = lambda META: ( # noqa: E731 + META["SPLIT_N"], + batches, + ) + _bgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return diff --git a/server/lorax_server/utils/ops/bgmv_shrink.py b/server/lorax_server/utils/ops/bgmv_shrink.py new file mode 100644 index 000000000..0937f4fa7 --- /dev/null +++ b/server/lorax_server/utils/ops/bgmv_shrink.py @@ -0,0 +1,149 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's + performance + """ + pid_sk = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + + offset_n = tl.arange(0, BLOCK_N) + offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K + a_ptr = input_ptr + cur_batch * xm_stride + b_ptr = lora_ptr + l0_stride * lora_index + accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32) + for k in range(0, K, BLOCK_K * SPLIT_K): + current_k = k + offset_k + current_k_c = tl.max_contiguous(current_k, BLOCK_K) + tiled_a = tl.load( + a_ptr + current_k_c, + mask=current_k < K, + other=0.0, + ) # [BLOCK_K] + b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) + + tiled_b = tl.load( + b_ptr + offset_n[:, None] * lora_k_stride + current_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + accumulator += tl.sum(tiled_a * tiled_b, 1) + accumulator *= scaling + offset_cn = tl.arange(0, BLOCK_N) + c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride + c_mask = offset_cn < N + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + scaling (float): Scaling factor. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + batches = lora_indices_tensor.size(0) + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_N = triton.next_power_of_2(N) + if override_config: + config = override_config + else: + # First try to load optimal config from the file + config = get_lora_op_configs("bgmv_shrink", batches, K) + + grid = lambda META: ( # noqa: E731 + META["SPLIT_K"], + batches, + ) + _bgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_N=BLOCK_N, + **config, + ) + return diff --git a/server/lorax_server/utils/ops/libentry.py b/server/lorax_server/utils/ops/libentry.py new file mode 100644 index 000000000..4572688b1 --- /dev/null +++ b/server/lorax_server/utils/ops/libentry.py @@ -0,0 +1,168 @@ +# Copied From https://github.com/FlagOpen/FlagGems + +import inspect + +import triton + + +class LibEntry(triton.KernelInterface): + def __init__( + self, + fn, + ): + self.fn = fn + self.arg_names = fn.arg_names + self.divisibility = 16 + self.kernel_cache = dict() + fn = self.fn + while not isinstance(fn, triton.runtime.JITFunction): + fn = fn.fn + self.jit_function: triton.runtime.JITFunction = fn + self.specialize_indices = [ + p.num for p in self.jit_function.params if not p.is_constexpr and not p.do_not_specialize + ] + self.do_not_specialize_indices = [ + p.num for p in self.jit_function.params if not p.is_constexpr and p.do_not_specialize + ] + + def key(self, spec_args, dns_args, const_args): + spec_key = [ + (arg.dtype, arg.data_ptr() % self.divisibility == 0) if hasattr(arg, "data_ptr") else (type(arg), arg) + for arg in spec_args + ] + dns_key = [ + arg.dtype + if hasattr(arg, "data_ptr") + else type(arg) + if not isinstance(arg, int) + else "i32" + if -(2**31) <= arg and arg <= 2**31 - 1 + else "u64" + if 2**63 <= arg and arg <= 2**64 - 1 + else "i64" + for arg in dns_args + ] + # const args passed by position + return tuple(spec_key + dns_key + const_args) + + def run(self, *args, **kwargs): + grid = kwargs["grid"] + # collect all the arguments + spec_args = [] # specialize arguments + dns_args = [] # do not specialize arguments + const_args = [] # constexpr arguments + k_args = [] # kernel arguments + for i, arg in enumerate(args): + if i in self.specialize_indices: + k_args.append(arg) + spec_args.append(arg) + elif i in self.do_not_specialize_indices: + k_args.append(arg) + dns_args.append(arg) + else: + const_args.append(arg) + for p in self.jit_function.params[len(args) :]: + if p.name in kwargs: + val = kwargs[p.name] + elif p.default is inspect._empty: + continue + else: + val = p.default + + if p.is_constexpr: + const_args.append(val) + elif p.do_not_specialize: + dns_args.append(val) + k_args.append(val) + else: + spec_args.append(val) + k_args.append(val) + + entry_key = self.key(spec_args, dns_args, const_args) + + if entry_key not in self.kernel_cache: + # compile the kernel also completes the related computations + kernel = self.fn.run(*args, **kwargs) + fn = self.fn + # collect constexpr arguments for grid computation + constexprs = {} + while not isinstance(fn, triton.runtime.JITFunction): + if isinstance(fn, triton.runtime.Autotuner): + config = fn.best_config + constexprs["num_warps"] = config.num_warps + constexprs["num_stages"] = config.num_stages + constexprs["num_ctas"] = config.num_ctas + constexprs = {**constexprs, **config.kwargs} + elif isinstance(fn, triton.runtime.Heuristics): + for v, heur in fn.values.items(): + constexprs[v] = heur( + { + **dict(zip(fn.arg_names, args)), + **kwargs, + **constexprs, + } + ) + else: + raise RuntimeError("Invalid Runtime Function") + fn = fn.fn + # In vLLM, certain kernels like fused_moe_kernel get the + # best_config(as kwargs) from a configuration json file, rather + # than using Autotuner & Heuristics. Therefore, all their constexprs + # (tl.constexpr) are assigned values through the following loop. + for p in self.jit_function.params: + if p.is_constexpr and p.name not in constexprs: + constexprs[p.name] = p.default # default=inspect._empty + self.kernel_cache[entry_key] = (kernel, constexprs) + else: + # load kernel from cache directly + kernel, constexprs = self.kernel_cache[entry_key] + + if callable(grid): + # collect all arguments to the grid fn,ie: + # 1. args, + # 2. kwargs, + # 3. all all other captured arguments in CompiledKernel from + # Autotunner & Heuristics when kwargs & captured args conflict, + # captured args have higher priority + # 4. We must filter out captured args with default value firstly + constexprs = {k: v for k, v in constexprs.items() if v is not inspect._empty} + meta = { + **dict(zip(self.arg_names, args)), + **kwargs, + **constexprs, + } + grid = grid(meta) + if isinstance(grid, tuple): + grid = grid + (1, 1) + elif isinstance(grid, list): + grid = grid + [1, 1] + kernel[grid[0:3]](*k_args) + # maintaining the same return type as the JITFunction.run + return kernel + + +def libentry(): + """ + Decorator for triton library entries. + Motivation: + The runtime overhead of Triton kernels is the reason for the lower + performance of small kernels, particularly evident with smaller models. + Using this decorator can reduce Triton runtime overhead. + How: + The `run` function of JITFunction needs to accomplish: + - Parameter binding using inspect + - KernelArg type wrapping + - Cache key calculation + When dealing with small size, these steps can become bottlenecks in + Triton runtime. Libentry simplifies these steps to reduce runtime + overhead, thereby improving the runtime expenses of small kernels. + NOTE: + When Triton is upgraded to version 3.0.0, libentry can be removed, + see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245 + + """ + + def decorator(fn): + return LibEntry(fn) + + return decorator diff --git a/server/lorax_server/utils/ops/sgmv_expand.py b/server/lorax_server/utils/ops/sgmv_expand.py new file mode 100644 index 000000000..083c03493 --- /dev/null +++ b/server/lorax_server/utils/ops/sgmv_expand.py @@ -0,0 +1,192 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from lorax_server.utils.ops.libentry import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + The sgmv's expand triton kernel is based on GroupGEMM. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride,) + b_ptr = lora_ptr + l0_stride * lora_index + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, mask=offset_k[None, :] < K - k * BLOCK_K, other=0) + tiled_b = tl.load(b_ptr, mask=offset_k[:, None] < K - k * BLOCK_K, other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + add_inputs: bool = False, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + """ + # print("!!! inputs", inputs.shape) + # print("!!! lora_b_weights", lora_b_weights.shape) + # print("!!! output_tensor", output_tensor.shape) + # print("!!! b_seq_start_loc", b_seq_start_loc) + # print("!!! seq_len_tensor", seq_len_tensor) + # print("!!! lora_indices_tensor", lora_indices_tensor) + # print("!!! batches", batches) + # print("!!! max_seq_length", max_seq_length) + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return diff --git a/server/lorax_server/utils/ops/sgmv_expand_slice.py b/server/lorax_server/utils/ops/sgmv_expand_slice.py new file mode 100644 index 000000000..2da04b947 --- /dev/null +++ b/server/lorax_server/utils/ops/sgmv_expand_slice.py @@ -0,0 +1,205 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from lorax_server.utils.ops.libentry import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride,) + b_ptr = lora_ptr + l0_stride * lora_index + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, mask=offset_k[None, :] < K - k * BLOCK_K, other=0) + tiled_b = tl.load(b_ptr, mask=offset_k[:, None] < K - k * BLOCK_K, other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset + c_ptr = out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + """_summary_ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output.. + """ + # print("!!! inputs", inputs.shape) + # print("!!! lora_b_weights", lora_b_weights.shape) + # print("!!! output_tensor", output_tensor.shape) + # print("!!! b_seq_start_loc", b_seq_start_loc) + # print("!!! seq_len_tensor", seq_len_tensor) + # print("!!! lora_indices_tensor", lora_indices_tensor) + # print("!!! batches", batches) + # print("!!! max_seq_length", max_seq_length) + # print("!!! slice_offset", slice_offset) + # print("!!! slice_size", slice_size) + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return diff --git a/server/lorax_server/utils/ops/sgmv_shrink.py b/server/lorax_server/utils/ops/sgmv_shrink.py new file mode 100644 index 000000000..80ea5921c --- /dev/null +++ b/server/lorax_server/utils/ops/sgmv_shrink.py @@ -0,0 +1,190 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from lorax_server.utils.ops.libentry import libentry + + +@libentry() +@triton.jit +def _sgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + xm_stride, # hidden_size + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. + The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, + introducing SPLIT-K can improve performance + """ + pid = tl.program_id(axis=0) + pid_sk = tl.program_id(axis=1) + cur_batch = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride + b_ptr = lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + offset_k[:, None] * lora_n_stride + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + tiled_a = tl.load(a_ptr, mask=offset_k[None, :] < k_remaining, other=0.0) + tiled_b = tl.load(b_ptr, mask=offset_k[:, None] < k_remaining, other=0.0) + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + scaling: float, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + scaling (float): Scaling factor. + """ + # print("!!! inputs", inputs.shape) + # print("!!! lora_a_weights", lora_a_weights.shape) + # print("!!! output_tensor", output_tensor.shape) + # print("!!! b_seq_start_loc", b_seq_start_loc) + # print("!!! seq_len_tensor", seq_len_tensor) + # print("!!! lora_indices_tensor", lora_indices_tensor) + # print("!!! batch_size", batches) + # print("!!! max_seq_length", max_seq_length) + # print("!!! scaling", scaling) + + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 32 + SPLIT_K = 8 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + SPLIT_K, + batches, + ) + + _sgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + ) + return diff --git a/server/lorax_server/utils/ops/utils.py b/server/lorax_server/utils/ops/utils.py new file mode 100644 index 000000000..4460188b8 --- /dev/null +++ b/server/lorax_server/utils/ops/utils.py @@ -0,0 +1,41 @@ +import functools +from typing import Dict + + +@functools.lru_cache +def _get_op_configs(op_type: str, batch: int, hidden_size: int): + # TODO: add optimal configurations + return None + + +def _check_divisibility(hidden_size: int): + # The bgmv_expand kernel requires that the hidden_size be divisible by + # the number below. + divisibility = [2, 4, 8, 16, 32, 64] + divisibility.sort(reverse=True) + for div in divisibility: + if hidden_size % div == 0: + return div + # hidden_size is an odd number + return 1 + + +def _get_default_config(op_type: str, batch: int, hidden_size: int): + if op_type == "expand": + return {"BLOCK_N": 256, "SPLIT_N": _check_divisibility(hidden_size), "num_warps": 8} + else: + return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} + + +def get_lora_op_configs(op_type: str, batch: int, hidden_size: int) -> Dict[str, int]: + """Inspired by `fused_moe_kernel` + The return value will be a dictionary mapping an irregular grid of batch + sizes and hidden_size to configurations of the bgmv-related kernel. + NOTE: It currently only supports the default configuration. We plan to + generate optimal configurations for different hardware in the future using + scripts similar to `benchmark_moe.py`. + """ + config = _get_op_configs(op_type, batch, hidden_size) + if not config: + config = _get_default_config(op_type, batch, hidden_size) + return config diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 9e41d4595..7d7b4c82c 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -48,7 +48,7 @@ def reshape_and_cache( elif SYSTEM == "xpu": ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slots) else: - torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, 'auto', 1.0, 1.0) + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0) def attention( @@ -138,7 +138,7 @@ def attention( block_size, max_s, None, - 'auto', + "auto", 1.0, 1.0, ) @@ -172,7 +172,7 @@ def attention( block_size, max_s, None, - 'auto', + "auto", 1.0, 1.0, ) diff --git a/server/lorax_server/utils/punica.py b/server/lorax_server/utils/punica.py new file mode 100644 index 000000000..fe5869e00 --- /dev/null +++ b/server/lorax_server/utils/punica.py @@ -0,0 +1,793 @@ +import os +import warnings +from functools import lru_cache +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from loguru import logger + +from lorax_server.utils.ops.bgmv_expand import bgmv_expand +from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice +from lorax_server.utils.ops.bgmv_shrink import bgmv_shrink +from lorax_server.utils.ops.sgmv_expand import sgmv_expand +from lorax_server.utils.ops.sgmv_expand_slice import sgmv_expand_slice +from lorax_server.utils.ops.sgmv_shrink import sgmv_shrink + +if TYPE_CHECKING: + from lorax_server.adapters.weights import AdapterBatchMetadata + +try: + import punica_kernels as _kernels + + HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) +except ImportError: + warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") + _kernels = None + HAS_SGMV = False + + +LORAX_PUNICA_TRITON_DISABLED = bool(os.environ.get("LORAX_PUNICA_TRITON_DISABLED", "")) +if LORAX_PUNICA_TRITON_DISABLED: + logger.info("LORAX_PUNICA_TRITON_DISABLED is set, disabling Punica Trion kernels.") + + +MIN_SGMV_RANK = 8 +MIN_RANK_CUSTOM = 16 +MAX_RANK_CUSTOM = 128 +SGMV_BLOCK_SIZE = 16 +BGMV_MAX_RANK = 128 + + +def has_sgmv() -> bool: + return HAS_SGMV + + +def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: + """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" + if not has_sgmv(): + return t + + # tensor parallelism will result in effective rank being divided by world_size, + # so we need to scale the min rank to offset that effect + min_rank = MIN_SGMV_RANK * world_size + return pad_to_min_rank(t, dim, min_rank) + + +def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor: + # if we're at or below the min rank, pad up to the min rank + # otherwise, pad to the nearest multiple of the block size + current_rank = t.size(dim) + target_rank = ( + min_rank + if current_rank <= min_rank + else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE + ) + if current_rank == target_rank: + return t + + pad_size = target_rank - current_rank + + # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + pad = [0, 0] * t.dim() + pad[(t.dim() - dim - 1) * 2 + 1] = pad_size + pad = tuple(pad) + + return F.pad(t, pad, mode="constant", value=0.0) + + +def use_cutlass_shrink(lora_rank: int) -> bool: + return lora_rank < MIN_RANK_CUSTOM + + +def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: + if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: + return t.transpose(0, 1) + return t + + +# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py +def add_lora_sgmv_cutlass( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.Tensor, + s_end: torch.Tensor, + layer_idx: int, + lora_rank: int, +): + """ + Semantics: + y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H1]`. + wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H2]`. + s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. + s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. + layer_idx: Layer index of the weight matrices. + """ + if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: + # Custom SGMV shrink only supports rank 16, 32, 64, 128 + _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank) + return + + tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) + tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) + + +def _add_lora_sgmv_cutlass_legacy( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +): + tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +@lru_cache(maxsize=1) +def get_tmp_tensor(device: torch.device) -> torch.Tensor: + return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) + + +@lru_cache(maxsize=32) +def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: + tmp_size = _kernels.sgmv_cutlass_tmp_size(size) + return torch.empty((tmp_size,), dtype=torch.uint8, device=device) + + +def get_tmp_expand_size(size: int) -> int: + return _kernels.sgmv_cutlass_tmp_size(size) + + +def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + if use_cutlass_shrink(lora_rank): + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp + else: + tmp_shrink = get_tmp_tensor(device) + tmp_expand = get_tmp_tensor_for_size(nsegments, device) + return tmp_shrink, tmp_expand + + +def lora_a_sgmv_cutlass( + x: torch.Tensor, + tmp: torch.Tensor, + wa_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +) -> torch.Tensor: + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + else: + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + return v + + +def lora_b_sgmv_cutlass( + y: torch.Tensor, + v: torch.Tensor, + tmp: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, +): + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +""" +Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + +Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + v: Shape: `[B, R]`. Temporary vector. + x: Shape: `[B, H1]`. Input vectors. + wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. + wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. +""" + + +def add_lora_a_bgmv( + v: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) + + +def add_lora_b_bgmv( + y: torch.Tensor, + v: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) + + +def segmented_matmul( + y: torch.Tensor, + x: torch.Tensor, + w: List[torch.Tensor], + b: List[torch.Tensor], + s_start: torch.IntTensor, + s_end: torch.IntTensor, +): + for i in range(len(w)): + if s_end[i] - s_start[i] <= 0: + continue + + xi = x[s_start[i] : s_end[i]] + wi = w[i] + bi = b[i] + y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) + + +def compute_meta(token_lora_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, batch_size, max_length, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + meta: "AdapterBatchMetadata", + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context=None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = meta.adapter_indices.tolist() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), device="cuda", dtype=torch.long) + prompt_mapping = meta.adapter_list.copy() + lora_idx = None + for i in range(len(index_mapping_indices)): + lora_idx = index_mapping_indices[i] + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get(index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, device="cuda", dtype=torch.long) + embeddings_indices = torch.stack( + [ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ] + ) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange(0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded) + ) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) + + +# Source: https://github.com/vllm-project/vllm/blob/main/vllm/lora/punica.py +class PunicaWrapper: + """ + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, device: str, enabled: bool): + self._token_lora_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) + self._embeddings_indices = torch.empty(2, max_num_batched_tokens, dtype=torch.long, device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) + + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device) + self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device) + self._lora_indices_per_batch = torch.empty(max_batches, dtype=torch.long, device=device) + self.max_batch_size = max_batches + self.max_length: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + self.enabled = enabled + + def update_metadata( + self, + meta: "AdapterBatchMetadata", + prefill: bool, + ): + # token_lora_indices is adapter_indices - 1 to account for base model offset + base_indices = meta.adapter_indices - 1 + + self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices) + # self._token_lora_indices = base_indices + self.indices_len[0] = base_indices.shape[-1] + + if prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self._token_lora_indices, base_indices.shape[-1]) + self.is_prefill = True + else: + self.is_prefill = False + + def _update_base_metadata( + self, + meta: "AdapterBatchMetadata", + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context=None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + meta, + max_loras, + vocab_size, + extra_vocab_size, + long_lora_context, + ) + self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_(sampler_indices_padded) + self._embeddings_indices[: embeddings_indices.shape[0], : embeddings_indices.shape[1]].copy_(embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[: long_lora_offsets_tensor.shape[0]].copy_(long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor, indices_len: int) -> None: + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, batch_size, max_length, no_lora) = compute_meta( + token_lora_tensor[:indices_len] + ) + + self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor) + self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_(lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.no_lora = no_lora + + @property + def prefill_metadata(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions + 2. seq_lengths: Tensor of sequence lengths + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: batch size after clustering identical lora indices + 5. max_length: The maximum sequence length in the batch + """ + return ( + self._seq_start_locs[: self.batch_size], + self._seq_lengths[: self.batch_size], + self._lora_indices_per_batch[: self.batch_size], + self.batch_size, + self.max_length, + ) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + # No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + # No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + # No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_input) + + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the shrink_decode function + should be called. + """ + shrink_fun: Callable = self.shrink_prefill if self.is_prefill else self.shrink_decode + shrink_fun(y, x, w_t_all, scale) + + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'b. + When `is_prefill` is true, it indicates that it is currently the + prefill stage, and the `expand_prefill` function should be called. + Otherwise, it is the decode stage, and the expand_decode function + should be called. + """ + + expand_fun: Callable = self.expand_prefill if self.is_prefill else self.expand_decode + expand_fun(y, x, w_t_all, add_input) + + def add_expand_slice( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + ): + """ + Similar to `add_expand` + """ + + expand_slice_fun: Callable = self.expand_slice_prefill if self.is_prefill else self.expand_slice_decode + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def add_lora( + self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None, + callback: Optional[Callable] = None, + ): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + wa_t_all (torch.Tensor): lora_a's weight + wb_t_all (torch.Tensor): lora_b's weight + scale (float): Scaling factor. + y_offset (Optional[int], optional): Offset to apply to the starting + column of y. + y_slice_size (Optional[int], optional): Size of the y column slice.. + buffer (Optional[torch.Tensor], optional): Defaults to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + + self.add_shrink(buffer, x, wa_t_all, scale) + + if callback is not None: + # callback used to aggregate intermediate results (i.e., allreduce, allgather) + buffer = callback(buffer) + + if y_offset is None and y_slice_size is None: + self.add_expand(y, buffer, wb_t_all, add_input=True) + else: + self.add_expand_slice(y, buffer, wb_t_all, y_offset, y_slice_size, add_input=True) + y = y.view_as(y_org) + + def add_lora_packed_nslice( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + scale: float, + output_slices: Tuple[int, ...], + ) -> None: + """ + Applies lora to each input. Similar to add_lora, This method is + used for layers that are composed of multiple sublayers + (slices) packed together. + """ + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + # TODO fuse these kernels + for slice_idx in range(len(output_slices)): + self.add_lora( + y, x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], scale, offset_left, output_slices[slice_idx] + ) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + ) -> None: + """ + LogitsProcessorWithLoRA always using bgmv + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + + bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + y = y.view_as(y_org) diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py deleted file mode 100644 index 6efb2647f..000000000 --- a/server/lorax_server/utils/sgmv.py +++ /dev/null @@ -1,236 +0,0 @@ -import os -import warnings -from functools import lru_cache -from typing import List, Tuple - -import torch -import torch.nn.functional as F - -try: - import punica_kernels as _kernels - - HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) -except ImportError: - warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") - _kernels = None - HAS_SGMV = False - - -MIN_SGMV_RANK = 8 -MIN_RANK_CUSTOM = 16 -MAX_RANK_CUSTOM = 128 -SGMV_BLOCK_SIZE = 16 -BGMV_MAX_RANK = 128 - - -def has_sgmv() -> bool: - return HAS_SGMV - - -def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: - """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" - if not has_sgmv(): - return t - - # tensor parallelism will result in effective rank being divided by world_size, - # so we need to scale the min rank to offset that effect - min_rank = MIN_SGMV_RANK * world_size - - # if we're at or below the min rank, pad up to the min rank - # otherwise, pad to the nearest multiple of the block size - current_rank = t.size(dim) - target_rank = ( - min_rank - if current_rank <= min_rank - else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE - ) - if current_rank == target_rank: - return t - - pad_size = target_rank - current_rank - - # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - pad = [0, 0] * t.dim() - pad[(t.dim() - dim - 1) * 2 + 1] = pad_size - pad = tuple(pad) - - return F.pad(t, pad, mode="constant", value=0.0) - - -def use_cutlass_shrink(lora_rank: int) -> bool: - return lora_rank < MIN_RANK_CUSTOM - - -def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: - if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: - return t.transpose(0, 1) - return t - - -# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py -def add_lora_sgmv_cutlass( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.Tensor, - s_end: torch.Tensor, - layer_idx: int, - lora_rank: int, -): - """ - Semantics: - y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H1]`. - wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H2]`. - s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. - s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. - layer_idx: Layer index of the weight matrices. - """ - if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: - # Custom SGMV shrink only supports rank 16, 32, 64, 128 - _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank) - return - - tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) - tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) - - -def _add_lora_sgmv_cutlass_legacy( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -): - tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -@lru_cache(maxsize=1) -def get_tmp_tensor(device: torch.device) -> torch.Tensor: - return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) - - -@lru_cache(maxsize=32) -def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: - tmp_size = _kernels.sgmv_cutlass_tmp_size(size) - return torch.empty((tmp_size,), dtype=torch.uint8, device=device) - - -def get_tmp_expand_size(size: int) -> int: - return _kernels.sgmv_cutlass_tmp_size(size) - - -def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: - if use_cutlass_shrink(lora_rank): - tmp = get_tmp_tensor_for_size(nsegments, device) - return tmp, tmp - else: - tmp_shrink = get_tmp_tensor(device) - tmp_expand = get_tmp_tensor_for_size(nsegments, device) - return tmp_shrink, tmp_expand - - -def lora_a_sgmv_cutlass( - x: torch.Tensor, - tmp: torch.Tensor, - wa_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -) -> torch.Tensor: - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - else: - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - return v - - -def lora_b_sgmv_cutlass( - y: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, -): - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -""" -Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - -Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - v: Shape: `[B, R]`. Temporary vector. - x: Shape: `[B, H1]`. Input vectors. - wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. - wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. -""" - - -def add_lora_a_bgmv( - v: torch.Tensor, - x: torch.Tensor, - wa_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) - - -def add_lora_b_bgmv( - y: torch.Tensor, - v: torch.Tensor, - wb_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) - - -def segmented_matmul( - y: torch.Tensor, - x: torch.Tensor, - w: List[torch.Tensor], - b: List[torch.Tensor], - s_start: torch.IntTensor, - s_end: torch.IntTensor, -): - for i in range(len(w)): - if s_end[i] - s_start[i] <= 0: - continue - - xi = x[s_start[i] : s_end[i]] - wi = w[i] - bi = b[i] - y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index 130cdc6e5..5566208b8 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -6,10 +6,13 @@ WARMUP = False SPECULATIVE_TOKENS = 0 +NGRAM = False +LORAX_PROFILER_DIR = os.environ.get("LORAX_PROFILER_DIR", None) PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", "")) CHUNKED_PREFILL = bool(os.environ.get("CHUNKED_PREFILL", "")) +LORAX_SPECULATION_MAX_BATCH_SIZE = int(os.environ.get("LORAX_SPECULATION_MAX_BATCH_SIZE", 32)) # Always use flashinfer when prefix caching is enabled FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) or PREFIX_CACHING @@ -21,6 +24,9 @@ logger.info(f"Prefix caching = {PREFIX_CACHING}") logger.info(f"Chunked prefill = {CHUNKED_PREFILL}") +if LORAX_PROFILER_DIR: + logger.info(f"Torch profiling enabled, output dir = {LORAX_PROFILER_DIR}") + SUPPORTS_CHUNKING: Optional[bool] = None MAX_PREFILL_TOKENS: Optional[int] = None @@ -50,15 +56,21 @@ def warmup_mode(): set_warmup(False) -def set_speculative_tokens(value: int): +def set_speculative_tokens(value: int, use_ngram: bool): global SPECULATIVE_TOKENS + global NGRAM SPECULATIVE_TOKENS = value + NGRAM = use_ngram def get_speculative_tokens() -> int: return SPECULATIVE_TOKENS +def use_ngram() -> bool: + return NGRAM + + def set_supports_chunking(supports_chunking: bool): global SUPPORTS_CHUNKING SUPPORTS_CHUNKING = supports_chunking diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index f9cf935f9..c3231caea 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -22,6 +22,7 @@ OutlinesLogitsProcessor, static_warper, ) +from lorax_server.utils.state import use_ngram from lorax_server.utils.watermark import WatermarkLogitsProcessor @@ -419,7 +420,7 @@ def __call__( if speculative_scores is not None: # Only use greedy sampling for speculative tokens speculative_ids = Greedy()(speculative_scores) - else: + elif use_ngram(): speculative_ids = ngram_speculate(input_ids, next_ids, accepted_ids, speculate) return next_ids, next_logprobs, accepted_ids, speculative_ids diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index 682402cc1..4e3567720 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -15,14 +15,16 @@ def is_quantized(quantize): def is_fp8_supported(): - return torch.cuda.is_available() and \ - (torch.cuda.get_device_capability()[0] >= 9) or \ - (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) + return ( + torch.cuda.is_available() + and (torch.cuda.get_device_capability()[0] >= 9) + or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) + ) def is_fp8_kv(quantize): - return quantize and quantize == 'fp8-kv' + return quantize and quantize == "fp8-kv" def is_fp8(quantize): - return quantize and quantize.startswith('fp8') + return quantize and quantize.startswith("fp8") diff --git a/server/tests/adapters/test_medusa.py b/server/tests/adapters/test_medusa.py index bc808d1a9..fe9274cf2 100644 --- a/server/tests/adapters/test_medusa.py +++ b/server/tests/adapters/test_medusa.py @@ -30,6 +30,7 @@ def test_batched_medusa_weights(default_causal_lm: CausalLM): meta = AdapterBatchMetadata( adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), + adapter_list=[0, 1, 0, 1], adapter_set={0, 1}, adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), segment_indices=[0, 1, 0, 1], diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index 05f810365..70f3e5985 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -9,7 +9,7 @@ from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights from lorax_server.utils.lora import LM_HEAD -from lorax_server.utils.sgmv import MIN_RANK_CUSTOM +from lorax_server.utils.punica import MIN_RANK_CUSTOM class FakeAdapterWeights(AdapterWeights): @@ -74,6 +74,7 @@ def test_batched_lora_weights(lora_ranks: List[int]): meta = AdapterBatchMetadata( adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), + adapter_list=[0, 1, 0, 1], adapter_set={0, 1}, adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), segment_indices=[0, 1, 0, 1], @@ -149,6 +150,7 @@ def test_batched_lora_weights_decode( meta = AdapterBatchMetadata( adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64), + adapter_list=adapter_indices, adapter_set=set(adapter_indices), adapter_segments=torch.tensor(segments, dtype=torch.int64), segment_indices=segment_indices, @@ -193,7 +195,8 @@ def test_batched_lora_weights_no_segments(): meta = AdapterBatchMetadata( adapter_indices=torch.tensor([0, 0, 0, 0], dtype=torch.int64), - adapter_set={0, 1}, + adapter_list=[0], + adapter_set={0}, adapter_segments=torch.tensor([0, 4], dtype=torch.int64), segment_indices=[0], ) diff --git a/server/tests/utils/test_sgmv.py b/server/tests/utils/test_sgmv.py index 0c535f1b1..5b94270a0 100644 --- a/server/tests/utils/test_sgmv.py +++ b/server/tests/utils/test_sgmv.py @@ -3,7 +3,7 @@ import pytest import torch -from lorax_server.utils.sgmv import ( +from lorax_server.utils.punica import ( get_tmp_tensors, has_sgmv, lora_a_sgmv_cutlass, diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 2da49c049..1e4380c23 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -80,7 +80,12 @@ def test_deterministic_tokens_temperature_zero(default_causal_lm, default_causal attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] adapter_data = AdapterBatchData.from_meta( - batch.adapter_meta, default_causal_lm.layer_to_adapter_weights, prefill=True, prefill_head_indices=None + meta=batch.adapter_meta, + weights=default_causal_lm.layer_to_adapter_weights, + layer_to_lora_weights={}, + punica_wrapper=None, + prefill=True, + prefill_head_indices=None, ) logits, _ = default_causal_lm.forward(