Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add statistical sampling tests #73

Closed
brandonwillard opened this issue Oct 10, 2024 · 5 comments · Fixed by #77
Closed

Add statistical sampling tests #73

brandonwillard opened this issue Oct 10, 2024 · 5 comments · Fixed by #77
Assignees
Labels
enhancement New feature or request

Comments

@brandonwillard
Copy link
Member

We need simple tests that confirm some expectations regarding sampling via an FSM.

For example, such a test might use a simple vocabulary over "0", "1", and eos, enumerate all
the distinct token sequences of a given length—fixing the probabilities/logits, of course—and, with that, draw samples and assert something about the empirical probabilities.

N.B. This could be done after #67 (and possibly #69, #71) to avoid more refactoring.

@brandonwillard brandonwillard added the enhancement New feature or request label Oct 10, 2024
@dpsimpson
Copy link
Contributor

dpsimpson commented Oct 16, 2024

To start with, let's take a simple generative model over the three symbol alphabet $\{0, 1, \text{eos}\}$ with transition table

     | t-1 |  0  |  1  | eos |
  t  |     |     |     |     | 
-----|-----|-----|-----|-----|
  0  |     | 0.2 | 0.3 | 0   |
  1  |     | 0.5 | 0.4 | 0   |
 eos |     | 0.3 | 0.3 | 0   |

and $P(x_0 = 0) = 0.2$, $P(x_0 = 1) = 0.8$.

Then we want to do guided generation given the regex "11[01]+|0[01]*", which corresponds to the DFA

simple_regex-2

We want to construct some probabilistic tests for sampling from this DFA with these transition probabilities. The idea is to compute some statistical quantities of a generation and look compare simulations to the exact quantity.

Let's look at the length of the generation. Once we are in the penultimate state of the DFA, the number of transitions needed to move to the terminal state a negative-binomial(1, 0.1) distribution, which has mean 9. Hence, the expected length of a generated string (including the eos token) is

$$ 0.2 * 1 + 0.8 * (2 + 1) + 9 = 11.6 $$

(The 2+1 was missing, reflects the "[01]+", so the path starting with a 1 is always at least 3 long.)

This example can be expanded out in various ways. For instance, we can look at more complex functionals (eg number of 1s generated). We can also look at a more complex regex. Any thoughts about how complex we should get?

@brandonwillard
Copy link
Member Author

That sounds great!

Any thoughts about how complex we should get?

For now, we need the simplest tests possible for addressing statistical consistency. Our CI runs are already getting pretty long.

@dpsimpson
Copy link
Contributor

Do you think these should be in a separate set of tests, given that they are stochastic and, therefore, somewhat flakey?

@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 16, 2024

Do you think these should be in a separate set of tests, given that they are stochastic and, therefore, somewhat flakey?

These tests can start in their own module, if that's what you're asking; otherwise, we'll fix the seeds to avoid the flakiness for now.

@dpsimpson
Copy link
Contributor

dpsimpson commented Oct 17, 2024

Leaving this here just in case we need it in the future.

Let $X \sim Binom(0.8)$ and $Y \sim Neg-Binom(1,0.1)$, then the length is

$$(1-X) + 3X + Y = 1 + 2X + Y$$

which has mean $1 + 1.6 + 0.7/0.3 =4.93$ and variance

$$4*0.2*0.8 + 0.7/0.3^2 = 8.418$$

A quick calculation says that the standard deviation of the mean of 100 samples is about 0.3.

But I'm freezing the seed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants