Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix local contract upgrade #20296

Open
wants to merge 12 commits into
base: main-2.x
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[lf] object Pretty {
char('\'') + text(p) + char('\'')

def prettyParties(p: Set[Party]): Doc =
char('{') & intercalate(char(','), p.map(prettyParty)) & char('{')
char('{') & intercalate(char(','), p.map(prettyParty)) & char('}')

def prettyDamlException(error: interpretation.Error): Doc = {
import interpretation.Error._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ private[lf] object SBuiltin {
coid: V.ContractId,
ifaceId: TypeConName,
)(k: SAny => Control[Question.Update]): Control[Question.Update] = {
fetchAny(machine, None, coid) { (_, srcContract) =>
fetchAny(machine, None, coid) { (_, srcContract, _) =>
val (tplId, arg) = getSAnyContract(ArrayList.single(srcContract), 0)
ensureTemplateImplementsInterface(machine, ifaceId, coid, tplId) {
viewInterface(machine, ifaceId, tplId, arg) { srcView =>
Expand All @@ -1218,12 +1218,12 @@ private[lf] object SBuiltin {
coid: V.ContractId,
interfaceId: TypeConName,
)(k: SAny => Control[Question.Update]): Control[Question.Update] =
fetchAny(machine, None, coid) { (maybePkgName, srcContract) =>
fetchAny(machine, None, coid) { (maybePkgName, srcContractAny, srcContract) =>
maybePkgName match {
case None =>
crash(s"unexpected contract instance without packageName")
case Some(pkgName) =>
val (srcTplId, srcArg) = getSAnyContract(ArrayList.single(srcContract), 0)
val (srcTplId, srcArg) = getSAnyContract(ArrayList.single(srcContractAny), 0)
ensureTemplateImplementsInterface(machine, interfaceId, coid, srcTplId) {
viewInterface(machine, interfaceId, srcTplId, srcArg) { srcView =>
resolvePackageName(machine, pkgName) { pkgId =>
Expand Down Expand Up @@ -1251,7 +1251,7 @@ private[lf] object SBuiltin {
if (dstTplId == srcTplId)
k(SAny(Ast.TTyCon(dstTplId), dstArg))
else {
validateContractInfo(machine, coid, dstTplId, contract) { () =>
checkContractUpgradable(coid, srcContract, contract) { () =>
executeExpression(machine, SEPreventCatch(dstView)) {
dstViewValue =>
if (srcViewValue != dstViewValue) {
Expand Down Expand Up @@ -1306,7 +1306,7 @@ private[lf] object SBuiltin {
machine: UpdateMachine,
): Control[Question.Update] = {
val coid = getSContractId(args, 0)
fetchAny(machine, optTargetTemplateId, coid) { (_, sv) =>
fetchAny(machine, optTargetTemplateId, coid) { (_, sv, _) =>
Control.Value(sv)
}
}
Expand Down Expand Up @@ -2369,7 +2369,7 @@ private[lf] object SBuiltin {
optTargetTemplateId: Option[TypeConName],
coid: V.ContractId,
)(f: SValue => Control[Question.Update]): Control[Question.Update] = {
fetchAny(machine, optTargetTemplateId, coid) { (_, fetched) =>
fetchAny(machine, optTargetTemplateId, coid) { (_, fetched, _) =>
// The SBCastAnyContract check can never fail when the upgrading feature flag is enabled.
// This is because the contract got up/down-graded when imported by importValue.

Expand All @@ -2386,16 +2386,61 @@ private[lf] object SBuiltin {
}
}

/** Checks that the metadata of [original] and [recomputed] are the same, fails with a [Control.Error] if not. */
private def checkContractUpgradable(
coid: V.ContractId,
original: ContractInfo,
recomputed: ContractInfo,
)(
k: () => Control[Question.Update]
): Control[Question.Update] = {

def check[T](getter: ContractInfo => T, desc: String): Option[String] =
Option.when(getter(recomputed) != getter(original))(
s"$desc mismatch: $original vs $recomputed"
)

List(
check(_.signatories, "signatories"),
// This definition of observers allows observers to lose parties that are signatories
check(c => c.stakeholders -- c.signatories, "observers"),
check(_.keyOpt.map(_.maintainers), "key maintainers"),
check(_.keyOpt.map(_.globalKey.key), "key value"),
).flatten match {
case Nil => k()
case errors =>
Control.Error(
IE.Upgrade(
IE.Upgrade.ValidationFailed(
coid = coid,
srcTemplateId = original.templateId,
dstTemplateId = recomputed.templateId,
signatories = recomputed.signatories,
observers = recomputed.observers,
keyOpt = recomputed.keyOpt.map(_.globalKeyWithMaintainers),
msg = errors.mkString("['", "', '", "']"),
)
)
)
}
}

// This is the core function which fetches a contract given its coid.
// Regardless of it being a local, disclosed or global contract
private def fetchAny(
machine: UpdateMachine,
optTargetTemplateId: Option[TypeConName],
coid: V.ContractId,
)(f: (Option[Ref.PackageName], SValue) => Control[Question.Update]): Control[Question.Update] = {
)(
f: (Option[Ref.PackageName], SValue, ContractInfo) => Control[Question.Update]
): Control[Question.Update] = {

def importContract(coinst: V.ContractInstance) = {
val V.ContractInstance(_, srcTmplId, coinstArg) = coinst
def importContract(
srcContract: ContractInfo
) = {
val srcTmplId = srcContract.templateId
val coinst =
V.ContractInstance(srcContract.packageName, srcContract.templateId, srcContract.arg)
val (upgradingIsEnabled, dstTmplId) = optTargetTemplateId match {
case Some(tycon) if coinst.upgradable =>
(true, tycon)
Expand All @@ -2411,19 +2456,19 @@ private[lf] object SBuiltin {
dstTmplId.packageId,
language.Reference.Template(dstTmplId),
) { () =>
importValue(machine, dstTmplId, coinstArg) { templateArg =>
importValue(machine, dstTmplId, srcContract.arg) { templateArg =>
getContractInfo(
machine,
coid,
dstTmplId,
templateArg,
allowCatchingContractInfoErrors = false,
) { contract =>
ensureContractActive(machine, coid, contract.templateId) {
) { dstContract =>
ensureContractActive(machine, coid, dstContract.templateId) {

machine.checkContractVisibility(coid, contract)
machine.checkContractVisibility(coid, dstContract)
machine.enforceLimitAddInputContract()
machine.enforceLimitSignatoriesAndObservers(coid, contract)
machine.enforceLimitSignatoriesAndObservers(coid, dstContract)

// In Validation mode, we always call validateContractInfo
// In Submission mode, we only call validateContractInfo when src != dest
Expand All @@ -2435,12 +2480,11 @@ private[lf] object SBuiltin {
upgradingIsEnabled && (srcTmplId.packageId != dstTmplId.packageId)
}
if (needValidationCall) {

validateContractInfo(machine, coid, srcTmplId, contract) { () =>
f(contract.packageName, contract.any)
checkContractUpgradable(coid, srcContract, dstContract) { () =>
f(dstContract.packageName, dstContract.any, dstContract)
}
} else {
f(contract.packageName, contract.any)
f(dstContract.packageName, dstContract.any, dstContract)
}
}
}
Expand All @@ -2461,58 +2505,32 @@ private[lf] object SBuiltin {
if (optTargetTemplateId.forall(_ == templateId)) {
// If the local contract has the same package ID as the target template ID, then we don't need to
// import its value and validate its contract info again.
f(contract.packageName, SValue.SAnyContract(templateId, templateArg))
f(contract.packageName, SValue.SAnyContract(templateId, templateArg), contract)
} else {
importContract(V.ContractInstance(contract.packageName, templateId, contract.arg))
importContract(contract)
}
}
}
case None =>
machine.lookupGlobalContract(coid)(importContract)
machine.lookupGlobalContract(coid)(coinst =>
machine.ensurePackageIsLoaded(
coinst.template.packageId,
language.Reference.Template(coinst.template),
) { () =>
importValue(machine, coinst.template, coinst.arg) { templateArg =>
getContractInfo(
machine,
coid,
coinst.template,
templateArg,
allowCatchingContractInfoErrors = false,
)(importContract)
}
}
)
}
}

private def validateContractInfo(
machine: UpdateMachine,
coid: V.ContractId,
srcTemplateId: Ref.Identifier,
contract: ContractInfo,
)(
continue: () => Control[Question.Update]
): Control[Question.Update] = {

val keyOpt: Option[GlobalKeyWithMaintainers] = contract.keyOpt match {
case None => None
case Some(cachedKey) =>
Some(cachedKey.globalKeyWithMaintainers)
}
machine.needUpgradeVerification(
location = NameOf.qualifiedNameOfCurrentFunc,
coid = coid,
signatories = contract.signatories,
observers = contract.observers,
keyOpt = keyOpt,
continue = {
case None =>
continue()
case Some(msg) =>
Control.Error(
IE.Upgrade(
IE.Upgrade.ValidationFailed(
coid = coid,
srcTemplateId = srcTemplateId,
dstTemplateId = contract.templateId,
signatories = contract.signatories,
observers = contract.observers,
keyOpt = keyOpt,
msg = msg,
)
)
)
},
)
}

private def importValue[Q](machine: Machine[Q], templateId: TypeConName, coinstArg: V)(
f: SValue => Control[Q]
): Control[Q] = {
Expand Down
Loading