Skip to content

Commit

Permalink
allow control of temperature
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed May 8, 2024
1 parent d665934 commit fd8c5ee
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 6 deletions.
10 changes: 10 additions & 0 deletions controllers/aici_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ pub struct Branch<S> {
/// If None, no sampling is performed.
/// If Some(set), only tokens from the set are allowed.
pub sample_mask: Option<S>,
/// Override temperature for sampling. It may or may not be sticky.
pub temperature: Option<f32>,
/// Describes what to do after sampling.
/// If no sampling, there should be exactly one splice, with empty `when_sampled`.
pub splices: Vec<Splice>,
Expand All @@ -127,6 +129,7 @@ impl<S: Clone> Clone for Branch<S> {
fn clone(&self) -> Self {
Branch {
sample_mask: self.sample_mask.clone(),
temperature: self.temperature,
splices: self.splices.clone(),
}
}
Expand All @@ -139,13 +142,15 @@ impl<S> Branch<S> {
{
Branch {
sample_mask: self.sample_mask.as_ref().map(f),
temperature: self.temperature,
splices: self.splices.clone(),
}
}

pub fn splice(backtrack: u32, ff_tokens: Vec<TokenId>) -> Self {
Branch {
sample_mask: None,
temperature: None,
splices: vec![Splice {
when_sampled: vec![],
backtrack,
Expand Down Expand Up @@ -174,9 +179,14 @@ impl MidProcessResult {
}

pub fn sample(set: SimpleVob) -> Self {
Self::sample_with_temp(set, None)
}

pub fn sample_with_temp(set: SimpleVob, temperature: Option<f32>) -> Self {
MidProcessResult {
branches: vec![Branch {
sample_mask: Some(set),
temperature: temperature,
splices: vec![],
}],
}
Expand Down
7 changes: 7 additions & 0 deletions controllers/guidance_ctrl/src/earley/from_guidance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct NodeProps {
pub commit_point: bool,
pub capture_name: String,
pub max_tokens: i32,
pub temperature: f32,
}

impl NodeProps {
Expand All @@ -28,6 +29,7 @@ impl NodeProps {
commit_point: n.commit_point,
capture_name: n.capture_name.to_string(),
max_tokens: n.max_tokens,
temperature: 0.0,
},
OneOffunction_type::select(n) => NodeProps {
nullable: n.nullable,
Expand All @@ -36,6 +38,7 @@ impl NodeProps {
commit_point: n.commit_point,
capture_name: n.capture_name.to_string(),
max_tokens: n.max_tokens,
temperature: 0.0,
},
OneOffunction_type::byte(n) => NodeProps {
nullable: n.nullable,
Expand All @@ -44,6 +47,7 @@ impl NodeProps {
commit_point: n.commit_point,
capture_name: n.capture_name.to_string(),
max_tokens: i32::MAX,
temperature: n.temperature,
},
OneOffunction_type::byte_range(n) => NodeProps {
nullable: false, // n.nullable,
Expand All @@ -52,6 +56,7 @@ impl NodeProps {
commit_point: n.commit_point,
capture_name: n.capture_name.to_string(),
max_tokens: i32::MAX,
temperature: n.temperature,
},
OneOffunction_type::model_variable(n) => NodeProps {
nullable: n.nullable,
Expand All @@ -60,6 +65,7 @@ impl NodeProps {
commit_point: n.commit_point,
capture_name: n.capture_name.to_string(),
max_tokens: i32::MAX,
temperature: 0.0,
},
OneOffunction_type::None => {
panic!("None function type in guidance::Grammar")
Expand Down Expand Up @@ -87,6 +93,7 @@ impl NodeProps {
} else {
Some(self.capture_name.clone())
},
temperature: self.temperature,
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion controllers/guidance_ctrl/src/earley/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ impl ModelVariable {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolProps {
pub max_tokens: usize,
pub commit_point: bool,
pub capture_name: Option<String>,
pub hidden: bool,
pub model_variable: Option<ModelVariable>,
pub temperature: f32,
}

impl Default for SymbolProps {
Expand All @@ -68,6 +69,7 @@ impl Default for SymbolProps {
max_tokens: usize::MAX,
model_variable: None,
capture_name: None,
temperature: 0.0,
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions controllers/guidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,19 @@ impl Parser {
.unwrap_or(usize::MAX)
}

pub fn temperature(&self) -> f32 {
let mut temp = 0.0f32;
for i in self.curr_row().item_indices() {
let item = self.scratch.items[i];
let sym = self.grammar.sym_idx_at(item.rule_idx());
let data = self.grammar.sym_data(sym);
if data.is_terminal {
temp = temp.max(data.props.temperature);
}
}
temp
}

pub fn apply_tokens(
&mut self,
trie: &TokTrie,
Expand Down
2 changes: 1 addition & 1 deletion controllers/guidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,6 @@ impl TokenParser {
return MidProcessResult::stop();
}

return MidProcessResult::sample(set);
return MidProcessResult::sample_with_temp(set, Some(self.parser.temperature()));
}
}
1 change: 1 addition & 0 deletions controllers/jsctrl/src/jsctrl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ mod aici_mod {
let splices: Vec<Object> = b.get2("splices");
Branch {
sample_mask: sample_mask.map(|ts| ts.inner),
temperature: None,
splices: splices
.into_iter()
.map(|s| Splice {
Expand Down
1 change: 1 addition & 0 deletions controllers/pyctrl/src/pyctrl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ impl AiciCtrl for Runner {

Branch {
sample_mask,
temperature: None,
splices,
}
});
Expand Down
2 changes: 2 additions & 0 deletions py/pyaici/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def from_json(obj: dict) -> "Splice":
@dataclass
class Branch:
mask: Optional[int] = None
temperature: Optional[float] = None
splices: List[Splice] = field(default_factory=list)

def find_splice(self, token: Optional[int]) -> Optional[Splice]:
Expand All @@ -296,6 +297,7 @@ def is_splice(self) -> bool:
def from_json(obj: dict) -> "Branch":
return Branch(
mask=obj.get("sample_mask"),
temperature=obj.get("temperature"),
splices=[Splice.from_json(q) for q in obj["splices"]],
)

Expand Down
12 changes: 9 additions & 3 deletions rllm/rllm-base/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,13 @@ impl<ME: ModelExec> RllmEngine<ME> {
}

let shm = &self.aicirt.as_mut().unwrap().bin_shm;
let slice = shm.slice_at_byte_offset::<f32>(mid_res.first_mask_byte_offset,
mid_res.mask_num_elts * mid_res.num_masks);
let slice = shm.slice_at_byte_offset::<f32>(
mid_res.first_mask_byte_offset,
mid_res.mask_num_elts * mid_res.num_masks,
);
Ok((
self.tmodel.new_bias(slice, mid_res.num_masks, mid_res.mask_num_elts),
self.tmodel
.new_bias(slice, mid_res.num_masks, mid_res.mask_num_elts),
seq_id_mapping,
))
}
Expand Down Expand Up @@ -511,6 +514,9 @@ impl<ME: ModelExec> RllmEngine<ME> {
Some(b) => {
let seq_idx = b.sample_mask.unwrap();
aici_bias.apply(&mut logits, seq_idx);
if let Some(t) = b.temperature {
sg.logits_processor.temperature = Some(t);
}
}
None => {}
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-guidance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ fi
export AZURE_GUIDANCE_URL

cd $(dirname $0)/../py/guidance
pytest --selected_model azure_guidance tests/models/test_azure_guidance.py
pytest --selected_model azure_guidance tests/models/test_azure_guidance.py"$@"

0 comments on commit fd8c5ee

Please sign in to comment.