Skip to content

Commit

Permalink
Merge pull request #2543 from opentensor/feat/thewhaleking/remove-tor…
Browse files Browse the repository at this point in the history
…ch-from-cr3

Remove torch from the weights calls
  • Loading branch information
thewhaleking authored Dec 16, 2024
2 parents bcc98aa + f0be03d commit daf1065
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 72 deletions.
16 changes: 5 additions & 11 deletions bittensor/core/extrinsics/async_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from bittensor.core.settings import version_as_int
from bittensor.utils import format_error_message
from bittensor.utils.btlogging import logging
from bittensor.utils.registration import torch, use_torch

if TYPE_CHECKING:
from bittensor_wallet import Wallet
from bittensor.core.async_subtensor import AsyncSubtensor
from bittensor.utils.registration import torch


async def _do_set_weights(
Expand Down Expand Up @@ -106,16 +106,10 @@ async def set_weights_extrinsic(
success (bool): Flag is ``true`` if extrinsic was finalized or included in the block. If we did not wait for finalization / inclusion, the response is ``true``.
"""
# First convert types.
if use_torch():
if isinstance(uids, list):
uids = torch.tensor(uids, dtype=torch.int64)
if isinstance(weights, list):
weights = torch.tensor(weights, dtype=torch.float32)
else:
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)

# Reformat and normalize.
weight_uids, weight_vals = weight_utils.convert_weights_and_uids_for_emit(
Expand Down
16 changes: 5 additions & 11 deletions bittensor/core/extrinsics/commit_reveal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from bittensor.utils import format_error_message
from bittensor.utils.btlogging import logging
from bittensor.utils.networking import ensure_connected
from bittensor.utils.registration import torch, use_torch
from bittensor.utils.weight_utils import convert_weights_and_uids_for_emit

if TYPE_CHECKING:
from bittensor_wallet import Wallet
from bittensor.core.subtensor import Subtensor
from bittensor.utils.registration import torch


@ensure_connected
Expand Down Expand Up @@ -105,16 +105,10 @@ def commit_reveal_v3_extrinsic(
"""
try:
# Convert uids and weights
if use_torch():
if isinstance(uids, list):
uids = torch.tensor(uids, dtype=torch.int64)
if isinstance(weights, list):
weights = torch.tensor(weights, dtype=torch.float32)
else:
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)
if isinstance(uids, list):
uids = np.array(uids, dtype=np.int64)
if isinstance(weights, list):
weights = np.array(weights, dtype=np.float32)

# Reformat and normalize.
uids, weights = convert_weights_and_uids_for_emit(uids, weights)
Expand Down
47 changes: 0 additions & 47 deletions tests/unit_tests/extrinsics/test_async_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,53 +278,6 @@ async def test_set_weights_extrinsic_exception(subtensor, mocker):
assert message == "Unexpected error"


@pytest.mark.asyncio
async def test_set_weights_extrinsic_if_use_torch(subtensor, mocker):
"""Tests set_weights_extrinsic when use_torch is True."""
# Preps
fake_wallet = mocker.Mock(autospec=Wallet)
fake_netuid = 1
fake_uids = [1, 2, 3]
fake_weights = [0.1, 0.2, 0.7]

mocked_use_torch = mocker.patch.object(
async_weights, "use_torch", return_value=True
)
mocked_torch_tensor = mocker.patch.object(
async_weights.torch, "tensor", return_value=mocker.Mock()
)

mocked_do_set_weights = mocker.patch.object(
async_weights, "_do_set_weights", return_value=(False, "Test error message")
)
mocked_convert_weights_and_uids_for_emit = mocker.patch.object(
async_weights.weight_utils,
"convert_weights_and_uids_for_emit",
return_value=(mocker.Mock(), mocker.Mock()),
)

# Call
result, message = await async_weights.set_weights_extrinsic(
subtensor=subtensor,
wallet=fake_wallet,
netuid=fake_netuid,
uids=fake_uids,
weights=fake_weights,
wait_for_inclusion=True,
wait_for_finalization=True,
)

# Asserts
mocked_do_set_weights.assert_called_once()
mocked_use_torch.assert_called_once()
mocked_convert_weights_and_uids_for_emit.assert_called()
mocked_torch_tensor.assert_called_with(
fake_weights, dtype=async_weights.torch.float32
)
assert result is False
assert message == "Test error message"


@pytest.mark.asyncio
async def test_do_commit_weights_success(subtensor, mocker):
"""Tests _do_commit_weights when the commit is successful."""
Expand Down
3 changes: 0 additions & 3 deletions tests/unit_tests/extrinsics/test_commit_reveal.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def test_commit_reveal_v3_extrinsic_success_with_torch(mocker, subtensor, hyperp
fake_reveal_round = 1

# Mocks
mocker.patch.object(commit_reveal, "use_torch", return_value=True)

mocked_uids = mocker.Mock()
mocked_weights = mocker.Mock()
Expand Down Expand Up @@ -233,7 +232,6 @@ def test_commit_reveal_v3_extrinsic_success_with_numpy(mocker, subtensor, hyperp
fake_uids = np.array([1, 2, 3], dtype=np.int64)
fake_weights = np.array([0.1, 0.2, 0.7], dtype=np.float32)

mocker.patch.object(commit_reveal, "use_torch", return_value=False)
mock_convert = mocker.patch.object(
commit_reveal,
"convert_weights_and_uids_for_emit",
Expand Down Expand Up @@ -282,7 +280,6 @@ def test_commit_reveal_v3_extrinsic_response_false(mocker, subtensor, hyperparam
fake_reveal_round = 1

# Mocks
mocker.patch.object(commit_reveal, "use_torch", return_value=True)
mocker.patch.object(
commit_reveal,
"convert_weights_and_uids_for_emit",
Expand Down

0 comments on commit daf1065

Please sign in to comment.