diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index e4ccb2fd..659feac3 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -118,6 +118,8 @@ pub struct Branch { /// If None, no sampling is performed. /// If Some(set), only tokens from the set are allowed. pub sample_mask: Option, + /// Override temperature for sampling. It may or may not be sticky. + pub temperature: Option, /// Describes what to do after sampling. /// If no sampling, there should be exactly one splice, with empty `when_sampled`. pub splices: Vec, @@ -127,6 +129,7 @@ impl Clone for Branch { fn clone(&self) -> Self { Branch { sample_mask: self.sample_mask.clone(), + temperature: self.temperature, splices: self.splices.clone(), } } @@ -139,6 +142,7 @@ impl Branch { { Branch { sample_mask: self.sample_mask.as_ref().map(f), + temperature: self.temperature, splices: self.splices.clone(), } } @@ -146,6 +150,7 @@ impl Branch { pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { Branch { sample_mask: None, + temperature: None, splices: vec![Splice { when_sampled: vec![], backtrack, @@ -174,9 +179,14 @@ impl MidProcessResult { } pub fn sample(set: SimpleVob) -> Self { + Self::sample_with_temp(set, None) + } + + pub fn sample_with_temp(set: SimpleVob, temperature: Option) -> Self { MidProcessResult { branches: vec![Branch { sample_mask: Some(set), + temperature: temperature, splices: vec![], }], } diff --git a/controllers/guidance_ctrl/src/earley/from_guidance.rs b/controllers/guidance_ctrl/src/earley/from_guidance.rs index 98c4d92e..86084b25 100644 --- a/controllers/guidance_ctrl/src/earley/from_guidance.rs +++ b/controllers/guidance_ctrl/src/earley/from_guidance.rs @@ -16,6 +16,7 @@ pub struct NodeProps { pub commit_point: bool, pub capture_name: String, pub max_tokens: i32, + pub temperature: f32, } impl NodeProps { @@ -28,6 +29,7 @@ impl NodeProps { commit_point: n.commit_point, capture_name: n.capture_name.to_string(), max_tokens: n.max_tokens, + temperature: 0.0, }, OneOffunction_type::select(n) => NodeProps { nullable: n.nullable, @@ -36,6 +38,7 @@ impl NodeProps { commit_point: n.commit_point, capture_name: n.capture_name.to_string(), max_tokens: n.max_tokens, + temperature: 0.0, }, OneOffunction_type::byte(n) => NodeProps { nullable: n.nullable, @@ -44,6 +47,7 @@ impl NodeProps { commit_point: n.commit_point, capture_name: n.capture_name.to_string(), max_tokens: i32::MAX, + temperature: n.temperature, }, OneOffunction_type::byte_range(n) => NodeProps { nullable: false, // n.nullable, @@ -52,6 +56,7 @@ impl NodeProps { commit_point: n.commit_point, capture_name: n.capture_name.to_string(), max_tokens: i32::MAX, + temperature: n.temperature, }, OneOffunction_type::model_variable(n) => NodeProps { nullable: n.nullable, @@ -60,6 +65,7 @@ impl NodeProps { commit_point: n.commit_point, capture_name: n.capture_name.to_string(), max_tokens: i32::MAX, + temperature: 0.0, }, OneOffunction_type::None => { panic!("None function type in guidance::Grammar") @@ -87,6 +93,7 @@ impl NodeProps { } else { Some(self.capture_name.clone()) }, + temperature: self.temperature, } } } diff --git a/controllers/guidance_ctrl/src/earley/grammar.rs b/controllers/guidance_ctrl/src/earley/grammar.rs index 10b33505..d070f94e 100644 --- a/controllers/guidance_ctrl/src/earley/grammar.rs +++ b/controllers/guidance_ctrl/src/earley/grammar.rs @@ -51,13 +51,14 @@ impl ModelVariable { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq)] pub struct SymbolProps { pub max_tokens: usize, pub commit_point: bool, pub capture_name: Option, pub hidden: bool, pub model_variable: Option, + pub temperature: f32, } impl Default for SymbolProps { @@ -68,6 +69,7 @@ impl Default for SymbolProps { max_tokens: usize::MAX, model_variable: None, capture_name: None, + temperature: 0.0, } } } diff --git a/controllers/guidance_ctrl/src/earley/parser.rs b/controllers/guidance_ctrl/src/earley/parser.rs index 92dfe392..8397ba84 100644 --- a/controllers/guidance_ctrl/src/earley/parser.rs +++ b/controllers/guidance_ctrl/src/earley/parser.rs @@ -335,6 +335,19 @@ impl Parser { .unwrap_or(usize::MAX) } + pub fn temperature(&self) -> f32 { + let mut temp = 0.0f32; + for i in self.curr_row().item_indices() { + let item = self.scratch.items[i]; + let sym = self.grammar.sym_idx_at(item.rule_idx()); + let data = self.grammar.sym_data(sym); + if data.is_terminal { + temp = temp.max(data.props.temperature); + } + } + temp + } + pub fn apply_tokens( &mut self, trie: &TokTrie, diff --git a/controllers/guidance_ctrl/src/tokenparser.rs b/controllers/guidance_ctrl/src/tokenparser.rs index e3a5361a..0615d905 100644 --- a/controllers/guidance_ctrl/src/tokenparser.rs +++ b/controllers/guidance_ctrl/src/tokenparser.rs @@ -190,6 +190,6 @@ impl TokenParser { return MidProcessResult::stop(); } - return MidProcessResult::sample(set); + return MidProcessResult::sample_with_temp(set, Some(self.parser.temperature())); } } diff --git a/controllers/jsctrl/src/jsctrl.rs b/controllers/jsctrl/src/jsctrl.rs index 9ede4ce1..67e7a635 100644 --- a/controllers/jsctrl/src/jsctrl.rs +++ b/controllers/jsctrl/src/jsctrl.rs @@ -334,6 +334,7 @@ mod aici_mod { let splices: Vec = b.get2("splices"); Branch { sample_mask: sample_mask.map(|ts| ts.inner), + temperature: None, splices: splices .into_iter() .map(|s| Splice { diff --git a/controllers/pyctrl/src/pyctrl.rs b/controllers/pyctrl/src/pyctrl.rs index 3f2284d8..c91342fe 100644 --- a/controllers/pyctrl/src/pyctrl.rs +++ b/controllers/pyctrl/src/pyctrl.rs @@ -651,6 +651,7 @@ impl AiciCtrl for Runner { Branch { sample_mask, + temperature: None, splices, } }); diff --git a/py/pyaici/comms.py b/py/pyaici/comms.py index ae58ac98..9dfbce02 100644 --- a/py/pyaici/comms.py +++ b/py/pyaici/comms.py @@ -279,6 +279,7 @@ def from_json(obj: dict) -> "Splice": @dataclass class Branch: mask: Optional[int] = None + temperature: Optional[float] = None splices: List[Splice] = field(default_factory=list) def find_splice(self, token: Optional[int]) -> Optional[Splice]: @@ -296,6 +297,7 @@ def is_splice(self) -> bool: def from_json(obj: dict) -> "Branch": return Branch( mask=obj.get("sample_mask"), + temperature=obj.get("temperature"), splices=[Splice.from_json(q) for q in obj["splices"]], ) diff --git a/rllm/rllm-base/src/engine.rs b/rllm/rllm-base/src/engine.rs index 029f7baa..d32f8506 100644 --- a/rllm/rllm-base/src/engine.rs +++ b/rllm/rllm-base/src/engine.rs @@ -398,10 +398,13 @@ impl RllmEngine { } let shm = &self.aicirt.as_mut().unwrap().bin_shm; - let slice = shm.slice_at_byte_offset::(mid_res.first_mask_byte_offset, - mid_res.mask_num_elts * mid_res.num_masks); + let slice = shm.slice_at_byte_offset::( + mid_res.first_mask_byte_offset, + mid_res.mask_num_elts * mid_res.num_masks, + ); Ok(( - self.tmodel.new_bias(slice, mid_res.num_masks, mid_res.mask_num_elts), + self.tmodel + .new_bias(slice, mid_res.num_masks, mid_res.mask_num_elts), seq_id_mapping, )) } @@ -511,6 +514,9 @@ impl RllmEngine { Some(b) => { let seq_idx = b.sample_mask.unwrap(); aici_bias.apply(&mut logits, seq_idx); + if let Some(t) = b.temperature { + sg.logits_processor.temperature = Some(t); + } } None => {} } diff --git a/scripts/test-guidance.sh b/scripts/test-guidance.sh index e206b514..bc0436e3 100755 --- a/scripts/test-guidance.sh +++ b/scripts/test-guidance.sh @@ -9,4 +9,4 @@ fi export AZURE_GUIDANCE_URL cd $(dirname $0)/../py/guidance -pytest --selected_model azure_guidance tests/models/test_azure_guidance.py +pytest --selected_model azure_guidance tests/models/test_azure_guidance.py"$@"