Skip to content

Commit

Permalink
poc
Browse files Browse the repository at this point in the history
  • Loading branch information
JessicaXYWang committed Oct 15, 2024
1 parent f0c2b00 commit 603777a
Showing 1 changed file with 122 additions and 40 deletions.
162 changes: 122 additions & 40 deletions core/src/main/python/synapse/ml/phi3/Phi3Transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,85 +6,167 @@
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from transformers import AutoTokenizer, AutoModelForCausalLM
from pyspark import keyword_only
import re

class Phi3Transform(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
model_name = Param(Params._dummy(), "model_name", "model name", typeConverter=TypeConverters.toString)
max_new_tokens = Param(Params._dummy(), "max_new_tokens", "maximum new tokens", typeConverter=TypeConverters.toInt)
temperature = Param(Params._dummy(), "temperature", "generation temperature", typeConverter=TypeConverters.toFloat)

class Peekable:
def __init__(self, iterable):
self._iterator = iter(iterable)
self._cache = []
def __iter__(self):
return self
def __next__(self):
if self._cache:
return self._cache.pop(0)
else:
return next(self._iterator)
def peek(self, n=1):
"""Peek at the next n elements without consuming them."""
while len(self._cache) < n:
try:
self._cache.append(next(self._iterator))
except StopIteration:
break
if n == 1:
return self._cache[0] if self._cache else None
else:
return self._cache[:n]

class ModelParam:
def __init__(self, **kwargs):
self.param = {
}
self.param.update(kwargs)

def get_param(self):
return self.param

class ModelConfig:
def __init__(self, **kwargs):
self.config = {
}
self.config.update(kwargs)

def get_config(self):
return self.config

def camel_to_snake(text):
return re.sub(r'(?<!^)(?=[A-Z])', '_', text).lower()

class HuggingFaceCausalLM(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
modelName = Param(Params._dummy(), "modelName", "model name", typeConverter=TypeConverters.toString)
inputCol = Param(Params._dummy(), "inputCol", "input column", typeConverter=TypeConverters.toString)
outputCol = Param(Params._dummy(), "outputCol", "output column", typeConverter=TypeConverters.toString)
modelParam = Param(Params._dummy(), "modelParam", "Model Parameters")
modelConfig = Param(Params._dummy(), "modelConfig", "Model configuration")
@keyword_only
def __init__(self, model_name=None, inputCol=None, outputCol=None, max_new_tokens=100, temperature=1.0):
super(Phi3Transform, self).__init__()
self._setDefault(model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature)
def __init__(self,
modelName=None,
inputCol=None,
outputCol=None,

):
super(HuggingFaceCausalLM, self).__init__()
self._setDefault(
modelName=modelName,
inputCol=inputCol,
outputCol=outputCol,
modelParam=ModelParam(),
modelConfig=ModelConfig()
)
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, model_name=None, inputCol=None, outputCol=None, max_new_tokens=None, temperature=None):
def setParams(self):
kwargs = self._input_kwargs
return self._set(**kwargs)

def setModelName(self, value):
return self._set(model_name=value)
return self._set(modelName=value)

def getModelName(self):
return self.getOrDefault(self.model_name)

# TODO: Clean Parameters https://huggingface.co/docs/transformers/v4.42.0/en/main_classes/text_generation
def setMaxNewTokens(self, value):
return self._set(max_new_tokens=value)
return self.getOrDefault(self.modelName)

def setInputCol(self, value):
return self._set(inputCol=value)

def getMaxNewTokens(self):
return self.getOrDefault(self.max_new_tokens)
def getInputCol(self):
return self.getOrDefault(self.inputCol)

def setTemperature(self, value):
return self._set(temperature=value)
def setOutputCol(self, value):
return self._set(outputCol=value)

def getTemperature(self):
return self.getOrDefault(self.temperature)
def getOutputCol(self):
return self.getOrDefault(self.outputCol)

def setInputCol(self, value):
return self._set(inputCol=value)
def setModelParam(self, **kwargs):
param = ModelParam(**kwargs)
return self._set(modelParam=param)

def getModelParam(self):
return self.getOrDefault(self.modelParam)

def setModelConfig(self, **kwargs):
config = ModelConfig(**kwargs)
return self._set(modelConfig=config)

def getModelConfig(self):
return self.getOrDefault(self.modelConfig)

def setOutputCol(self, value):
return self._set(outputCol=value)
def _predict_single_complete(self, prompt, model, tokenizer):
param = self.getModelParam().get_param()
inputs = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, **param)
decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return decoded_output

def _predict_single(self, prompt, model, tokenizer):
def _predict_single_chat(self, prompt, model, tokenizer):
param = self.getModelParam().get_param()
chat = [{"role": "user", "content": prompt}]
formatted_chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
tokenized_chat = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False)
inputs = {key: tensor.to(model.device) for key, tensor in tokenized_chat.items()}

max_new_tokens = self.getMaxNewTokens()
temperature = self.getTemperature()

outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=temperature) # TODO: clean parameters
merged_inputs = {**inputs, **param}
outputs = model.generate(**merged_inputs)
decoded_output = tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True)
return decoded_output

@property
def schema(self):
return StructType([StructField(self.getOutputCol(), StringType(), True)])

def _process_partition(self, iterator):
def _process_partition(self, iterator, task):
peekable_iterator = Peekable(iterator)
try:
first_row = peekable_iterator.peek()
except StopIteration:
return
model_name = self.getModelName()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
model_config = self.getModelConfig().get_config()
model = AutoModelForCausalLM.from_pretrained(model_name, **model_config)
model.eval()

for row in iterator:
for row in peekable_iterator:
prompt = row[self.getInputCol()]
result = self._predict_single(prompt, model, tokenizer)

if task == "chat":
result = self._predict_single_chat(prompt, model, tokenizer)
elif task == "complete":
result = self._predict_single_complete(prompt, model, tokenizer)
row_dict = row.asDict()
row_dict[self.getOutputCol()] = result
yield Row(**row_dict)

def _transform(self, dataset):
input_schema = dataset.schema
output_schema = input_schema.add(StructField(self.getOutputCol(), StringType(), True))

result_rdd = dataset.rdd.mapPartitions(self._process_partition)
output_schema = StructType(input_schema.fields + [StructField(self.getOutputCol(), StringType(), True)])
result_rdd = dataset.rdd.mapPartitions(lambda partition: self._process_partition(partition, "chat"))
result_df = result_rdd.toDF(output_schema)
return result_df

def complete(self, dataset):
input_schema = dataset.schema
output_schema = StructType(input_schema.fields + [StructField(self.getOutputCol(), StringType(), True)])
result_rdd = dataset.rdd.mapPartitions(lambda partition: self._process_partition(partition, "complete"))
result_df = result_rdd.toDF(output_schema)

return result_df

0 comments on commit 603777a

Please sign in to comment.