Skip to content

Commit

Permalink
add set device type
Browse files Browse the repository at this point in the history
  • Loading branch information
JessicaXYWang committed Oct 21, 2024
1 parent f2ab308 commit 3ee9168
Showing 1 changed file with 26 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, **kwargs):

def get_config(self):
return self.config

def set_config(self, **kwargs):
self.config.update(kwargs)

def camel_to_snake(text):
return re.sub(r'(?<!^)(?=[A-Z])', '_', text).lower()
Expand All @@ -58,6 +61,8 @@ class HuggingFaceCausalLM(Transformer, HasInputCol, HasOutputCol, DefaultParamsR
outputCol = Param(Params._dummy(), "outputCol", "output column", typeConverter=TypeConverters.toString)
modelParam = Param(Params._dummy(), "modelParam", "Model Parameters")
modelConfig = Param(Params._dummy(), "modelConfig", "Model configuration")
deviceMap = Param(Params._dummy(), "deviceMap", "device map", typeConverter=TypeConverters.toString)
torchDtype =Param(Params._dummy(), "torchDtype", "torch dtype", typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self,
modelName=None,
Expand Down Expand Up @@ -113,9 +118,21 @@ def setModelConfig(self, **kwargs):
def getModelConfig(self):
return self.getOrDefault(self.modelConfig)

def setDeviceMap(self, value):
return self._set(deviceMap=value)

def getDeviceMap(self):
return self.getOrDefault(self.deviceMap)

def setTorchDtype(self, value):
return self._set(torchDtype=value)

def getTorchDtype(self):
return self.getOrDefault(self.torchDtype)

def _predict_single_complete(self, prompt, model, tokenizer):
param = self.getModelParam().get_param()
inputs = tokenizer(prompt, return_tensors="pt").input_ids
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model.generate(inputs, **param)
decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return decoded_output
Expand Down Expand Up @@ -143,7 +160,15 @@ def _process_partition(self, iterator, task):
return
model_name = self.getModelName()
tokenizer = AutoTokenizer.from_pretrained(model_name)

device_map = self.getDeviceMap()
torch_dtype = self.getTorchDtype()
model_config = self.getModelConfig().get_config()
if device_map:
model_config["device_map"] = device_map
if torch_dtype:
model_config["tourch_dtype"] = torch_dtype

model = AutoModelForCausalLM.from_pretrained(model_name, **model_config)
model.eval()

Expand Down

0 comments on commit 3ee9168

Please sign in to comment.