Skip to content

Commit

Permalink
feat: accept abi method reference as a parameter to methodSelector fu…
Browse files Browse the repository at this point in the history
…nction (#108)

* feat: accept abi method reference as a parameter to methodSelector function

* feat: use method names overridden by decorators in methodSelector

* refactor: Store arc4 method config in context and resolve name overrides from there

* refactor: Defer evaulation of function bodies until after contract metadata has been gathered

* test: add execution tests for arc4 method selector

---------

Co-authored-by: Tristan Menzel <[email protected]>
  • Loading branch information
boblat and tristanmenzel authored Feb 20, 2025
1 parent e04b710 commit 763db73
Show file tree
Hide file tree
Showing 315 changed files with 9,303 additions and 3,524 deletions.
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"dev:examples": "tsx src/cli.ts build examples --output-awst --output-awst-json",
"dev:approvals": "tsx src/cli.ts build tests/approvals --dry-run",
"dev:expected-output": "tsx src/cli.ts build tests/expected-output --dry-run",
"dev:testing": "tsx src/cli.ts build tests/approvals/native-arrays.algo.ts tests/approvals/mutable-arrays.algo.ts --output-awst --output-awst-json --output-ssa-ir --log-level=info --out-dir out/unoptimized/[name] --optimization-level=1",
"dev:testing": "tsx src/cli.ts build tests/approvals/arc4-method-selector.algo.ts --output-awst --output-awst-json --output-ssa-ir --log-level=info --out-dir out/unoptimized/[name] --optimization-level=1",
"audit": "better-npm-audit audit",
"format": "prettier --write .",
"lint": "eslint \"src/**/*.ts\"",
Expand Down
2 changes: 1 addition & 1 deletion packages/algo-ts/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@algorandfoundation/algorand-typescript",
"version": "1.0.0-beta.18",
"version": "1.0.0-beta.19",
"description": "This package contains definitions for the types which comprise Algorand TypeScript which can be compiled to run on the Algorand Virtual Machine using the Puya compiler.",
"private": false,
"main": "index.js",
Expand Down
7 changes: 6 additions & 1 deletion packages/algo-ts/src/arc4/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ export function baremethod<TContract extends Contract>(config?: BareMethodConfig
* @param methodSignature An ARC4 method signature. Eg. `hello(string)string`. Must be a compile time constant.
* @returns The ARC4 method selector. Eg. `02BECE11`
*/
export function methodSelector(methodSignature: string): bytes {
export function methodSelector<
TMethod extends (this: TContract, ...args: TArgs) => TReturn,
TContract extends Contract,
TArgs extends DeliberateAny[],
TReturn,
>(methodSignature: string | TMethod): bytes {
throw new NoImplementation()
}
1 change: 1 addition & 0 deletions src/awst/nodes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,7 @@ export class ARC4ABIMethodConfig {
export type LValue = VarExpression | FieldExpression | IndexExpression | TupleExpression | AppStateExpression | AppAccountStateExpression
export type Constant = IntegerConstant | BoolConstant | BytesConstant | StringConstant
export type AWST = Contract | LogicSignature | Subroutine
export type ARC4MethodConfig = ARC4BareMethodConfig | ARC4ABIMethodConfig
export const concreteNodes = {
expressionStatement: ExpressionStatement,
block: Block,
Expand Down
9 changes: 6 additions & 3 deletions src/awst_build/ast-visitors/base-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
} from '../../visitor/syntax-names'
import type { Visitor } from '../../visitor/visitor'
import { accept } from '../../visitor/visitor'
import type { AwstBuildContext } from '../context/awst-build-context'
import { AwstBuildContext } from '../context/awst-build-context'
import { InstanceBuilder, NodeBuilder } from '../eb'
import { BooleanExpressionBuilder } from '../eb/boolean-expression-builder'
import { ArrayLiteralExpressionBuilder } from '../eb/literal/array-literal-expression-builder'
Expand Down Expand Up @@ -57,9 +57,12 @@ import { TextVisitor } from './text-visitor'
export abstract class BaseVisitor implements Visitor<Expressions, NodeBuilder> {
private baseAccept = <TNode extends ts.Node>(node: TNode) => accept<BaseVisitor, TNode>(this, node)
readonly textVisitor: TextVisitor
get context() {
return AwstBuildContext.current
}

protected constructor(public context: AwstBuildContext) {
this.textVisitor = new TextVisitor(context)
protected constructor() {
this.textVisitor = new TextVisitor()
}

logNotSupported(node: ts.Node | undefined, message: string) {
Expand Down
42 changes: 19 additions & 23 deletions src/awst_build/ast-visitors/constructor-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,50 @@ import ts from 'typescript'
import type { ContractReference } from '../../awst/models'
import { nodeFactory } from '../../awst/node-factory'
import * as awst from '../../awst/nodes'
import { AwstBuildFailureError } from '../../errors'
import { logger } from '../../logger'
import { codeInvariant, invariant } from '../../util'
import type { AwstBuildContext } from '../context/awst-build-context'
import type { ContractClassPType } from '../ptypes'
import { voidPType } from '../ptypes'
import { ContractMethodBaseVisitor } from './contract-method-visitor'
import { visitInChildContext } from './util'

export interface ConstructorInfo {
propertyInitializerStatements: awst.Statement[]
cref: ContractReference
}

export class ConstructorVisitor extends ContractMethodBaseVisitor {
private readonly _result: awst.ContractMethod
private _foundSuperCall = false
private readonly _propertyInitializerStatements: awst.Statement[]
constructor(ctx: AwstBuildContext, node: ts.ConstructorDeclaration, contractType: ContractClassPType, contractInfo: ConstructorInfo) {
super(ctx, node, contractType)
this._propertyInitializerStatements = contractInfo.propertyInitializerStatements
const sourceLocation = this.sourceLocation(node)

const { args, body, documentation } = this.buildFunctionAwst(node)
constructor(
node: ts.ConstructorDeclaration,
contractType: ContractClassPType,
private readonly contractInfo: ConstructorInfo,
) {
super(node, contractType)
}

this._result = new awst.ContractMethod({
get result() {
const sourceLocation = this.sourceLocation(this.node)
const { args, body, documentation } = this.buildFunctionAwst()
return new awst.ContractMethod({
arc4MethodConfig: null,
memberName: this._functionType.name,
sourceLocation,
args,
returnType: voidPType.wtype,
body,
cref: contractInfo.cref,
cref: this.contractInfo.cref,
documentation,
inline: null,
})
}

get result() {
return this._result
}

public static buildConstructor(
parentCtx: AwstBuildContext,
node: ts.ConstructorDeclaration,
contractType: ContractClassPType,
constructorMethodInfo: ConstructorInfo,
) {
const result = new ConstructorVisitor(parentCtx.createChildContext(), node, contractType, constructorMethodInfo).result
invariant(result instanceof awst.ContractMethod, "result must be ContractMethod'")
return result
return visitInChildContext(this, node, contractType, constructorMethodInfo)
}

visitBlock(node: ts.Block): awst.Block {
Expand All @@ -70,13 +65,14 @@ export class ConstructorVisitor extends ContractMethodBaseVisitor {
sourceLocation: this.sourceLocation(s),
},
...(Array.isArray(statement) ? statement : [statement]),
...this._propertyInitializerStatements,
...this.contractInfo.propertyInitializerStatements,
)
}
return statement
} catch (e) {
if (e instanceof AwstBuildFailureError) return []
throw e
invariant(e instanceof Error, 'Only errors should be thrown')
logger.error(e)
return []
}
}),
)
Expand Down
66 changes: 39 additions & 27 deletions src/awst_build/ast-visitors/contract-method-visitor.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import type ts from 'typescript'
import { ContractReference, OnCompletionAction } from '../../awst/models'
import { nodeFactory } from '../../awst/node-factory'
import type { ABIMethodArgConstantDefault, ABIMethodArgMemberDefault } from '../../awst/nodes'
import type { ABIMethodArgConstantDefault, ABIMethodArgMemberDefault, ARC4MethodConfig } from '../../awst/nodes'
import * as awst from '../../awst/nodes'
import { ARC4ABIMethodConfig, ARC4BareMethodConfig, ARC4CreateOption } from '../../awst/nodes'
import type { SourceLocation } from '../../awst/source-location'
import { Constants } from '../../constants'
import { CodeError } from '../../errors'
import { logger } from '../../logger'
import { codeInvariant, invariant, isIn } from '../../util'
import type { AwstBuildContext } from '../context/awst-build-context'
import type { NodeBuilder } from '../eb'
import { ContractSuperBuilder, ContractThisBuilder } from '../eb/contract-builder'
import { requireExpressionOfType } from '../eb/util'
Expand All @@ -18,11 +17,12 @@ import type { ContractClassPType, FunctionPType } from '../ptypes'
import { GlobalStateType, LocalStateType } from '../ptypes'
import { DecoratorVisitor } from './decorator-visitor'
import { FunctionVisitor } from './function-visitor'
import { visitInChildContext } from './util'

export class ContractMethodBaseVisitor extends FunctionVisitor {
protected readonly _contractType: ContractClassPType
constructor(ctx: AwstBuildContext, node: ts.MethodDeclaration | ts.ConstructorDeclaration, contractType: ContractClassPType) {
super(ctx, node)
constructor(node: ts.MethodDeclaration | ts.ConstructorDeclaration, contractType: ContractClassPType) {
super(node)
this._contractType = contractType
}
visitSuperKeyword(node: ts.SuperExpression): NodeBuilder {
Expand All @@ -31,25 +31,28 @@ export class ContractMethodBaseVisitor extends FunctionVisitor {
// Only the polytype clustered class should have more than one base type, and it shouldn't have
// any user code with super calls
invariant(this._contractType.baseTypes.length === 1, 'Super keyword only valid if contract has a single base type')
return new ContractSuperBuilder(this._contractType.baseTypes[0], sourceLocation, this.context)
return new ContractSuperBuilder(this._contractType.baseTypes[0], sourceLocation)
}

visitThisKeyword(node: ts.ThisExpression): NodeBuilder {
const sourceLocation = this.sourceLocation(node)
return new ContractThisBuilder(this._contractType, sourceLocation, this.context)
return new ContractThisBuilder(this._contractType, sourceLocation)
}
}

export class ContractMethodVisitor extends ContractMethodBaseVisitor {
private readonly _result: awst.ContractMethod
private readonly metaData: {
cref: ContractReference
arc4MethodConfig: ARC4MethodConfig | null
sourceLocation: SourceLocation
}

constructor(ctx: AwstBuildContext, node: ts.MethodDeclaration, contractType: ContractClassPType) {
super(ctx, node, contractType)
constructor(node: ts.MethodDeclaration, contractType: ContractClassPType) {
super(node, contractType)
const sourceLocation = this.sourceLocation(node)
const { args, body, documentation } = this.buildFunctionAwst(node)
const cref = ContractReference.fromPType(this._contractType)

const decorator = DecoratorVisitor.buildContractMethodData(ctx, node)
const decorator = DecoratorVisitor.buildContractMethodData(node)
const cref = ContractReference.fromPType(this._contractType)

const modifiers = this.parseMemberModifiers(node)

Expand All @@ -60,29 +63,38 @@ export class ContractMethodVisitor extends ContractMethodBaseVisitor {
methodLocation: sourceLocation,
})

this._result = new awst.ContractMethod({
arc4MethodConfig: arc4MethodConfig ?? null,
memberName: this._functionType.name,
if (arc4MethodConfig)
this.context.addArc4Config({
contractReference: cref,
sourceLocation,
arc4MethodConfig,
memberName: this._functionType.name,
})
this.metaData = {
arc4MethodConfig,
cref,
sourceLocation,
}
}

get result() {
const { args, body, documentation } = this.buildFunctionAwst()

return new awst.ContractMethod({
arc4MethodConfig: this.metaData.arc4MethodConfig,
memberName: this._functionType.name,
sourceLocation: this.metaData.sourceLocation,
args,
returnType: this._functionType.returnType.wtypeOrThrow,
body,
cref,
cref: this.metaData.cref,
documentation,
inline: null,
})
}

get result() {
return this._result
}

public static buildContractMethod(
parentCtx: AwstBuildContext,
node: ts.MethodDeclaration,
contractType: ContractClassPType,
): awst.ContractMethod {
return new ContractMethodVisitor(parentCtx.createChildContext(), node, contractType).result
public static buildContractMethod(node: ts.MethodDeclaration, contractType: ContractClassPType): () => awst.ContractMethod {
return visitInChildContext(this, node, contractType)
}

private buildArc4Config({
Expand All @@ -95,7 +107,7 @@ export class ContractMethodVisitor extends ContractMethodBaseVisitor {
decorator: DecoratorData | undefined
modifiers: { isPublic: boolean; isStatic: boolean }
methodLocation: SourceLocation
}): awst.ContractMethod['arc4MethodConfig'] | null {
}): awst.ARC4MethodConfig | null {
const isProgramMethod = isIn(functionType.name, [Constants.approvalProgramMethodName, Constants.clearStateProgramMethodName])

if (decorator && isIn(decorator.type, [Constants.arc4BareDecoratorName, Constants.arc4AbiDecoratorName])) {
Expand Down
Loading

0 comments on commit 763db73

Please sign in to comment.