Skip to content

Commit

Permalink
sql: implement CREATE OR REPLACE FUNCTION
Browse files Browse the repository at this point in the history
This commit adds to the `REPLACE` path to the
`CREATE OR REPLACE FUNCTION` statement. Major changes are
(1) fetch function with same signature if exists, and validate
(2) remove refereces before replacing the function, and then
    add new references.

Release note: None
  • Loading branch information
chengxiong-ruan committed Aug 16, 2022
1 parent 5d7e232 commit d71f834
Show file tree
Hide file tree
Showing 6 changed files with 848 additions and 622 deletions.
10 changes: 5 additions & 5 deletions pkg/sql/catalog/funcdesc/func_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func NewMutableFunctionDescriptor(
parentID descpb.ID,
parentSchemaID descpb.ID,
name string,
argNum int,
args []descpb.FunctionDescriptor_Argument,
returnType *types.T,
returnSet bool,
privs *catpb.PrivilegeDescriptor,
Expand All @@ -71,7 +71,7 @@ func NewMutableFunctionDescriptor(
ID: id,
ParentID: parentID,
ParentSchemaID: parentSchemaID,
Args: make([]descpb.FunctionDescriptor_Argument, 0, argNum),
Args: args,
ReturnType: descpb.FunctionDescriptor_ReturnType{
Type: returnType,
ReturnSet: returnSet,
Expand Down Expand Up @@ -417,9 +417,9 @@ func (desc *Mutable) SetDeclarativeSchemaChangerState(state *scpb.DescriptorStat
desc.DeclarativeSchemaChangerState = state
}

// AddArgument adds a function argument to argument list.
func (desc *Mutable) AddArgument(arg descpb.FunctionDescriptor_Argument) {
desc.Args = append(desc.Args, arg)
// AddArguments adds function arguments to argument list.
func (desc *Mutable) AddArguments(args ...descpb.FunctionDescriptor_Argument) {
desc.Args = append(desc.Args, args...)
}

// SetVolatility sets the volatility attribute.
Expand Down
94 changes: 58 additions & 36 deletions pkg/sql/crdb_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2591,7 +2591,6 @@ CREATE TABLE crdb_internal.create_function_statements (
)
`,
populate: func(ctx context.Context, p *planner, db catalog.DatabaseDescriptor, addRow func(...tree.Datum) error) error {
flags := tree.ObjectLookupFlags{CommonLookupFlags: tree.CommonLookupFlags{AvoidLeased: true}}
var dbDescs []catalog.DatabaseDescriptor
if db == nil {
var err error
Expand All @@ -2602,45 +2601,68 @@ CREATE TABLE crdb_internal.create_function_statements (
} else {
dbDescs = append(dbDescs, db)
}
for _, db := range dbDescs {
err := forEachSchema(ctx, p, db, func(sc catalog.SchemaDescriptor) error {
var fnIDs []descpb.ID
fnIDToScName := make(map[descpb.ID]string)
fnIDToScID := make(map[descpb.ID]descpb.ID)
fnIDToDBName := make(map[descpb.ID]string)
fnIDToDBID := make(map[descpb.ID]descpb.ID)
for _, curDB := range dbDescs {
err := forEachSchema(ctx, p, curDB, func(sc catalog.SchemaDescriptor) error {
return sc.ForEachFunctionOverload(func(overload descpb.SchemaDescriptor_FunctionOverload) error {
fnDesc, err := p.Descriptors().GetImmutableFunctionByID(ctx, p.txn, overload.ID, flags)
if err != nil {
return err
}
treeNode, err := fnDesc.ToCreateExpr()
treeNode.FuncName.ObjectNamePrefix = tree.ObjectNamePrefix{
ExplicitSchema: true,
SchemaName: tree.Name(sc.GetName()),
}
if err != nil {
return err
}
for i := range treeNode.Options {
if body, ok := treeNode.Options[i].(tree.FunctionBodyStr); ok {
stmtStrs := strings.Split(string(body), "\n")
for i := range stmtStrs {
stmtStrs[i] = "\t" + stmtStrs[i]
}
fnIDs = append(fnIDs, overload.ID)
fnIDToScName[overload.ID] = sc.GetName()
fnIDToScID[overload.ID] = sc.GetID()
fnIDToDBName[overload.ID] = curDB.GetName()
fnIDToDBID[overload.ID] = curDB.GetID()
return nil
})
})
if err != nil {
return err
}
}

p := &treeNode.Options[i]
// Add two new lines just for better formatting.
*p = "\n" + tree.FunctionBodyStr(strings.Join(stmtStrs, "\n")) + "\n"
}
fnDescs, err := p.Descriptors().GetImmutableDescriptorsByID(
ctx, p.txn, tree.CommonLookupFlags{Required: true, AvoidLeased: true}, fnIDs...,
)
if err != nil {
return err
}

for _, desc := range fnDescs {
fnDesc := desc.(catalog.FunctionDescriptor)
if err != nil {
return err
}
treeNode, err := fnDesc.ToCreateExpr()
treeNode.FuncName.ObjectNamePrefix = tree.ObjectNamePrefix{
ExplicitSchema: true,
SchemaName: tree.Name(fnIDToScName[fnDesc.GetID()]),
}
if err != nil {
return err
}
for i := range treeNode.Options {
if body, ok := treeNode.Options[i].(tree.FunctionBodyStr); ok {
stmtStrs := strings.Split(string(body), "\n")
for i := range stmtStrs {
stmtStrs[i] = "\t" + stmtStrs[i]
}
p := &treeNode.Options[i]
// Add two new lines just for better formatting.
*p = "\n" + tree.FunctionBodyStr(strings.Join(stmtStrs, "\n")) + "\n"
}
}

return addRow(
tree.NewDInt(tree.DInt(db.GetID())), // database_id
tree.NewDString(db.GetName()), // database_name
tree.NewDInt(tree.DInt(sc.GetID())), // schema_id
tree.NewDString(sc.GetName()), // schema_name
tree.NewDInt(tree.DInt(fnDesc.GetID())), // function_id
tree.NewDString(fnDesc.GetName()), //function_name
tree.NewDString(tree.AsString(treeNode)), // create_statement
)
})
})
err = addRow(
tree.NewDInt(tree.DInt(fnIDToDBID[fnDesc.GetID()])), // database_id
tree.NewDString(fnIDToDBName[fnDesc.GetID()]), // database_name
tree.NewDInt(tree.DInt(fnIDToScID[fnDesc.GetID()])), // schema_id
tree.NewDString(fnIDToScName[fnDesc.GetID()]), // schema_name
tree.NewDInt(tree.DInt(fnDesc.GetID())), // function_id
tree.NewDString(fnDesc.GetName()), //function_name
tree.NewDString(tree.AsString(treeNode)), // create_statement
)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit d71f834

Please sign in to comment.