Skip to content

Commit

Permalink
add aicirt --cap-fork; drop inference_cap req
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Apr 12, 2024
1 parent 7b8aa6d commit 08f21d1
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 38 deletions.
2 changes: 1 addition & 1 deletion aicirt/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde_json::Value;

pub type ModuleInstId = usize;

#[derive(Serialize, Deserialize, Default, Clone)]
#[derive(Serialize, Deserialize, Clone)]
pub struct InferenceCapabilities {
#[serde(default)]
pub backtrack: bool,
Expand Down
26 changes: 10 additions & 16 deletions aicirt/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ struct Cli {
#[arg(long)]
bench: bool,

/// Fork test
/// Allow fork() in controllers.
#[arg(long)]
fork: bool,
cap_fork: bool,

/// Enable futex comms
#[arg(long, default_value_t = false)]
Expand Down Expand Up @@ -660,11 +660,6 @@ impl Stepper {
Ok(())
}

fn inference_caps(&mut self, caps: InferenceCapabilities) -> Result<Value> {
self.globals.inference_caps = caps;
Ok(json!({}))
}

fn aici_mid_process(&mut self, req: AiciMidProcessReq) -> Result<AiciMidProcessResp> {
let block_elts = self.globals.tokrx_info.vocab_size as usize;
let mut outputs = HashMap::default();
Expand Down Expand Up @@ -822,14 +817,7 @@ impl Exec for Stepper {
#[inline(never)]
fn exec(&mut self, json: Value, _auth: AuthInfo) -> Result<Value> {
match json["op"].as_str() {
Some("inference_caps") => Ok(self.inference_caps(serde_json::from_value(json)?)?),
Some("tokens") => {
let caps = &self.globals.inference_caps;
if !caps.backtrack || !caps.ff_tokens {
bail_user!("need at least backtrack and ff_tokens inference_caps")
}
Ok(json!({ "vocab_size": self.globals.tokrx_info.vocab_size }))
}
Some("tokens") => Ok(json!({ "vocab_size": self.globals.tokrx_info.vocab_size })),
Some("mid_process") => Ok(serde_json::to_value(
&self.aici_mid_process(serde_json::from_value(json)?)?,
)?),
Expand Down Expand Up @@ -1119,9 +1107,15 @@ fn main() -> () {
return ();
}

let inference_caps = InferenceCapabilities {
fork: cli.cap_fork,
backtrack: true,
ff_tokens: true,
};

let tokenizer = find_tokenizer(&cli.tokenizer).unwrap();
let token_bytes = tokenizer.token_bytes();
let wasm_ctx = WasmContext::new(limits.clone(), tokenizer).unwrap();
let wasm_ctx = WasmContext::new(inference_caps, limits.clone(), tokenizer).unwrap();

if cli.save_tokenizer.is_some() {
save_tokenizer(&cli);
Expand Down
15 changes: 12 additions & 3 deletions aicirt/src/moduleinstance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ use crate::{
TimerSet, UserError,
};
use aici_abi::{toktree::TokTrie, InitPromptArg, ProcessResultOffset, TokenId};
use aicirt::{api::{InferenceCapabilities, SequenceResult}, bail_user, bintokens::ByteTokenizer, user_error};
use aicirt::{
api::{InferenceCapabilities, SequenceResult},
bail_user,
bintokens::ByteTokenizer,
user_error,
};
use anyhow::{anyhow, ensure, Result};
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc, time::Instant};
Expand All @@ -26,7 +31,11 @@ impl WasmContext {
unsafe { wasmtime::Module::deserialize_file(&self.engine, path) }
}

pub fn new(limits: AiciLimits, tokenizer: ByteTokenizer) -> Result<Self> {
pub fn new(
inference_caps: InferenceCapabilities,
limits: AiciLimits,
tokenizer: ByteTokenizer,
) -> Result<Self> {
let mut cfg = wasmtime::Config::default();
// these are defaults as of 13.0.0, but we specify them anyways for stability
cfg.debug_info(false)
Expand Down Expand Up @@ -75,7 +84,7 @@ impl WasmContext {
tokrx_info: tokenizer.tokrx_info(),
trie_bytes: Arc::new(bytes),
hf_tokenizer: Arc::new(tokenizer.hf_tokenizer),
inference_caps: InferenceCapabilities::default(),
inference_caps,
};

Ok(Self {
Expand Down
10 changes: 4 additions & 6 deletions py/pyaici/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,10 @@ def __init__(
"--bin-size=" + str(bin_size),
"--name=" + pref,
"--server",
] + rtargs
]
if fork_supported:
args.append("--cap-fork")
args += rtargs

print("running: ", args)
self.proc = subprocess.Popen(args)
Expand All @@ -378,11 +381,6 @@ def cleanup():
atexit.register(cleanup)

self.cmd.exec("ping")
self.cmd.exec("inference_caps", {
"backtrack": True,
"ff_tokens": True,
"fork": fork_supported,
})
resp = self.cmd.exec("tokens")
self.vocab_size = resp["data"]["vocab_size"]

Expand Down
16 changes: 4 additions & 12 deletions rllm/rllm-base/src/iface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use aici_abi::{
};
use aicirt::{
api::{
AiciMidProcessReq, AiciMidProcessResp, AuthInfo, GetTagsResp, InferenceCapabilities,
InstantiateReq, MkModuleReq, MkModuleResp, SequenceResult, SetTagsReq, TokensResp,
AiciMidProcessReq, AiciMidProcessResp, AuthInfo, GetTagsResp, InstantiateReq, MkModuleReq,
MkModuleResp, SequenceResult, SetTagsReq, TokensResp,
},
futexshm::ClientChannel,
msgchannel::MessageChannel,
Expand Down Expand Up @@ -154,7 +154,8 @@ impl AiciRtIface {
.arg(&args.bin_size.to_string())
.arg("--name")
.arg(&args.shm_prefix)
.arg("--futex");
.arg("--futex")
.arg("--cap-fork");
for a in &args.add_args {
cmd_bld.arg(a);
}
Expand Down Expand Up @@ -205,15 +206,6 @@ impl AiciRtIface {

let _: Value = r.cmd.exec("ping", json!({}))?;

let _: Value = r.cmd.exec(
"inference_caps",
serde_json::to_value(InferenceCapabilities {
backtrack: true,
ff_tokens: true,
fork: true,
})?,
)?;

let tokens: TokensResp = r
.cmd
.exec("tokens", json!({}))
Expand Down

0 comments on commit 08f21d1

Please sign in to comment.