Skip to content

Commit

Permalink
Add Mirostat sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Nov 12, 2023
1 parent 7a512fd commit 8e29e00
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 30 deletions.
16 changes: 11 additions & 5 deletions examples/multiple_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@

# Create some sampling settings

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.8
settings.top_p = 0.75
settings_proto = ExLlamaV2Sampler.Settings()
settings_proto.temperature = 0.8
settings_proto.top_p = 0.75
# settings_proto.mirostat = True
# settings_proto.mirostat_tau = 5
# settings_proto.top_k = 1000

# Define some prompts to inference in parallel

Expand All @@ -50,10 +53,11 @@

max_parallel_seqs = 3

# Active sequences and corresponding caches
# Active sequences and corresponding caches and settings

input_ids = []
caches = []
settings = []

# Stats

Expand All @@ -80,6 +84,7 @@
model.forward(ids[:, :-1], cache, preprocess_only = True)
input_ids.append(ids)
caches.append(cache)
settings.append(settings_proto.clone()) # Need individual settings per prompt to support Mirostat

total_prompt_tokens += ids.shape[-1] -1
prompt_time += time.time() - time_begin
Expand All @@ -97,7 +102,7 @@
r = random.random()
for i in range(len(input_ids)):

token, _, _ = ExLlamaV2Sampler.sample(logits[i:i+1, :, :], settings, input_ids[i], r, tokenizer)
token, _, _ = ExLlamaV2Sampler.sample(logits[i:i+1, :, :], settings[i], input_ids[i], r, tokenizer)
input_ids[i] = torch.cat([input_ids[i], token], dim = 1)
total_gen_tokens += 1

Expand All @@ -116,6 +121,7 @@

input_ids.pop(i)
caches.pop(i)
settings.pop(i)

# Stats

Expand Down
56 changes: 56 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,62 @@ int tfs_cpu
return k;
}

int mirostat_pre_cpu
(
const int num_candidates,
float* temp_probs,
int* temp_indices,
float mirostat_mu,
float mirostat_tau,
float mirostat_eta
)
{
//TIME_START;

// If mu not yet initialized, initialize here

float mu = mirostat_mu;
if (mu == 0.0f) mu = mirostat_tau * 2.0f;

// Discard tokens with surprise greater than mu

int nc = sort_descending(num_candidates, temp_probs, temp_indices, num_candidates);

float target_prob = powf(2, -mu);
int k = 1;
for (; k < nc; k++)
{
if (-log2(temp_probs[k]) > mu) break;
}

//TIME_STOP;

return k;
}

float mirostat_post_cpu
(
const int num_candidates,
float* temp_probs,
int* temp_indices,
float mirostat_mu,
float mirostat_tau,
float mirostat_eta
)
{
// If mu not yet initializer, initialize here

float mu = mirostat_mu;
if (mu == 0.0f) mu = mirostat_tau * 2.0f;

// Adjust mu based on probability of final choice

float observed_surprise = -log2(temp_probs[0]);
mu += mirostat_eta * (mirostat_tau - observed_surprise);

return mu;
}

int typical_cpu
(
const int num_candidates,
Expand Down
20 changes: 20 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ int typical_cpu
float typical
);

int mirostat_pre_cpu
(
const int num_candidates,
float* temp_probs,
int* temp_indices,
float mirostat_mu,
float mirostat_tau,
float mirostat_eta
);

float mirostat_post_cpu
(
const int num_candidates,
float* temp_probs,
int* temp_indices,
float mirostat_mu,
float mirostat_tau,
float mirostat_eta
);

int multinomial_cpu
(
const int num_candidates,
Expand Down
21 changes: 19 additions & 2 deletions exllamav2/exllamav2_ext/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ void apply_rep_penalty
}
}

void sample_basic
std::vector<float> sample_basic
(
torch::Tensor logits, // shape [bsz, vocab_size]
float temperature,
Expand All @@ -754,7 +754,11 @@ void sample_basic
float random,
torch::Tensor output_tokens, // shape [bsz, 1]
torch::Tensor output_probs, // shape [bsz, 1]
torch::Tensor logit_filter // shape [bsz, vocab_size]
torch::Tensor logit_filter, // shape [bsz, vocab_size]
bool mirostat,
std::vector<float>& mirostat_mu,
float mirostat_tau,
float mirostat_eta
)
{
TORCH_CHECK_DTYPE(logits, kFloat);
Expand Down Expand Up @@ -830,13 +834,26 @@ void sample_basic
normalize_cpu(num_candidates, temp_probs);
}

if (mirostat)
{
num_candidates = mirostat_pre_cpu(num_candidates, temp_probs, temp_indices, mirostat_mu[i], mirostat_tau, mirostat_eta);
normalize_cpu(num_candidates, temp_probs);
}

num_candidates = multinomial_cpu(num_candidates, temp_probs, temp_indices, random);
output_tokens[i] = temp_indices[0];
output_probs[i] = temp_probs[0];

if (mirostat)
{
mirostat_mu[i] = mirostat_post_cpu(num_candidates, temp_probs, temp_indices, mirostat_mu[i], mirostat_tau, mirostat_eta);
}
}

free(temp_probs);
free(temp_indices);

return mirostat_mu;
}


Expand Down
67 changes: 44 additions & 23 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class Settings:
tfs = 0
typical = 0

mirostat = False
mirostat_tau = 1.5
mirostat_eta = 0.1
mirostat_mu = None # (re)initialized from mirostat_tau on first sample

token_bias = None

filters = []
Expand All @@ -26,14 +31,26 @@ class Settings:
def clone(self):

c = ExLlamaV2Sampler.Settings()
c.temperature = self.temperature
c.top_k = self.top_k
c.top_p = self.top_p

c.token_repetition_penalty = self.token_repetition_penalty
c.token_repetition_range = self.token_repetition_range
c.token_repetition_decay = self.token_repetition_decay

c.temperature = self.temperature
c.top_k = self.top_k
c.top_p = self.top_p
c.min_p = self.min_p
c.tfs = self.tfs
c.typical = self.typical

c.mirostat = self.mirostat
c.mirostat_tau = self.mirostat_tau
c.mirostat_eta = self.mirostat_eta
c.mirostat_mu = None if self.mirostat_mu is None else self.mirostat_mu.copy()

c.token_bias = self.token_bias
c.filters = [f.clone() for f in self.filters]

return c


Expand Down Expand Up @@ -126,36 +143,40 @@ def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor,
# if logit_filter[0, i].item():
# print(i)

# Begin Mirostat

if settings.mirostat:
if settings.mirostat_mu is None:
settings.mirostat_mu = [0.0] * batch_size

# Sampling

batch_size = logits.shape[0]

output_tokens = torch.empty((batch_size, 1), device = "cpu", dtype = torch.long)
output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
ext_c.sample_basic(logits,
settings.temperature,
settings.top_k,
settings.top_p,
settings.min_p,
settings.tfs,
settings.typical,
random,
output_tokens,
output_probs,
logit_filter)

m = ext_c.sample_basic(logits,
settings.temperature,
settings.top_k,
settings.top_p,
settings.min_p,
settings.tfs,
settings.typical,
random,
output_tokens,
output_probs,
logit_filter,
settings.mirostat,
settings.mirostat_mu if settings.mirostat else [],
settings.mirostat_tau,
settings.mirostat_eta)

if settings.mirostat: settings.mirostat_mu = m

# Stop condition from filters

end_filter = False
if len(settings.filters) > 0 and output_tokens[0].item() in end_tokens: end_filter = True

return output_tokens, output_probs, end_filter









0 comments on commit 8e29e00

Please sign in to comment.