diff --git a/.idea/kotlinScripting.xml b/.idea/kotlinScripting.xml deleted file mode 100644 index bc444dead9..0000000000 --- a/.idea/kotlinScripting.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/CircuitCodeGenerationBackend.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/CircuitCodeGenerationBackend.kt index 2d675dc6d3..b503398f6a 100644 --- a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/CircuitCodeGenerationBackend.kt +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/CircuitCodeGenerationBackend.kt @@ -2,6 +2,7 @@ package io.github.aplcornell.viaduct.backends import io.github.aplcornell.viaduct.backends.aby.ABYBackend import io.github.aplcornell.viaduct.backends.cleartext.CleartextBackend +import io.github.aplcornell.viaduct.backends.commitment.CommitmentBackend /** Combines all back ends that support circuit code generation. */ -object CircuitCodeGenerationBackend : Backend by listOf(CleartextBackend, ABYBackend).unions() +object CircuitCodeGenerationBackend : Backend by listOf(CleartextBackend, ABYBackend, CommitmentBackend).unions() diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/cleartext/CleartextCircuitCodeGenerator.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/cleartext/CleartextCircuitCodeGenerator.kt index 02ae0f2880..26eb66167f 100644 --- a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/cleartext/CleartextCircuitCodeGenerator.kt +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/cleartext/CleartextCircuitCodeGenerator.kt @@ -2,12 +2,19 @@ package io.github.aplcornell.viaduct.backends.cleartext import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.MemberName +import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy +import com.squareup.kotlinpoet.asClassName +import com.squareup.kotlinpoet.asTypeName import io.github.aplcornell.viaduct.circuitcodegeneration.AbstractCodeGenerator import io.github.aplcornell.viaduct.circuitcodegeneration.Argument import io.github.aplcornell.viaduct.circuitcodegeneration.CodeGeneratorContext import io.github.aplcornell.viaduct.circuitcodegeneration.UnsupportedCommunicationException +import io.github.aplcornell.viaduct.circuitcodegeneration.kotlinType import io.github.aplcornell.viaduct.circuitcodegeneration.receiveExpected import io.github.aplcornell.viaduct.circuitcodegeneration.receiveReplicated +import io.github.aplcornell.viaduct.circuitcodegeneration.typeTranslator +import io.github.aplcornell.viaduct.runtime.commitment.Commitment +import io.github.aplcornell.viaduct.runtime.commitment.Committed import io.github.aplcornell.viaduct.syntax.BinaryOperator import io.github.aplcornell.viaduct.syntax.Host import io.github.aplcornell.viaduct.syntax.Protocol @@ -15,8 +22,10 @@ import io.github.aplcornell.viaduct.syntax.UnaryOperator import io.github.aplcornell.viaduct.syntax.circuit.OperatorNode import io.github.aplcornell.viaduct.syntax.operators.Maximum import io.github.aplcornell.viaduct.syntax.operators.Minimum +import io.github.aplcornell.viaduct.backends.commitment.Commitment as CommitmentProtocol class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCodeGenerator(context) { + override fun operatorApplication(protocol: Protocol, op: OperatorNode, arguments: List): CodeBlock = when (op.operator) { Minimum -> @@ -112,6 +121,128 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod } } + private fun createCommitment( + source: Protocol, + target: Protocol, + argument: Argument, + builder: CodeBlock.Builder, + ): CodeBlock { + require(context.host in source.hosts + target.hosts) + if (source !is Local) { + throw UnsupportedCommunicationException(source, target, argument.sourceLocation) + } + require(source.hosts.size == 1 && source.host in source.hosts) + require(target is CommitmentProtocol) + if (target.cleartextHost != source.host || target.cleartextHost in target.hashHosts) { + throw UnsupportedCommunicationException(source, target, argument.sourceLocation) + } + + val argType = kotlinType(argument.type.shape, typeTranslator(argument.type.elementType.value)) + val sendingHost = target.cleartextHost + val receivingHosts = target.hashHosts + return when (context.host) { + sendingHost -> { + val tempName1 = context.newTemporary("CommitTemp") + val tempName2 = context.newTemporary("CommitTemp") + builder.addStatement( + "val %N = %T(%L)", + tempName1, + (Committed::class).asTypeName().parameterizedBy(argType), + argument.value, + ) + builder.addStatement( + "val %N = %N.%M()", + tempName2, + tempName1, + MemberName(Committed.Companion::class.asClassName(), "commitment"), + ) + receivingHosts.forEach { + builder.addStatement("%L", context.send(CodeBlock.of("%N", tempName2), it)) + } + CodeBlock.of("%N", tempName1) + } + + in receivingHosts -> { + val tempName3 = context.newTemporary("CommitTemp") + builder.addStatement( + "val %N = %L", + tempName3, + context.receive((Commitment::class).asTypeName().parameterizedBy(argType), source.host), + ) + CodeBlock.of("%N", tempName3) + } + + else -> throw IllegalStateException() + } + } + + private fun openCommitment( + source: Protocol, + target: Protocol, + argument: Argument, + builder: CodeBlock.Builder, + ): CodeBlock { + require(source is CommitmentProtocol) + if (target !is Cleartext) { + throw UnsupportedCommunicationException(source, target, argument.sourceLocation) + } + require(context.host in source.hosts + target.hosts) + if (source.hashHosts != target.hosts || source.cleartextHost in source.hashHosts) { + throw UnsupportedCommunicationException(source, target, argument.sourceLocation) + } + + val argType = kotlinType(argument.type.shape, typeTranslator(argument.type.elementType.value)) + val sendingHost = source.cleartextHost + val receivingHosts = target.hosts + return when (context.host) { + sendingHost -> { + receivingHosts.forEach { + builder.addStatement("%L", context.send(argument.value, it)) + } + CodeBlock.of("%L.value", argument.value) + } + in receivingHosts -> { + val tempName1 = context.newTemporary("CommitTemp") + builder.addStatement( + "val %N = %L", + tempName1, + context.receive((Committed::class).asTypeName().parameterizedBy(argType), source.cleartextHost), + ) + val tempName2 = context.newTemporary("CommitTemp") + builder.addStatement( + "val %N = %L", + tempName2, + argument.value, + ) + val tempName3 = context.newTemporary("CommitTemp") + builder.addStatement( + "val %N = %N.%N(%N)", + tempName3, + tempName2, + "open", + tempName1, + ) + + val peers = receivingHosts.filter { it != context.host } + if (peers.isNotEmpty()) { + for (host in peers) builder.addStatement("%L", context.send(CodeBlock.of(tempName3), host)) + builder.addStatement( + "%L", + receiveExpected( + CodeBlock.of(tempName3), + context.host, + argType, + peers, + context, + ), + ) + } + CodeBlock.of("%N", tempName3) + } + else -> throw IllegalStateException() + } + } + override fun import( protocol: Protocol, arguments: List, @@ -127,6 +258,13 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod CodeBlock.of("") } } + is CommitmentProtocol -> { + if (context.host in protocol.hosts + arg.protocol.hosts) { + openCommitment(arg.protocol, protocol, arg, builder) + } else { + CodeBlock.of("") + } + } else -> throw UnsupportedCommunicationException(arg.protocol, protocol, arg.sourceLocation) } @@ -149,6 +287,13 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod CodeBlock.of("") } } + is CommitmentProtocol -> { + if (context.host in protocol.hosts + arg.protocol.hosts) { + createCommitment(protocol, arg.protocol, arg, builder) + } else { + CodeBlock.of("") + } + } else -> throw UnsupportedCommunicationException(protocol, arg.protocol, arg.sourceLocation) } diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/commitment/CommitmentBackend.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/commitment/CommitmentBackend.kt index 02f17c451d..fedd6b5b2e 100644 --- a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/commitment/CommitmentBackend.kt +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/commitment/CommitmentBackend.kt @@ -26,5 +26,5 @@ object CommitmentBackend : Backend { override fun codeGenerator(context: CodeGeneratorContext): CodeGenerator = CommitmentDispatchCodeGenerator(context) - override fun circuitCodeGenerator(context: CircuitCodeGeneratorContext): CircuitCodeGenerator = TODO() + override fun circuitCodeGenerator(context: CircuitCodeGeneratorContext): CircuitCodeGenerator = CommitmentCircuitCodeGenerator(context) } diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/commitment/CommitmentCircuitCodeGenerator.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/commitment/CommitmentCircuitCodeGenerator.kt new file mode 100644 index 0000000000..7cc61e184b --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/backends/commitment/CommitmentCircuitCodeGenerator.kt @@ -0,0 +1,44 @@ +package io.github.aplcornell.viaduct.backends.commitment + +import com.squareup.kotlinpoet.CodeBlock +import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy +import com.squareup.kotlinpoet.TypeName +import com.squareup.kotlinpoet.asTypeName +import io.github.aplcornell.viaduct.circuitcodegeneration.AbstractCodeGenerator +import io.github.aplcornell.viaduct.circuitcodegeneration.Argument +import io.github.aplcornell.viaduct.circuitcodegeneration.CodeGeneratorContext +import io.github.aplcornell.viaduct.circuitcodegeneration.UnsupportedCommunicationException +import io.github.aplcornell.viaduct.circuitcodegeneration.typeTranslator +import io.github.aplcornell.viaduct.runtime.commitment.Committed +import io.github.aplcornell.viaduct.syntax.Protocol +import io.github.aplcornell.viaduct.syntax.types.ValueType +import io.github.aplcornell.viaduct.runtime.commitment.Commitment as CommitmentValue + +/** + * Backend code generator for the commitment protocol for the circuit IR. + * + * Throws an UnsupportedCommunicationException when used in an input program as a computation protocol. + * This is because the commitment protocol is only a storage format and not a computation protocol. + */ +class CommitmentCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCodeGenerator(context) { + override fun paramType(protocol: Protocol, sourceType: ValueType): TypeName { + require(protocol is Commitment) + return when (context.host) { + protocol.cleartextHost -> (Committed::class).asTypeName().parameterizedBy(typeTranslator(sourceType)) + in protocol.hashHosts -> (CommitmentValue::class).asTypeName().parameterizedBy(typeTranslator(sourceType)) + else -> throw IllegalStateException() + } + } + + override fun storageType(protocol: Protocol, sourceType: ValueType): TypeName { + return super.storageType(protocol, sourceType) + } + + override fun import(protocol: Protocol, arguments: List): Pair> { + throw UnsupportedCommunicationException(arguments.first().protocol, protocol, arguments.first().sourceLocation) + } + + override fun export(protocol: Protocol, arguments: List): Pair> { + throw UnsupportedCommunicationException(arguments.first().protocol, protocol, arguments.first().sourceLocation) + } +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/ArrayTypeNode.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/ArrayTypeNode.kt new file mode 100644 index 0000000000..d66e48a18a --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/ArrayTypeNode.kt @@ -0,0 +1,19 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.prettyprinting.Document +import io.github.aplcornell.viaduct.prettyprinting.bracketed +import io.github.aplcornell.viaduct.prettyprinting.plus +import io.github.aplcornell.viaduct.syntax.Arguments +import io.github.aplcornell.viaduct.syntax.SourceLocation +import io.github.aplcornell.viaduct.syntax.ValueTypeNode + +class ArrayTypeNode( + val elementType: ValueTypeNode, + val shape: Arguments, + override val sourceLocation: SourceLocation, +) : Node() { + override val children: Iterable + get() = shape + + override fun toDocument(): Document = elementType + shape.bracketed() +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Expressions.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Expressions.kt new file mode 100644 index 0000000000..ea1ab80285 --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Expressions.kt @@ -0,0 +1,89 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.prettyprinting.Document +import io.github.aplcornell.viaduct.prettyprinting.bracketed +import io.github.aplcornell.viaduct.prettyprinting.plus +import io.github.aplcornell.viaduct.prettyprinting.times +import io.github.aplcornell.viaduct.prettyprinting.tupled +import io.github.aplcornell.viaduct.syntax.Arguments +import io.github.aplcornell.viaduct.syntax.Operator +import io.github.aplcornell.viaduct.syntax.SourceLocation +import io.github.aplcornell.viaduct.syntax.surface.keyword +import io.github.aplcornell.viaduct.syntax.values.Value + +/** A computation that produces a result. */ +sealed class ExpressionNode : Node() +sealed class IndexExpressionNode : ExpressionNode() + +/** A literal constant. */ +class LiteralNode( + val value: Value, + override val sourceLocation: SourceLocation, +) : IndexExpressionNode() { + override val children: Iterable + get() = listOf() + + override fun toDocument(): Document = value.toDocument() +} + +class ReferenceNode( + val name: VariableNode, + override val sourceLocation: SourceLocation, +) : IndexExpressionNode() { + override val children: Iterable + get() = listOf() + + override fun toDocument(): Document = name.toDocument() +} + +class LookupNode( + val variable: VariableNode, + val indices: Arguments, + override val sourceLocation: SourceLocation, +) : ExpressionNode() { + override val children: Iterable + get() = indices + + override fun toDocument(): Document = variable + indices.bracketed() +} + +/** An n-ary operator applied to n arguments. */ +class OperatorApplicationNode( + val operator: OperatorNode, + val arguments: Arguments, + override val sourceLocation: SourceLocation, +) : ExpressionNode() { + override val children: Iterable + get() = listOf(operator) + arguments + + override fun toDocument(): Document = Document("(") + operator.operator.toDocument(arguments) + ")" +} + +class OperatorNode( + val operator: Operator, + override val sourceLocation: SourceLocation, +) : Node() { + override val children: Iterable + get() = listOf() + + override fun toDocument(): Document = Document("::$operator") +} + +/** + * @param defaultValue to be used when the list is empty + * @param operator must be associative + */ +class ReduceNode( + val operator: OperatorNode, + val defaultValue: ExpressionNode, + val indices: IndexParameterNode, + val body: ExpressionNode, + override val sourceLocation: SourceLocation, +) : ExpressionNode() { + override val children: Iterable + get() = listOf(operator, defaultValue, indices, body) + + override fun toDocument(): Document { + return keyword("reduce") + listOf(operator, defaultValue).tupled() * "{" * indices * "->" * body * " }" + } +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/IndexParameterNode.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/IndexParameterNode.kt new file mode 100644 index 0000000000..299bf17dae --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/IndexParameterNode.kt @@ -0,0 +1,16 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.prettyprinting.Document +import io.github.aplcornell.viaduct.prettyprinting.times +import io.github.aplcornell.viaduct.syntax.SourceLocation + +class IndexParameterNode( + override val name: VariableNode, + val bound: IndexExpressionNode, + override val sourceLocation: SourceLocation, +) : Node(), VariableDeclarationNode { + override val children: Iterable + get() = listOf(bound) + + override fun toDocument(): Document = name.toDocument() * "<" * bound +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Node.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Node.kt new file mode 100644 index 0000000000..25e3d468fa --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Node.kt @@ -0,0 +1,7 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.attributes.TreeNode +import io.github.aplcornell.viaduct.prettyprinting.PrettyPrintable +import io.github.aplcornell.viaduct.syntax.HasSourceLocation + +sealed class Node : TreeNode, HasSourceLocation, PrettyPrintable diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Parsing.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Parsing.kt new file mode 100644 index 0000000000..203fe3da95 --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Parsing.kt @@ -0,0 +1,27 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.parsing.ProtocolParser +import io.github.aplcornell.viaduct.parsing.SourceFile +import io.github.aplcornell.viaduct.parsing.defaultProtocolParsers +import io.github.aplcornell.viaduct.syntax.Protocol +import io.github.aplcornell.viaduct.syntax.ProtocolName +import java_cup.runtime.ComplexSymbolFactory + +/** Parses [this] string and returns the AST. */ +fun String.parse( + path: String = "", + protocolParsers: Map> = defaultProtocolParsers, +): ProgramNode { + return SourceFile.from(path, this).parse(protocolParsers) +} + +/** Parses [this] source file to IR and returns the IR. */ +fun SourceFile.parse( + protocolParsers: Map> = defaultProtocolParsers, +): ProgramNode { + val symbolFactory = ComplexSymbolFactory() + val scanner = SourceLexer(this, symbolFactory) + val parser = SourceParser(scanner, symbolFactory) + parser.protocolParsers = protocolParsers + return parser.parseProgram() +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/ProgramNode.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/ProgramNode.kt new file mode 100644 index 0000000000..ffdec5516b --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/ProgramNode.kt @@ -0,0 +1,55 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.attributes.Attribute +import io.github.aplcornell.viaduct.attributes.Tree +import io.github.aplcornell.viaduct.attributes.attribute +import io.github.aplcornell.viaduct.prettyprinting.Document +import io.github.aplcornell.viaduct.prettyprinting.concatenated +import io.github.aplcornell.viaduct.prettyprinting.plus +import io.github.aplcornell.viaduct.syntax.Host +import io.github.aplcornell.viaduct.syntax.SourceLocation +import kotlinx.collections.immutable.PersistentList +import kotlinx.collections.immutable.toPersistentList + +/** + * The source representation of a program. + */ +class ProgramNode( + val declarations: PersistentList, + override val sourceLocation: SourceLocation, +) : Node(), List by declarations { + constructor(declarations: List, sourceLocation: SourceLocation) : + this(declarations.toPersistentList(), sourceLocation) + + override val children: Iterable + get() = declarations + + /** A lazily constructed [Tree] instance for the program. */ + val tree: Tree by lazy { Tree(this) } + + private val functionCache: Attribute<(ProgramNode) -> Any?, Any?> = attribute { + this.invoke(this@ProgramNode) + } + + /** + * Applies [function] to this program and returns the results. + * The result is cached, so future calls with the same function do not evaluate [function]. + */ + @Suppress("UNCHECKED_CAST") + fun cached(function: (ProgramNode) -> T): T = + functionCache(function) as T + + val hostDeclarations: Iterable = + declarations.filterIsInstance() + + val hosts: Set = hostDeclarations.map { it.name.value }.toSet() + +// val circuits: Iterable = +// declarations.filterIsInstance() + + val functions: Iterable = + declarations.filterIsInstance() + + override fun toDocument(): Document = + declarations.concatenated(Document.forcedLineBreak + Document.forcedLineBreak) +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Statements.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Statements.kt new file mode 100644 index 0000000000..78c78720a1 --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Statements.kt @@ -0,0 +1,146 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.prettyprinting.Document +import io.github.aplcornell.viaduct.prettyprinting.braced +import io.github.aplcornell.viaduct.prettyprinting.bracketed +import io.github.aplcornell.viaduct.prettyprinting.concatenated +import io.github.aplcornell.viaduct.prettyprinting.joined +import io.github.aplcornell.viaduct.prettyprinting.nested +import io.github.aplcornell.viaduct.prettyprinting.plus +import io.github.aplcornell.viaduct.prettyprinting.times +import io.github.aplcornell.viaduct.prettyprinting.tupled +import io.github.aplcornell.viaduct.syntax.Arguments +import io.github.aplcornell.viaduct.syntax.FunctionNameNode +import io.github.aplcornell.viaduct.syntax.HostNode +import io.github.aplcornell.viaduct.syntax.ProtocolNode +import io.github.aplcornell.viaduct.syntax.SourceLocation +import io.github.aplcornell.viaduct.syntax.surface.keyword +import kotlinx.collections.immutable.PersistentList +import kotlinx.collections.immutable.toPersistentList + +/** A computation with side effects. */ +sealed class StatementNode : Node() + +//sealed class CircuitStatementNode : StatementNode() + +sealed class CommandNode : Node() + +/** A sequence of statements. */ +class BlockNode +private constructor( + val statements: PersistentList, + val returnStatement: ReturnNode, + override val sourceLocation: SourceLocation, +) : Node(), List by statements { + constructor(statements: List, returnStatement: ReturnNode, sourceLocation: SourceLocation) : + this(statements.toPersistentList(), returnStatement, sourceLocation) + + override val children: Iterable + get() = statements + listOf(returnStatement) + + override fun toDocument(): Document { + val statements: MutableList = (statements.map { it.toDocument() } as MutableList) + statements.add(returnStatement.toDocument()) + val body: Document = statements.concatenated(separator = Document.forcedLineBreak) + return listOf((Document.forcedLineBreak + body).nested() + Document.forcedLineBreak).braced() + } +} + +class ReturnNode( + val values: Arguments, + override val sourceLocation: SourceLocation, +) : StatementNode() { + override val children: Iterable + get() = values + + override fun toDocument(): Document = keyword("return") * values.joined() +} + +class LetNode( + val bindings: Arguments, + val command: CommandNode, + override val sourceLocation: SourceLocation, +) : StatementNode() { + override val children: Iterable + get() = bindings + listOf(command) + + override fun toDocument(): Document = + keyword("val") * bindings.joined() * "=" * command +} + +class VariableBindingNode( + override val name: VariableNode, + val protocol: ProtocolNode, + val type: ArrayTypeNode, + override val sourceLocation: SourceLocation, +) : Node(), VariableDeclarationNode { + override val children: Iterable + get() = listOf(type) + + override fun toDocument(): Document = + name + Document("@") + protocol.value.toDocument() + Document(":") * type.toDocument() +} + +class FunctionCallNode( + val name: FunctionNameNode, + val bounds: Arguments, + val inputs: Arguments, + override val sourceLocation: SourceLocation, +) : CommandNode() { + override val children: Iterable + get() = bounds + inputs + + override fun toDocument(): Document { + return name + bounds.joined( + prefix = Document("<"), + postfix = Document(">"), + ) + inputs.tupled() + } +} + +/** + * An external input. + * @param type Type of the value to receive. + */ +class InputNode( + val type: ArrayTypeNode, + val host: HostNode, + override val sourceLocation: SourceLocation, +) : CommandNode() { + override val children: Iterable + get() = listOf(type) + + override fun toDocument(): Document = host + "." + keyword("input") + "<" + type + ">()" +} + +/** An external output. */ +class OutputNode( + val type: ArrayTypeNode, + val message: ReferenceNode, + val host: HostNode, + override val sourceLocation: SourceLocation, +) : CommandNode() { + override val children: Iterable + get() = listOf(message) + + override fun toDocument(): Document = host + "." + keyword("output") + listOf(message).tupled() +} + +/** A command to create a new array on a computation protocol + * e.g. @Local(host = Alice) for i < 10, j < 20: a[i, j] + b[i, j] + * @param protocol The computation protocol on which the array is created + * @param indices The index variables and expressions to be used to create the array + * @param value The value to be assigned to each element of the array + * */ +class ArrayCreationNode( + val protocol: ProtocolNode, + val indices: Arguments, + val value: ExpressionNode, + override val sourceLocation: SourceLocation, +) : CommandNode() { + override val children: Iterable + get() = indices + listOf(value) + + override fun toDocument(): Document = + (protocol.value.toDocument() * keyword("for") * indices.joined() + ":") * value +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/TopLevelDeclarations.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/TopLevelDeclarations.kt new file mode 100644 index 0000000000..67c417819f --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/TopLevelDeclarations.kt @@ -0,0 +1,72 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.prettyprinting.Document +import io.github.aplcornell.viaduct.prettyprinting.joined +import io.github.aplcornell.viaduct.prettyprinting.plus +import io.github.aplcornell.viaduct.prettyprinting.times +import io.github.aplcornell.viaduct.prettyprinting.tupled +import io.github.aplcornell.viaduct.syntax.Arguments +import io.github.aplcornell.viaduct.syntax.FunctionNameNode +import io.github.aplcornell.viaduct.syntax.HostNode +import io.github.aplcornell.viaduct.syntax.ProtocolNode +import io.github.aplcornell.viaduct.syntax.SourceLocation +import io.github.aplcornell.viaduct.syntax.surface.keyword + +/** A declaration at the top level of a file. */ +sealed class TopLevelDeclarationNode : Node() + +/** + * Declaration of a participant and their authority. + * + * @param name Host name. + */ +class HostDeclarationNode( + val name: HostNode, + override val sourceLocation: SourceLocation, +) : TopLevelDeclarationNode() { + override val children: Iterable + get() = listOf() + + override fun toDocument(): Document = keyword("host") * name +} + +class SizeParameterNode( + override val name: VariableNode, + override val sourceLocation: SourceLocation, +) : Node(), VariableDeclarationNode { + override val children: Iterable + get() = listOf() + + override fun toDocument(): Document = name.toDocument() +} + +/** + * A parameter to a function declaration. + */ +class ParameterNode( + override val name: VariableNode, + val protocol: ProtocolNode, + val type: ArrayTypeNode, + override val sourceLocation: SourceLocation, +) : Node(), VariableDeclarationNode { + override val children: Iterable + get() = listOf(type) + + override fun toDocument(): Document = + name + Document("@") + protocol.value.toDocument() + Document(":") * type.toDocument() +} + +class FunctionDeclarationNode( + val name: FunctionNameNode, + val sizes: Arguments, + val inputs: Arguments, + val outputs: Arguments, + val body: BlockNode, + override val sourceLocation: SourceLocation, +) : TopLevelDeclarationNode() { + override val children: Iterable + get() = sizes + inputs + outputs + listOf(body) + + override fun toDocument(): Document = + (keyword("fun") * "<" + sizes.joined() + ">") * name + inputs.tupled() * "->" * outputs.joined() * body +} diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Variable.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Variable.kt new file mode 100644 index 0000000000..f7f24082b0 --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/Variable.kt @@ -0,0 +1,17 @@ +package io.github.aplcornell.viaduct.syntax.source + +import io.github.aplcornell.viaduct.prettyprinting.Document +import io.github.aplcornell.viaduct.prettyprinting.styled +import io.github.aplcornell.viaduct.syntax.Located +import io.github.aplcornell.viaduct.syntax.Name +import io.github.aplcornell.viaduct.syntax.VariableStyle + +/** A variable is a name that stands for a value or an object instance. */ +data class Variable(override val name: String) : Name { + override val nameCategory: String + get() = "variable" + + override fun toDocument(): Document = Document(name).styled(VariableStyle) +} + +typealias VariableNode = Located diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/VariableDeclarationNode.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/VariableDeclarationNode.kt new file mode 100644 index 0000000000..e77023227e --- /dev/null +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/syntax/source/VariableDeclarationNode.kt @@ -0,0 +1,7 @@ +package io.github.aplcornell.viaduct.syntax.source + +/** A node that declares a [Variable]. */ +sealed interface VariableDeclarationNode { + /** The variable declared by this node. */ + val name: VariableNode +} diff --git a/compiler/tests/should-pass/circuit/cleartext/Commitment1.circuit b/compiler/tests/should-pass/circuit/cleartext/Commitment1.circuit new file mode 100644 index 0000000000..a1f83781b9 --- /dev/null +++ b/compiler/tests/should-pass/circuit/cleartext/Commitment1.circuit @@ -0,0 +1,20 @@ +host alice +host bob +host chuck + +circuit fun <> move@Local(host = alice)(a: int[]) -> b: int[] { + return a +} + +circuit fun <> move2@Replication(hosts = {bob, chuck})(a: int[]) -> b: int[] { + return a +} + +fun <> main() -> { + val a@Local(host = alice) = alice.input() + val c@Commitment(sender = alice, receivers = {bob, chuck}) = move<>(a) + val d@Replication(hosts = {bob, chuck}) = move2<>(c) + val = bob.output(d) + val = chuck.output(d) + return +} diff --git a/compiler/tests/should-pass/circuit/cleartext/Commitment2.circuit b/compiler/tests/should-pass/circuit/cleartext/Commitment2.circuit new file mode 100644 index 0000000000..8ee265cd9b --- /dev/null +++ b/compiler/tests/should-pass/circuit/cleartext/Commitment2.circuit @@ -0,0 +1,18 @@ +host alice +host bob + +circuit fun <> move@Local(host = alice)(a: int[]) -> b: int[] { + return a +} + +circuit fun <> move2@Local(host = bob)(a: int[]) -> b: int[] { + return a +} + +fun <> main() -> { + val a@Local(host = alice) = alice.input() + val c@Commitment(sender = alice, receivers = {bob}) = move<>(a) + val d@Local(host = bob) = move2<>(c) + val = bob.output(d) + return +} diff --git a/examples/inputs/circuit/cleartext/Commitment1-alice.txt b/examples/inputs/circuit/cleartext/Commitment1-alice.txt new file mode 100644 index 0000000000..7ed6ff82de --- /dev/null +++ b/examples/inputs/circuit/cleartext/Commitment1-alice.txt @@ -0,0 +1 @@ +5 diff --git a/examples/inputs/circuit/cleartext/Commitment1-bob.txt b/examples/inputs/circuit/cleartext/Commitment1-bob.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/inputs/circuit/cleartext/Commitment1-chuck.txt b/examples/inputs/circuit/cleartext/Commitment1-chuck.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/inputs/circuit/cleartext/Commitment2-alice.txt b/examples/inputs/circuit/cleartext/Commitment2-alice.txt new file mode 100644 index 0000000000..7ed6ff82de --- /dev/null +++ b/examples/inputs/circuit/cleartext/Commitment2-alice.txt @@ -0,0 +1 @@ +5 diff --git a/examples/inputs/circuit/cleartext/Commitment2-bob.txt b/examples/inputs/circuit/cleartext/Commitment2-bob.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/outputs/circuit/cleartext/Commitment1-alice.txt b/examples/outputs/circuit/cleartext/Commitment1-alice.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/outputs/circuit/cleartext/Commitment1-bob.txt b/examples/outputs/circuit/cleartext/Commitment1-bob.txt new file mode 100644 index 0000000000..7ed6ff82de --- /dev/null +++ b/examples/outputs/circuit/cleartext/Commitment1-bob.txt @@ -0,0 +1 @@ +5 diff --git a/examples/outputs/circuit/cleartext/Commitment1-chuck.txt b/examples/outputs/circuit/cleartext/Commitment1-chuck.txt new file mode 100644 index 0000000000..7ed6ff82de --- /dev/null +++ b/examples/outputs/circuit/cleartext/Commitment1-chuck.txt @@ -0,0 +1 @@ +5 diff --git a/examples/outputs/circuit/cleartext/Commitment2-alice.txt b/examples/outputs/circuit/cleartext/Commitment2-alice.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/outputs/circuit/cleartext/Commitment2-bob.txt b/examples/outputs/circuit/cleartext/Commitment2-bob.txt new file mode 100644 index 0000000000..7ed6ff82de --- /dev/null +++ b/examples/outputs/circuit/cleartext/Commitment2-bob.txt @@ -0,0 +1 @@ +5