Skip to content

Commit

Permalink
Move multinomial hypothesis tests to test_distribution_ops, resolve c…
Browse files Browse the repository at this point in the history
…onflicts with master.
  • Loading branch information
tongxin committed Aug 7, 2024
2 parents ce37522 + 2b74ff3 commit 376344c
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 55 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ on:
jobs:
container-unit-test:
runs-on: [self-hosted, docker]
timeout-minutes: 50
container:
image: localhost:5000/flag-gems-ci:v1.0
ports:
Expand All @@ -22,6 +21,9 @@ jobs:
- name: checkout-code
uses: actions/checkout@v4

- name: check-gpu-free
run: tests/scripts/gpu_check.sh

- name: unit_test-flag-gems
shell: bash
run: |
Expand Down Expand Up @@ -63,7 +65,6 @@ jobs:
container-model-test:
runs-on: [self-hosted, docker]
timeout-minutes: 5
container:
image: localhost:5000/flag-gems-ci:v1.0
ports:
Expand All @@ -73,6 +74,9 @@ jobs:
- name: checkout-code
uses: actions/checkout@v4

- name: check-gpu-free
run: tests/scripts/gpu_check.sh

- name: examples-flag-gems
run: |
CUDA_VISIBLE_DEVICES=5 pytest -s examples/model_bert_test.py
55 changes: 55 additions & 0 deletions tests/scripts/gpu_check.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/bin/bash

# Configuration parameters
memory_usage_max=30000 # Maximum memory usage limit (MB)
sleep_time=120 # Wait time (seconds), default is 2 minutes

# Get the number of GPUs
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)

if [ "$gpu_count" -eq 0 ]; then
echo "No GPUs detected. Please ensure you have NVIDIA GPUs installed and properly configured."
exit 1
fi

echo "Detected $gpu_count GPUs."

nvidia-smi

while true; do
# Query GPU memory usage and total memory
memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits 2>/dev/null)
memory_total=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null)

# Check if nvidia-smi command was successful
if [ $? -ne 0 ]; then
echo "Failed to query GPU memory information. Please check if nvidia-smi is working correctly."
exit 1
fi

# Convert query results to arrays
IFS=$'\n' read -d '' -r -a memory_usage_array <<< "$memory_usage"
IFS=$'\n' read -d '' -r -a memory_total_array <<< "$memory_total"

need_wait=false

# Check the available memory for each GPU
for ((i=0; i<$gpu_count; i++)); do
memory_usage_i=${memory_usage_array[$i]}
memory_total_i=${memory_total_array[$i]}
memory_remin_i=$((memory_total_i - memory_usage_i))

if [ $memory_remin_i -lt $memory_usage_max ]; then
need_wait=true
break
fi
done

if [ "$need_wait" = false ]; then
echo "All GPUs have sufficient available memory. Proceeding with execution."
break
fi

echo "GPU memory is insufficient, waiting for $sleep_time seconds before retrying..."
sleep $sleep_time
done
51 changes: 50 additions & 1 deletion tests/test_distribution_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import pytest
import scipy
import torch

import flag_gems
Expand All @@ -20,7 +22,7 @@ def test_accuracy_normal(shape, dtype):


@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_uniform(shape, dtype):
x = torch.randn(size=shape, dtype=dtype, device="cuda")
with flag_gems.use_gems():
Expand All @@ -36,3 +38,50 @@ def test_accuracy_exponential_(shape, dtype):
with flag_gems.use_gems():
x.exponential_()
assert x.min() > 0


@pytest.mark.parametrize("shape", [(1000,), (100, 1000)])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("n_samples", [2048])
def test_accuracy_multinomial_with_replacement(shape, dtype, n_samples):
dist = torch.zeros(size=shape, dtype=dtype, device="cuda")
Index = [5, 13, 42]
dist[..., Index] = 1
with flag_gems.use_gems():
res_out = torch.multinomial(dist, n_samples, True)
index, tally = torch.unique(res_out, sorted=True, return_counts=True)
assert index.tolist() == Index
# Do a simple Chi-square test
tally = np.array(tally.tolist())
expected = tally
expected[:] = tally.mean()
observed = tally
chi2, pvalue = scipy.stats.chisquare(observed, expected)
assert pvalue > 0.05


@pytest.mark.parametrize("pool", [100])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("n_samples", [10])
def test_accuracy_multinomial_without_replacement(pool, dtype, n_samples):
n_draws = 1000
dist = torch.zeros(size=(pool,), dtype=dtype, device="cuda").broadcast_to(
n_draws, pool
)
indices = torch.randint(0, pool, (50,), device="cuda").unique()
dist[:, indices] = 1
with flag_gems.use_gems():
res_out = torch.multinomial(dist, n_samples, False)
# Verifies uniqueness
for draw in range(n_draws):
assert res_out[draw].unique().size(0) == res_out.size(1)
# Chi-square tests
samples, count = res_out.unique(return_counts=True)
dist = dist[0][samples]
dist = dist / dist.sum()
# The expected number of samples must equal the observed number of samples exactly
observed_samples = n_samples * n_draws
expected_count = torch.round(dist * n_samples * n_draws)
expected_count[0] += observed_samples - expected_count.sum()
chi2, pvalue = scipy.stats.chisquare(count.tolist(), expected_count.tolist())
assert pvalue > 0.05
42 changes: 10 additions & 32 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Optional

