forked from facebookresearch/chameleon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalignment.py
79 lines (63 loc) · 2.33 KB
/
alignment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
import torch
class PromptAlignment(ABC):
@abstractmethod
def start_index(self, input_ids: list[list[int]]) -> int:
...
@abstractmethod
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
...
@abstractmethod
def postprocess_inputs(
self, inputs: torch.Tensor, original_inputs: torch.Tensor
) -> torch.Tensor:
...
class AlignPromptRight(PromptAlignment):
def __init__(self, pad_id: int):
self.pad_id = pad_id
def start_index(self, input_ids: list[list[int]]) -> int:
return max(len(sublist) for sublist in input_ids)
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.LongTensor:
max_length = max(len(sublist) for sublist in input_ids)
return torch.tensor(
[
([self.pad_id] * (max_length - len(sublist))) + sublist
for sublist in input_ids
],
requires_grad=False,
)
def postprocess_inputs(
self,
inputs: torch.Tensor,
original_inputs: torch.Tensor,
) -> torch.Tensor:
return inputs
class AlignPromptLeft(PromptAlignment):
def __init__(self, pad_id: int = -1):
self.pad_id = pad_id
def start_index(self, input_ids: list[list[int]]) -> int:
return min(len(sublist) for sublist in input_ids)
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
max_length = max(len(sublist) for sublist in input_ids)
return torch.tensor(
[
sublist + ([self.pad_id] * (max_length - len(sublist)))
for sublist in input_ids
],
requires_grad=False,
)
def postprocess_inputs(
self,
inputs: torch.Tensor,
original_inputs: torch.Tensor,
) -> torch.Tensor:
max_init_len = original_inputs.shape[1]
if inputs.shape[1] <= max_init_len:
original_inputs_limited = original_inputs[:, : inputs.shape[1]]
mask = original_inputs_limited != self.pad_id
inputs[mask] = original_inputs_limited[mask]
return inputs