Skip to content

Fix v2 for MoE #1548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gptqmodel/looper/native_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..looper.named_module import NamedModule
from ..models import BaseGPTQModel
from ..quantization.config import QuantizeConfig
from ..quantization.gptq import CPU, DEVICE_1
from ..quantization.gptq import CPU, DEVICE_1, DEVICE_2
from ..utils.logger import setup_logger

log = setup_logger()
Expand Down Expand Up @@ -75,7 +75,7 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
inp = inp[0].detach()

if self.qcfg.v2_memory_device == "auto":
v2_memory_device = DEVICE_1
v2_memory_device = DEVICE_2
elif self.qcfg.v2_memory_device == "cpu":
# slower but >= 4x vram memory reduction
v2_memory_device = CPU
Expand All @@ -84,7 +84,7 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
elif isinstance(self.qcfg.v2_memory_device, torch.device):
v2_memory_device = self.qcfg.v2_memory_device
else:
v2_memory_device = DEVICE_1
v2_memory_device = DEVICE_2

self.native_inp_caches[name] += [inp.to(device=v2_memory_device)]
del inp, out
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DEVICE_0 = auto_select_torch_device(index=0)
# device_1 may be same as device_0 if there is only 1 visible/active device
DEVICE_1 = auto_select_torch_device(index=1)
DEVICE_2 = auto_select_torch_device(index=2)

lock = threading.Lock()

Expand Down
41 changes: 40 additions & 1 deletion gptqmodel/quantization/gptqv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@

from ..looper.named_module import NamedModule
from ..quantization import QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.torch import torch_compile, torch_sync
from .gptq import DEVICE_1, GPTQ

log = setup_logger()

class GPTQv2(GPTQ):
def __init__(self, module: NamedModule, qcfg: Optional[QuantizeConfig]=None):
Expand Down Expand Up @@ -72,9 +74,46 @@ def __init__(self, module: NamedModule, qcfg: Optional[QuantizeConfig]=None):
# self.dXXT.addmm_((native_inp.T-reshaped_inp.T), reshaped_inp, beta=beta, alpha=alpha)
# del native_inp, reshaped_inp

def find_closest_native_input(self, inp):
if not self.native_inps:
return None

# only match with exact same shape
shape_matches = []
for i, native_inp in enumerate(self.native_inps):
if native_inp.shape == inp.shape:
shape_matches.append((i, native_inp))

# then find the closest tensor value match
if shape_matches:
closest_idx = -1
min_diff = float('inf')
for i, native_inp in shape_matches:
native_inp = native_inp.to(device=inp.device)
diff = (native_inp - inp).abs().sum().item()
if diff < min_diff:
min_diff = diff
closest_idx = i
if closest_idx != -1:
return self.native_inps.pop(closest_idx)

# no match found
return None

def process_batch(self, inp):
inp = inp.to(device=DEVICE_1, dtype=torch.float32)
native_inp = self.native_inps.pop(0).to(device=DEVICE_1, dtype=torch.float32)

# not compatible with Moe
# native_inp = self.native_inps.pop(0).to(device=DEVICE_1, dtype=torch.float32)

native_inp = self.find_closest_native_input(inp)

if native_inp is None:
log.error(f"Skipping input of shape `{inp.shape}` as it not matched to native_inputs. If this is MoE model, this is safe to ignore.")
return

native_inp = native_inp.to(device=inp.device)

if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
native_inp = native_inp.unsqueeze(0)
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def auto_select_torch_device(index: int = 0):
if HAS_CUDA:
# defensive check
if index > 0 and torch.cuda.device_count() <= index :
index = 0
index = torch.cuda.device_count() - 1
device = torch.device(f"cuda:{index}")
elif HAS_XPU:
# defensive check
if index > 0 and torch.xpu.device_count() <= index:
index = 0
index = torch.xpu.device_count() - 1
device = torch.device(f"xpu:{index}")
elif HAS_MPS:
device = torch.device("mps") # mps has no index
Expand Down