Skip to content

Commit

Permalink
Merge pull request #87 from RaoFoundation/dev
Browse files Browse the repository at this point in the history
Release 2.2.2
  • Loading branch information
Sid-Data-Universe authored Mar 22, 2024
2 parents d2faaec + b4d1207 commit 706f659
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 3 deletions.
2 changes: 1 addition & 1 deletion constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Project Constants.
# ---------------------------------

__version__ = "2.2.1"
__version__ = "2.2.2"
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
Expand Down
74 changes: 73 additions & 1 deletion pretrain/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import constants
import traceback
import bittensor as bt
import pretrain as pt


def iswin(loss_i, loss_j, block_i, block_j):
Expand Down Expand Up @@ -82,8 +83,66 @@ def compute_wins(
return wins, win_rate


def check_for_reasonable_output(
model, input1: torch.Tensor, input2: torch.Tensor
) -> bool:
"""Checks that a model generates reasonable outputs for two given inputs.
Args:
model (torch.nn.Module): The model for which outputs are to be checked. Already loaded to device.
input1 (torch.Tensor]): Tokenized input1 to check. Already loaded to device.
input2 (torch.Tensor]): Tokenized input2 to check. Already loaded to device.
Returns:
bool: If the model generates reasonable outputs.
"""
# Generate 30 tokens of output from the model for each prompt.
output_length = 30
tokenizer = pt.model.get_tokenizer()
# Only take the last 30 tokens since otherwise we also get the prompt ids.
generate_id1s = model.generate(
input1,
min_new_tokens=output_length,
max_new_tokens=output_length,
pad_token_id=tokenizer.eos_token_id,
)[:, -output_length:]
generate_id2s = model.generate(
input2,
min_new_tokens=output_length,
max_new_tokens=output_length,
pad_token_id=tokenizer.eos_token_id,
)[:, -output_length:]

# Check if too many of the generated ids are the same between the two outputs.
if torch.sum(torch.eq(generate_id1s, generate_id2s)).item() >= output_length / 3:
bt.logging.info(
f"Model with config {model.config} had too much overlap between generated outputs."
)
return False

# Check if internally either response is too repetitive.
for tensor in [generate_id1s, generate_id2s]:
# Find unique elements and their counts
_, counts = torch.unique(tensor, return_counts=True)
# Find the index of the maximum count
max_count_index = torch.argmax(counts)
# Extract the count of the most common element
most_common_count = counts[max_count_index].item()

if most_common_count > output_length / 3:
bt.logging.info(
f"Model with config {model.config} had too much repetition in generated output."
)
return False

# Passed all the checks, return True.
return True


def compute_losses(
model, batches: typing.List[torch.Tensor], device: str
model,
batches: typing.List[torch.Tensor],
device: str,
) -> typing.List[float]:
"""
Computes the losses for a given model on provided batches.
Expand All @@ -99,6 +158,19 @@ def compute_losses(
model.to(device)
model.eval()

# First check that model generates reasonable looking outputs.
# Grab 100 tokens from the first two batches as 'prompts'. (1 x Seq Length tensors.)
prompt_length = 100
falcon_token_inputs_1 = (batches[0][:, :prompt_length]).to(device)
falcon_token_inputs_2 = (batches[1][:, :prompt_length]).to(device)

if not check_for_reasonable_output(
model, falcon_token_inputs_1, falcon_token_inputs_2
):
return [math.inf for _ in batches]

# Everything looks good! Continue to computing actual losses.

# Iterate over each page and corresponding batches
losses = []
for batch in batches:
Expand Down
3 changes: 2 additions & 1 deletion utilities/miner_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, miner_uids: List[int]):
self.miner_uids = sorted(copy.deepcopy(miner_uids))
# Start the index at a random position. This helps ensure that miners with high UIDs aren't penalized if
# the validator restarts frequently.
self.index = random.randint(0, len(self.miner_uids) - 1)
# Temporarily hard code to start at 200 to more quickly restart on the relevant models.
self.index = 200
self.lock = threading.Lock()

def __iter__(self):
Expand Down

0 comments on commit 706f659

Please sign in to comment.