From eb7aea2d93d1bde4aad9177d0eac1367da490ff2 Mon Sep 17 00:00:00 2001 From: Cameron Fairchild Date: Thu, 13 Oct 2022 15:50:02 -0400 Subject: [PATCH] [Hotfix] Fix CUDA Reg update block (#954) * bump version * fix block update * . * verify new helper * remove uneeded comment --- bittensor/__init__.py | 2 +- bittensor/utils/__init__.py | 133 +++++++++++------- .../bittensor_tests/utils/test_utils.py | 82 +++++++++++ 3 files changed, 168 insertions(+), 49 deletions(-) diff --git a/bittensor/__init__.py b/bittensor/__init__.py index e316c1bbf5..ef5a9a9a72 100644 --- a/bittensor/__init__.py +++ b/bittensor/__init__.py @@ -19,7 +19,7 @@ from prometheus_client import Info # Bittensor code and protocol version. -__version__ = '3.4.0' +__version__ = '3.4.1' version_split = __version__.split(".") __version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) diff --git a/bittensor/utils/__init__.py b/bittensor/utils/__init__.py index 8482d5a336..a9e2144d86 100644 --- a/bittensor/utils/__init__.py +++ b/bittensor/utils/__init__.py @@ -530,25 +530,17 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, pass # check for new block - block_number = subtensor.get_current_block() - if block_number != old_block_number: - old_block_number = block_number - # update block information - block_hash = subtensor.substrate.get_block_hash( block_number) - while block_hash == None: - block_hash = subtensor.substrate.get_block_hash( block_number) - block_bytes = block_hash.encode('utf-8')[2:] - difficulty = subtensor.difficulty - - update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block) - # Set new block events for each solver - for worker in solvers: - worker.newBlockEvent.set() - - # update stats - curr_stats.block_number = block_number - curr_stats.block_hash = block_hash - curr_stats.difficulty = difficulty + old_block_number = check_for_newest_block_and_update( + subtensor = subtensor, + old_block_number=old_block_number, + curr_diff=curr_diff, + curr_block=curr_block, + curr_block_num=curr_block_num, + curr_stats=curr_stats, + update_curr_block=update_curr_block, + check_block=check_block, + solvers=solvers + ) num_time = 0 for _ in range(len(solvers)*2): @@ -636,6 +628,66 @@ def __exit__(self, *args): # restore the old start method multiprocessing.set_start_method(self._old_start_method, force=True) +def check_for_newest_block_and_update( + subtensor: 'bittensor.Subtensor', + old_block_number: int, + curr_diff: multiprocessing.Array, + curr_block: multiprocessing.Array, + curr_block_num: multiprocessing.Value, + update_curr_block: Callable, + check_block: 'multiprocessing.Lock', + solvers: List[Solver], + curr_stats: RegistrationStatistics + ) -> int: + """ + Checks for a new block and updates the current block information if a new block is found. + + Args: + subtensor (:obj:`bittensor.Subtensor`, `required`): + The subtensor object to use for getting the current block. + old_block_number (:obj:`int`, `required`): + The old block number to check against. + curr_diff (:obj:`multiprocessing.Array`, `required`): + The current difficulty as a multiprocessing array. + curr_block (:obj:`multiprocessing.Array`, `required`): + Where the current block is stored as a multiprocessing array. + curr_block_num (:obj:`multiprocessing.Value`, `required`): + Where the current block number is stored as a multiprocessing value. + update_curr_block (:obj:`Callable`, `required`): + A function that updates the current block. + check_block (:obj:`multiprocessing.Lock`, `required`): + A mp lock that is used to check for a new block. + solvers (:obj:`List[Solver]`, `required`): + A list of solvers to update the current block for. + curr_stats (:obj:`RegistrationStatistics`, `required`): + The current registration statistics to update. + + Returns: + (int) The current block number. + """ + block_number = subtensor.get_current_block() + if block_number != old_block_number: + old_block_number = block_number + # update block information + block_hash = subtensor.substrate.get_block_hash( block_number) + while block_hash == None: + block_hash = subtensor.substrate.get_block_hash( block_number) + block_bytes = block_hash.encode('utf-8')[2:] + difficulty = subtensor.difficulty + + update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block) + # Set new block events for each solver + + for worker in solvers: + worker.newBlockEvent.set() + + # update stats + curr_stats.block_number = block_number + curr_stats.block_hash = block_hash + curr_stats.difficulty = difficulty + + return old_block_number + def solve_for_difficulty_fast_cuda( subtensor: 'bittensor.Subtensor', wallet: 'bittensor.Wallet', output_in_place: bool = True, update_interval: int = 50_000, TPB: int = 512, dev_id: Union[List[int], int] = 0, n_samples: int = 5, alpha_: float = 0.70, log_verbose: bool = False ) -> Optional[POWSolution]: """ @@ -680,13 +732,6 @@ def solve_for_difficulty_fast_cuda( subtensor: 'bittensor.Subtensor', wallet: 'b curr_block_num = multiprocessing.Value('i', 0, lock=True) # int curr_diff = multiprocessing.Array('Q', [0, 0], lock=True) # [high, low] - def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: multiprocessing.Lock): - with lock: - curr_block_num.value = block_number - for i in range(64): - curr_block[i] = block_bytes[i] - registration_diff_pack(diff, curr_diff) - # Establish communication queues stopEvent = multiprocessing.Event() stopEvent.clear() @@ -712,7 +757,7 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu old_block_number = block_number # Set to current block - update_curr_block(block_number, block_bytes, difficulty, check_block) + update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block) # Set new block events for each solver to start at the initial block for worker in solvers: @@ -755,27 +800,19 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu except Empty: # No solution found, try again pass - - if block_number != old_block_number: - old_block_number = block_number - # update block information - block_hash = subtensor.substrate.get_block_hash( block_number) - while block_hash == None: - block_hash = subtensor.substrate.get_block_hash( block_number) - block_bytes = block_hash.encode('utf-8')[2:] - difficulty = subtensor.difficulty - - update_curr_block(block_number, block_bytes, difficulty, check_block) - # Set new block events for each solver - - for worker in solvers: - worker.newBlockEvent.set() - - - # update stats - curr_stats.block_number = block_number - curr_stats.block_hash = block_hash - curr_stats.difficulty = difficulty + + # check for new block + old_block_number = check_for_newest_block_and_update( + subtensor = subtensor, + curr_diff=curr_diff, + curr_block=curr_block, + curr_block_num=curr_block_num, + old_block_number=old_block_number, + curr_stats=curr_stats, + update_curr_block=update_curr_block, + check_block=check_block, + solvers=solvers + ) num_time = 0 # Get times for each solver diff --git a/tests/unit_tests/bittensor_tests/utils/test_utils.py b/tests/unit_tests/bittensor_tests/utils/test_utils.py index 030cdbfb83..1220e6836d 100644 --- a/tests/unit_tests/bittensor_tests/utils/test_utils.py +++ b/tests/unit_tests/bittensor_tests/utils/test_utils.py @@ -251,6 +251,88 @@ def test_registration_diff_pack_unpack_over_32_bits(): bittensor.utils.registration_diff_pack(fake_diff, mock_diff) assert bittensor.utils.registration_diff_unpack(mock_diff) == fake_diff +class TestUpdateCurrentBlockDuringRegistration(unittest.TestCase): + def test_check_for_newest_block_and_update_same_block(self): + # if the block is the same, the function should return the same block number + subtensor = MagicMock() + current_block_num: int = 1 + subtensor.get_current_block = MagicMock( return_value=current_block_num ) + + self.assertEqual(bittensor.utils.check_for_newest_block_and_update( + subtensor, + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + ), current_block_num) + + def test_check_for_newest_block_and_update_new_block(self): + # if the block is new, the function should return the new block_number + mock_block_hash = '0xba7ea4eb0b16dee271dbef5911838c3f359fcf598c74da65a54b919b68b67279' + + current_block_num: int = 1 + current_diff: int = 0 + + mock_substrate = MagicMock( + get_block_hash=MagicMock( + return_value=mock_block_hash + ), + + ) + subtensor = MagicMock( + substrate=mock_substrate, + difficulty=current_diff + 1, # new diff + ) + subtensor.get_current_block = MagicMock( return_value=current_block_num + 1 ) # new block + + mock_update_curr_block = MagicMock() + + mock_solvers = [ + MagicMock( + newBlockEvent=MagicMock( + set=MagicMock() + ) + ), + MagicMock( + newBlockEvent=MagicMock( + set=MagicMock() + ) + )] + + mock_curr_stats = MagicMock( + block_number=current_block_num, + block_hash=b'', + difficulty=0, + ) + + self.assertEqual(bittensor.utils.check_for_newest_block_and_update( + subtensor, + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + mock_update_curr_block, + MagicMock(), + mock_solvers, + mock_curr_stats, + ), current_block_num + 1) + + # check that the update_curr_block function was called + mock_update_curr_block.assert_called_once() + + # check that the solvers got the event + for solver in mock_solvers: + solver.newBlockEvent.set.assert_called_once() + + # check the stats were updated + self.assertEqual(mock_curr_stats.block_number, current_block_num + 1) + self.assertEqual(mock_curr_stats.block_hash, mock_block_hash) + self.assertEqual(mock_curr_stats.difficulty, current_diff + 1) + class TestGetBlockWithRetry(unittest.TestCase): def test_get_block_with_retry_network_error_exit(self): mock_subtensor = MagicMock(