import numpy as np
import pytest
import scipy
import torch

import flag_gems
Expand Down Expand Up @@ -228,39 +226,19 @@ def test_accuracy_multinomial_with_replacement(shape, dtype, n_samples):
dist[..., Index] = 1
with flag_gems.use_gems():
res_out = torch.multinomial(dist, n_samples, True)
index, tally = torch.unique(res_out, sorted=True, return_counts=True)
assert index.tolist() == Index
# Do a simple Chi-square test
tally = np.array(tally.tolist())
expected = tally
expected[:] = tally.mean()
observed = tally
chi2, pvalue = scipy.stats.chisquare(observed, expected)
assert pvalue > 0.05


@pytest.mark.parametrize("pool", [100])
assert torch.all(torch.isin(res_out, torch.tensor(Index, device="cuda")))


@pytest.mark.parametrize("pool", [100, 2048])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("n_samples", [10])
def test_accuracy_multinomial_without_replacement(pool, dtype, n_samples):
n_draws = 1000
dist = torch.zeros(size=(pool,), dtype=dtype, device="cuda").broadcast_to(
def test_accuracy_multinomial_without_replacement(pool, dtype):
n_draws = 10
dist = torch.rand(size=(pool,), dtype=dtype, device="cuda").broadcast_to(
n_draws, pool
)
indices = torch.randint(0, pool, (50,), device="cuda").unique()
dist[:, indices] = 1
n_samples = pool
with flag_gems.use_gems():
res_out = torch.multinomial(dist, n_samples, False)
# Verifies uniqueness
for draw in range(n_draws):
assert res_out[draw].unique().size(0) == res_out.size(1)
# Chi-square tests
samples, count = res_out.unique(return_counts=True)
dist = dist[0][samples]
dist = dist / dist.sum()
# The expected number of samples must equal the observed number of samples exactly
observed_samples = n_samples * n_draws
expected_count = torch.round(dist * n_samples * n_draws)
expected_count[0] += observed_samples - expected_count.sum()
chi2, pvalue = scipy.stats.chisquare(count.tolist(), expected_count.tolist())
assert pvalue > 0.05
sorted_samples, _ = res_out.sort(dim=1)
assert torch.all(sorted_samples == torch.arange(pool, device="cuda"))
41 changes: 23 additions & 18 deletions tests/test_specific_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"native_dropout": ("test_accuracy_dropout",),
}

op_name_2_unit_test_maps = {
op_name_to_unit_test_maps = {
"test_blas_ops.py": blas_ops_ut_map,
"test_reduction_ops.py": reduction_ops_ut_map,
"test_unary_pointwise_ops.py": unary_pointwise_ops_ut_map,
Expand All @@ -145,34 +145,39 @@
parser.add_argument("--name", type=str, help="test for a specific op")
args = parser.parse_args()

op_nums = 0
op_list = []
for item in op_name_to_unit_test_maps.values():
op_nums = op_nums + len(item)
for op in item.keys():
op_list.append(op)
print(f"Here is the sorted op list with {op_nums} ops:")
op_list = sorted(op_list)
print(op_list)

final_result = 0
if args.all:
op_nums = 0
op_list = []
for item in op_name_2_unit_test_maps.values():
op_nums = op_nums + len(item)
for op in item.keys():
op_list.append(op)
print(f"{op_nums} ops to test, and here is the sorted op list:")
op_list = sorted(op_list)
print(op_list)

for file_name, collection in op_name_2_unit_test_maps.items():
for file_name, collection in op_name_to_unit_test_maps.items():
for op, uts in collection.items():
for ut in uts:
cmd = f"{file_name}::{ut}"
result = pytest.main(["-s", cmd])
print("final_result: ", final_result)
exit(final_result)

if args.name:
exec_flag = False
for file_name, collection in op_name_2_unit_test_maps.items():
if args.name not in op_list:
logging.fatal(f"No op named {args.name} found! Check the name and list!")
exit(1)

for file_name, collection in op_name_to_unit_test_maps.items():
for op, uts in collection.items():
if op == args.name:
print(op)
for ut in uts:
cmd = f"{file_name}::{ut}"
print(cmd)
result = pytest.main(["-s", cmd])
exec_flag = True

if exec_flag is False:
logging.fatal(f"No op named {args.name} found! Check the name and list!")
final_result += result
print("final_result: ", final_result)
exit(final_result)
4 changes: 2 additions & 2 deletions tests/test_tensor_constructor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_rand(shape, dtype):
with flag_gems.use_gems():
res_out = torch.rand(shape, dtype=dtype, device="cuda")
Expand All @@ -21,7 +21,7 @@ def test_accuracy_rand(shape, dtype):


@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_randn(shape, dtype):
with flag_gems.use_gems():
res_out = torch.randn(shape, dtype=dtype, device="cuda")
Expand Down

0 comments on commit 376344c

Please sign in to comment.