Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove torch from the weights calls #2543

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions bittensor/core/extrinsics/async_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from bittensor.core.settings import version_as_int
from bittensor.utils import format_error_message
from bittensor.utils.btlogging import logging
from bittensor.utils.registration import torch, use_torch

if TYPE_CHECKING:
from bittensor_wallet import Wallet
from bittensor.core.async_subtensor import AsyncSubtensor
from bittensor.utils.registration import torch


async def _do_set_weights(
Expand Down Expand Up @@ -106,16 +106,10 @@ async def set_weights_extrinsic(
success (bool): Flag is ``true`` if extrinsic was finalized or included in the block. If we did not wait for finalization / inclusion, the response is ``true``.
"""
# First convert types.
if use_torch():
if isinstance(uids, list):
uids = torch.tensor(uids, dtype=torch.int64)
if isinstance(weights, list):
weights = torch.tensor(weights, dtype=torch.float32)
else:
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)

# Reformat and normalize.
weight_uids, weight_vals = weight_utils.convert_weights_and_uids_for_emit(
Expand Down
16 changes: 5 additions & 11 deletions bittensor/core/extrinsics/commit_reveal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from bittensor.utils import format_error_message
from bittensor.utils.btlogging import logging
from bittensor.utils.networking import ensure_connected
from bittensor.utils.registration import torch, use_torch
from bittensor.utils.weight_utils import convert_weights_and_uids_for_emit

if TYPE_CHECKING:
from bittensor_wallet import Wallet
from bittensor.core.subtensor import Subtensor
from bittensor.utils.registration import torch


@ensure_connected
Expand Down Expand Up @@ -105,16 +105,10 @@ def commit_reveal_v3_extrinsic(
"""
try:
# Convert uids and weights
if use_torch():
if isinstance(uids, list):
uids = torch.tensor(uids, dtype=torch.int64)
if isinstance(weights, list):
weights = torch.tensor(weights, dtype=torch.float32)
else:
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)

# Reformat and normalize.
uids, weights = convert_weights_and_uids_for_emit(uids, weights)
Expand Down
47 changes: 0 additions & 47 deletions tests/unit_tests/extrinsics/test_async_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,53 +278,6 @@ async def test_set_weights_extrinsic_exception(subtensor, mocker):
assert message == "Unexpected error"


@pytest.mark.asyncio
async def test_set_weights_extrinsic_if_use_torch(subtensor, mocker):
"""Tests set_weights_extrinsic when use_torch is True."""
# Preps
fake_wallet = mocker.Mock(autospec=Wallet)
fake_netuid = 1
fake_uids = [1, 2, 3]
fake_weights = [0.1, 0.2, 0.7]

mocked_use_torch = mocker.patch.object(
async_weights, "use_torch", return_value=True
)
mocked_torch_tensor = mocker.patch.object(
async_weights.torch, "tensor", return_value=mocker.Mock()
)

mocked_do_set_weights = mocker.patch.object(
async_weights, "_do_set_weights", return_value=(False, "Test error message")
)
mocked_convert_weights_and_uids_for_emit = mocker.patch.object(
async_weights.weight_utils,
"convert_weights_and_uids_for_emit",
return_value=(mocker.Mock(), mocker.Mock()),
)

# Call
result, message = await async_weights.set_weights_extrinsic(
subtensor=subtensor,
wallet=fake_wallet,
netuid=fake_netuid,
uids=fake_uids,
weights=fake_weights,
wait_for_inclusion=True,
wait_for_finalization=True,
)

# Asserts
mocked_do_set_weights.assert_called_once()
mocked_use_torch.assert_called_once()
mocked_convert_weights_and_uids_for_emit.assert_called()
mocked_torch_tensor.assert_called_with(
fake_weights, dtype=async_weights.torch.float32
)
assert result is False
assert message == "Test error message"


@pytest.mark.asyncio
async def test_do_commit_weights_success(subtensor, mocker):
"""Tests _do_commit_weights when the commit is successful."""
Expand Down
3 changes: 0 additions & 3 deletions tests/unit_tests/extrinsics/test_commit_reveal.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def test_commit_reveal_v3_extrinsic_success_with_torch(mocker, subtensor, hyperp
fake_reveal_round = 1

# Mocks
mocker.patch.object(commit_reveal, "use_torch", return_value=True)

mocked_uids = mocker.Mock()
mocked_weights = mocker.Mock()
Expand Down Expand Up @@ -233,7 +232,6 @@ def test_commit_reveal_v3_extrinsic_success_with_numpy(mocker, subtensor, hyperp
fake_uids = np.array([1, 2, 3], dtype=np.int64)
fake_weights = np.array([0.1, 0.2, 0.7], dtype=np.float32)

mocker.patch.object(commit_reveal, "use_torch", return_value=False)
mock_convert = mocker.patch.object(
commit_reveal,
"convert_weights_and_uids_for_emit",
Expand Down Expand Up @@ -282,7 +280,6 @@ def test_commit_reveal_v3_extrinsic_response_false(mocker, subtensor, hyperparam
fake_reveal_round = 1

# Mocks
mocker.patch.object(commit_reveal, "use_torch", return_value=True)
mocker.patch.object(
commit_reveal,
"convert_weights_and_uids_for_emit",
Expand Down
Loading