Skip to content

Commit

Permalink
fix onmt converter (#1581)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Dec 7, 2023
1 parent 83caf67 commit 4f8a4f3
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions python/ctranslate2/converters/opennmt_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
num_heads = getattr(opt, "heads", 8)
num_kv = getattr(opt, "num_kv", 0)
if num_kv == num_heads:
if num_kv == num_heads or num_kv == 0:
num_kv = None
rotary_dim = 0 if with_rotary else None
rotary_interleave = getattr(opt, "rotary_interleave", True)
ffn_glu = activation_fn == "silu"
sliding_window = getattr(opt, "sliding_window", 0)

Expand All @@ -119,7 +120,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
alibi=with_alibi,
rms_norm=opt.layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=True,
rotary_interleave=rotary_interleave,
multi_query_attention=getattr(opt, "multiquery", False),
num_heads_kv=num_kv,
sliding_window=sliding_window,
Expand Down Expand Up @@ -329,7 +330,7 @@ def set_linear(spec, variables, scope):
spec.weight = _get_variable(variables, "%s.weight" % scope)
bias = variables.get("%s.bias" % scope)
if bias is not None:
spec.bias = bias.numpy()
spec.bias = bias


def set_embeddings(spec, variables, scope):
Expand All @@ -341,7 +342,7 @@ def set_position_encodings(spec, variables, scope):


def _get_variable(variables, name):
return variables[name].numpy()
return variables[name]


def main():
Expand Down

0 comments on commit 4f8a4f3

Please sign in to comment.