You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
next_token_scores = self.apply_warp(next_token_scores)
probs = npsoftmax(next_token_scores.astype(np.float64), axis=1)
# Caution:
# *** ValueError: sum(pvals[:-1].astype(np.float64)) > 1.0. The pvals array is cast to 64-bit floating point prior to checking the sum. Precision changes when casting may cause problems even if the sum of the original pvals is valid.
next_token = npmultinominal2D(probs).astype(input_ids.dtype)
这几句代码。
新的计算代码:
next_token = post_process(next_token_scores)
def post_process(tensor, topk=3):
tensor = tensor.reshape([-1]).astype("float32")
tensor = warp_temperature(tensor, 1.0)
topk_vals, topk_idxs = warp_topk1(tensor, topk)
probs = npsoftmax(topk_vals, axis=0)
max_idx = np.random.multinomial(1, probs).argmax()
next_token = topk_idxs[max_idx]
next_token = np.array([next_token], dtype="int64").reshape([-1, 1])
return next_token
def warp_topk1(tensor, topk):
tensor_1d = tensor.reshape([-1])
topk_vals, topk_idxs = get_topk(tensor_1d, topk)
return topk_vals, topk_idxs
def get_topk(tensor_1d, topk=3):
# value in topk_vals are placed by descending order
topk_vals = [-float("Inf")] * topk
topk_idxs = [0] * topk
for idx, elem in enumerate(tensor_1d):
if elem > topk_vals[topk - 1]:
for i in range(topk):
# find where current top value should be placed
# then we right shift the topk_vals to place the top value
if elem > topk_vals[i]:
# right shift
for j in reversed(range(i, topk-1)):
topk_vals[j+1] = topk_vals[j]
topk_idxs[j+1] = topk_idxs[j]
topk_vals[i] = elem
topk_idxs[i] = idx
break
return topk_vals, topk_idxs
The text was updated successfully, but these errors were encountered:
采用下面的方式替代已有计算可以明显降低next_token计算量,用于替换原有的
这几句代码。
新的计算代码:
The text was updated successfully, but these errors were encountered: