diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index ad4065920..0d39b2c9f 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -4,6 +4,7 @@ import torch.distributed from opentelemetry import trace from transformers import AutoTokenizer +from transformers.models.phi.modeling_phi import PhiConfig from lorax_server.models import FlashCausalLM from lorax_server.models.custom_modeling.flash_phi_modeling import ( @@ -14,7 +15,6 @@ MLP_FC1, MLP_FC2, FlashPhiForCausalLM, - PhiConfig, ) from lorax_server.utils import ( Weights,