Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hellozmz committed Feb 23, 2024
1 parent 44bcb64 commit cd98f2d
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, DeepLink.

import os

def _patch_internlm():
import importlib.util
Expand Down Expand Up @@ -86,9 +86,14 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs):
_find_or_mock_module("xentropy_cuda_lib")
_find_or_mock_module("flash_attn_cuda")
_find_flash_attn()
_patch_flash_attn()
_patch_ops()
ban_patch_internlm_flash_attn = bool(os.environ.get("BAN_PATCH_INTERNLM_FLASH_ATTN", False))
if not ban_patch_internlm_flash_attn:
_patch_flash_attn()
ban_patch_internlm_op = bool(os.environ.get("BAN_PATCH_INTERNLM_OP", False))
if not ban_patch_internlm_op:
_patch_ops()
print("[deeplink_ext] patched diopi implementation of internlm\n", end="")


_patch_internlm()
ban_patch_internlm = bool(os.environ.get("BAN_PATCH_INTERNLM", False))
if not ban_patch_internlm:
_patch_internlm()

0 comments on commit cd98f2d

Please sign in to comment.