Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] rename variance parameter to std (standard deviation) #4118

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions mojo/stdlib/src/random/random.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,17 @@ fn rand[
return


fn randn_float64(mean: Float64 = 0.0, variance: Float64 = 1.0) -> Float64:
"""Returns a random double sampled from a Normal(mean, variance) distribution.
fn randn_float64(mean: Float64 = 0.0, std: Float64 = 1.0) -> Float64:
"""Returns a random double sampled from a Normal(mean, std) distribution.

Args:
mean: Normal distribution mean.
variance: Normal distribution variance.
std: Normal distribution standard deviation.

Returns:
A random float64 sampled from Normal(mean, variance).
A random float64 sampled from Normal(mean, std).
"""
return external_call["KGEN_CompilerRT_NormalDouble", Float64](
mean, variance
)
return external_call["KGEN_CompilerRT_NormalDouble", Float64](mean, std)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhm, can you really just change that? if the underlying function expects a variance and you give it a standard deviation... you'll get another distribution than what you're expecting. You'd have to square the value for this to make sense.

Copy link
Contributor

@soraros soraros Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the bug was that the underlying function expects a std, and the argument was wrongly named?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, didn't read the linked issue



fn randn[
Expand All @@ -202,9 +200,9 @@ fn randn[
ptr: UnsafePointer[Scalar[type]],
size: Int,
mean: Float64 = 0.0,
variance: Float64 = 1.0,
std: Float64 = 1.0,
):
"""Fills memory with random values from a Normal(mean, variance) distribution.
"""Fills memory with random values from a Normal(mean, std) distribution.

Constraints:
The type should be floating point.
Expand All @@ -216,11 +214,11 @@ fn randn[
ptr: The pointer to the memory area to fill.
size: The number of elements to fill.
mean: Normal distribution mean.
variance: Normal distribution variance.
std: Normal distribution standard deviation.
"""

for i in range(size):
ptr[i] = randn_float64(mean, variance).cast[type]()
ptr[i] = randn_float64(mean, std).cast[type]()
return


Expand Down
47 changes: 45 additions & 2 deletions mojo/stdlib/test/random/test_random.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,51 @@ def test_random():
),
)

var random_normal = randn_float64(0, 1)
# it's quite hard to verify that the values returned are forming a normal distribution

def test_seed_normal():
seed(42)
# verify `randn_float64` values are normally distributed
var num_samples = 1000
var samples = List[Float64](capacity=num_samples)
for _ in range(num_samples):
samples.append(randn_float64(0, 2))

var sum: Float64 = 0.0
for sample in samples:
sum += sample[]

var mean: Float64 = sum / num_samples

var sum_sq: Float64 = 0.0
for sample in samples:
sum_sq += (sample[] - mean) ** 2

var variance = sum_sq / num_samples

# Calculate absolute differences (errors)
var mean_error = abs(mean)
var variance_error = abs(variance - 4)

var mean_tolerance: Float64 = 0.06 # SE_μ = σ / √n
assert_true(
mean_error < mean_tolerance,
String(
"Mean error ",
mean_error,
" is above the accepted tolerance ",
mean_tolerance,
),
)
var variance_tolerance: Float64 = 0.57 # SE_S² = √(2 * σ^4 / (n - 1))
assert_true(
variance_error < variance_tolerance,
String(
"Variance error ",
variance_error,
" is above the accepted tolerance ",
variance_tolerance,
),
)


def test_seed():
Expand Down