From 80d2998684e6e43027eba8e8b97a35865594273e Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 24 Apr 2024 13:55:03 +0100 Subject: [PATCH] Fix inference (#31) --- amt/inference/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/amt/inference/model.py b/amt/inference/model.py index c302614..e8390f4 100644 --- a/amt/inference/model.py +++ b/amt/inference/model.py @@ -354,6 +354,7 @@ def __init__( ] ) self.ln = nn.LayerNorm(n_state) + self.output = nn.Linear(n_state, n_vocab, bias=False) self.register_buffer("causal_mask", None, persistent=False) def forward( @@ -376,9 +377,7 @@ def forward( ) x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() + logits = self.output(x) return logits