Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ut for property #1233

Closed
wants to merge 10 commits into from
26 changes: 25 additions & 1 deletion src/Neo.Compiler.CSharp/CompilationEngine/CompilationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using Akka.Util.Internal;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using Akka.Util.Internal;

using Diagnostic = Microsoft.CodeAnalysis.Diagnostic;
using ECPoint = Neo.Cryptography.ECC.ECPoint;

Expand Down Expand Up @@ -72,6 +73,9 @@ public class CompilationContext
internal int StaticFieldCount => _staticFields.Count + _anonymousStaticFields.Count + _vtables.Count;
private byte[] Script => _script ??= GetInstructions().Select(p => p.ToArray()).SelectMany(p => p).ToArray();

// Define a tuple array to store both field symbols and their semantic models
internal (IFieldSymbol Field, SemanticModel Model)[] ContractFields = [];
internal SemanticModel ContractSemanticModel { get; set; }

/// <summary>
/// Specify the contract to be compiled.
Expand Down Expand Up @@ -372,6 +376,8 @@ private void ProcessClass(SemanticModel model, INamedTypeSymbol symbol)
return;
}

ContractSemanticModel = model;

foreach (var attribute in symbol.GetAttributesWithInherited())
{
if (attribute.AttributeClass!.IsSubclassOf(nameof(ManifestExtraAttribute)))
Expand Down Expand Up @@ -416,6 +422,25 @@ private void ProcessClass(SemanticModel model, INamedTypeSymbol symbol)
}
_className = symbol.Name;
}
// Get all fields and their corresponding semantic models
ContractFields = symbol.GetAllMembers()
.OfType<IFieldSymbol>()
.Select(field =>
{
// Try to get the syntax reference for the field
var syntaxRef = field.DeclaringSyntaxReferences.FirstOrDefault();
// If the field has a syntax reference, get its semantic model
// Otherwise, use the current model (for metadata fields)
var fieldModel = syntaxRef != null
? ((ISourceAssemblySymbol)field.ContainingAssembly).Compilation.GetSemanticModel(syntaxRef.SyntaxTree)
: model;
return (Field: field, Model: fieldModel);
})
.ToArray();

// Process each field using its symbol
ContractFields.ForEach(f => AddStaticField(f.Field));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ContractFields.ForEach(f => AddStaticField(f.Field));
foreach ((IFieldSymbol f, _) in ContractFields)
AddStaticField(f);

Copy link
Contributor

@Hecate2 Hecate2 Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array.ForEach is using Akka.Util.Internal, and is harder to debug


Dictionary<(string, int), IMethodSymbol> export = new();
// export methods `new`ed in child class, not those hidden in parent class
foreach (ISymbol member in symbol.GetAllMembers())
Expand Down Expand Up @@ -514,7 +539,6 @@ private void ProcessMethod(SemanticModel model, IMethodSymbol symbol, bool expor
throw new CompilationException(symbol, DiagnosticId.SyntaxNotSupported, $"Unsupported syntax: Can not set contract interface {symbol.Name} as inline.");
return;
}

MethodConvert convert = ConvertMethod(model, symbol);
if (export && MethodConvert.NeedInstanceConstructor(symbol))
{
Expand Down
24 changes: 21 additions & 3 deletions src/Neo.Compiler.CSharp/MethodConvert/ConstructorConvert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,23 @@ private void ProcessConstructorInitializer(SemanticModel model)

private void ProcessStaticFields(SemanticModel model)
{
foreach (INamedTypeSymbol @class in _context.StaticFieldSymbols.Select(p => p.ContainingType).Distinct<INamedTypeSymbol>(SymbolEqualityComparer.Default).ToArray())

foreach (var contractField in _context.ContractFields)
{
ProcessFieldInitializer(model, contractField.Field, null, () =>
{
byte index = _context.AddStaticField(contractField.Field);
AccessSlot(OpCode.STSFLD, index);
});
}

foreach (INamedTypeSymbol @class in _context.StaticFieldSymbols
.Select(p => p.ContainingType)
.Distinct<INamedTypeSymbol>(SymbolEqualityComparer.Default)
.ToArray())
{
foreach (IFieldSymbol field in @class.GetAllMembers().OfType<IFieldSymbol>())
foreach (IFieldSymbol field in @class.GetAllMembers().OfType<IFieldSymbol>().Where(p => !_context.ContractFields.Any(f =>
SymbolEqualityComparer.Default.Equals(f.Field, p))))
{
if (field.IsConst || !field.IsStatic) continue;
ProcessFieldInitializer(model, field, null, () =>
Expand All @@ -75,7 +89,11 @@ private void ProcessStaticFields(SemanticModel model)
}
foreach (var (fieldIndex, type) in _context.VTables)
{
IMethodSymbol[] virtualMethods = type.GetAllMembers().OfType<IMethodSymbol>().Where(p => p.IsVirtualMethod()).ToArray();
IMethodSymbol[] virtualMethods = type
.GetAllMembers()
.OfType<IMethodSymbol>()
.Where(p => p.IsVirtualMethod())
.ToArray();
for (int i = virtualMethods.Length - 1; i >= 0; i--)
{
IMethodSymbol method = virtualMethods[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ private void ConvertParameterIdentifierNamePostIncrementOrDecrementExpression(Sy

private void ConvertPropertyIdentifierNamePostIncrementOrDecrementExpression(SemanticModel model, SyntaxToken operatorToken, IPropertySymbol symbol)
{
if (symbol.IsStatic)
if (!NeedInstanceConstructor(symbol.GetMethod!))
{
CallMethodWithConvention(model, symbol.GetMethod!);
AddInstruction(OpCode.DUP);
Expand Down
39 changes: 36 additions & 3 deletions src/Neo.Compiler.CSharp/MethodConvert/MethodConvert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
using System.Linq;
using System.Numerics;
using System.Runtime.InteropServices;
using Neo.VM.Types;
using Array = System.Array;

namespace Neo.Compiler
{
Expand Down Expand Up @@ -243,12 +245,43 @@ private void ProcessFieldInitializer(SemanticModel model, IFieldSymbol field, Ac
syntaxNode = syntax;
initializer = syntax.Initializer;
}
if (initializer is null) return;
model = model.Compilation.GetSemanticModel(syntaxNode.SyntaxTree);

if (initializer is null)
{
if(_context.ContractFields.Any(f =>
SymbolEqualityComparer.Default.Equals(f.Field, field)) &&
(field.Type.GetStackItemType() == StackItemType.Integer || field.Type.GetStackItemType() == StackItemType.Integer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f.Field.Type😰?

{
var index = _context.AddStaticField(field);
PushDefault(field.Type);
AccessSlot(OpCode.STSFLD, index);
}
return;
}

using (InsertSequencePoint(syntaxNode))
{
preInitialize?.Invoke();
ConvertExpression(model, initializer.Value, syntaxNode);
// We must process contract fields separately, they may not belong to the current semantic model
// And they may also not belong to the semantic model of the contract, but the parent class semantic model
if (_context.ContractFields.Any(f =>
SymbolEqualityComparer.Default.Equals(f.Field, field)))
{
// Try to get the syntax reference for the field
var syntaxRef = field.DeclaringSyntaxReferences.FirstOrDefault();
// If the field has a syntax reference, get its semantic model
// Otherwise, use the current model (for metadata fields)
var fieldModel = syntaxRef != null
? ((ISourceAssemblySymbol)field.ContainingAssembly).Compilation.GetSemanticModel(syntaxRef.SyntaxTree)
: _context.ContractSemanticModel;

ConvertExpression( fieldModel.Compilation.GetSemanticModel(syntaxNode.SyntaxTree), initializer.Value, syntaxNode);
}
else
{
model = model.Compilation.GetSemanticModel(syntaxNode.SyntaxTree);
ConvertExpression(model, initializer.Value, syntaxNode);
}
postInitialize?.Invoke();
}
}
Expand Down
86 changes: 44 additions & 42 deletions src/Neo.Compiler.CSharp/MethodConvert/PropertyConvert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,28 @@ private void ConvertNoBody(AccessorDeclarationSyntax syntax)
using (InsertSequencePoint(syntax))
{
_inline = attribute is null;
ConvertFieldBackedProperty(property);
// Here we handle them separately, if store backed, its store backed.
// we need to load value from store directly.
if (attribute is not null)
{
ConvertStorageBackedProperty(property, attribute);
}
else
{
ConvertFieldBackedProperty(property);
}
}
}

private void ConvertFieldBackedProperty(IPropertySymbol property)
{
IFieldSymbol[] fields = property.ContainingType.GetAllMembers().OfType<IFieldSymbol>().ToArray();
if (Symbol.IsStatic)
var backingField = Array.Find(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property))!;
// We need to take care of contract fields as non-static.
if (Symbol.IsStatic || !NeedInstanceConstructor(Symbol) || _context.ContractFields.Any(f =>
SymbolEqualityComparer.Default.Equals(f.Field, backingField)))
{
IFieldSymbol backingField = Array.Find(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property))!;
// IFieldSymbol backingField = Array.Find(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property))!;
byte backingFieldIndex = _context.AddStaticField(backingField);
switch (Symbol.MethodKind)
{
Expand All @@ -61,8 +71,6 @@ private void ConvertFieldBackedProperty(IPropertySymbol property)
}
else
{
if (!NeedInstanceConstructor(Symbol))
return;
fields = fields.Where(p => !p.IsStatic).ToArray();
int backingFieldIndex = Array.FindIndex(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property));
switch (Symbol.MethodKind)
Expand Down Expand Up @@ -122,24 +130,18 @@ private void ConvertStorageBackedProperty(IPropertySymbol property, AttributeDat
{
IFieldSymbol[] fields = property.ContainingType.GetAllMembers().OfType<IFieldSymbol>().ToArray();
byte[] key = GetStorageBackedKey(property, attribute);

IFieldSymbol backingField = Array.Find(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property))!;

if (Symbol.MethodKind == MethodKind.PropertyGet)
{
JumpTarget endTarget = new();
if (Symbol.IsStatic)
{
// AddInstruction(OpCode.DUP);
AddInstruction(OpCode.ISNULL);
// Ensure that no object was sent
Jump(OpCode.JMPIFNOT_L, endTarget);
}
else if (NeedInstanceConstructor(Symbol))
{
// Check class
Jump(OpCode.JMPIF_L, endTarget);
}
// Step 1. Load the value from the store.
Push(key);
CallInteropMethod(ApplicationEngine.System_Storage_GetReadOnlyContext);
CallInteropMethod(ApplicationEngine.System_Storage_Get);

// Step 2. Check if the value is initialized.
// If not, load the default/assigned value to the backing field.
switch (property.Type.Name)
{
case "byte":
Expand All @@ -162,14 +164,28 @@ private void ConvertStorageBackedProperty(IPropertySymbol property, AttributeDat
case "Int64":
case "UInt64":
case "BigInteger":
// Replace NULL with 0
case "bool":
/// TODO: Default value for string
// Check Null
AddInstruction(OpCode.DUP);
AddInstruction(OpCode.ISNULL);
AddInstruction(OpCode.ISNULL); // null means these value are not initialized, then we should load them from backing field
JumpTarget ifFalse = new();
Jump(OpCode.JMPIFNOT_L, ifFalse);
AddInstruction(OpCode.DROP); // Drop the DUPed value
if (Symbol.IsStatic || !NeedInstanceConstructor(Symbol) || _context.ContractFields.Any(f =>
SymbolEqualityComparer.Default.Equals(f.Field, backingField)))
{
AddInstruction(OpCode.DROP);
AddInstruction(OpCode.PUSH0);
byte backingFieldIndex = _context.AddStaticField(backingField);
AccessSlot(OpCode.LDSFLD, backingFieldIndex);
}
else if (NeedInstanceConstructor(Symbol))
{
AddInstruction(OpCode.DUP);
fields = fields.Where(p => !p.IsStatic).ToArray();
int backingFieldIndex = Array.FindIndex(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property));
if (!_inline) AccessSlot(OpCode.LDARG, 0);
Push(backingFieldIndex);
AddInstruction(OpCode.PICKITEM);
}
ifFalse.Instruction = AddInstruction(OpCode.NOP);
break;
Expand All @@ -178,33 +194,19 @@ private void ConvertStorageBackedProperty(IPropertySymbol property, AttributeDat
case "UInt160":
case "UInt256":
case "ECPoint":
// but for those whose default value is null, it is impossible to know if
// the value has being initialized or not.
// TODO: figure out a way to initialize storebacked fields when deploy.
break;
default:
CallContractMethod(NativeContract.StdLib.Hash, "deserialize", 1, true);
break;
}
if (Symbol.IsStatic)
{
AddInstruction(OpCode.DUP);
IFieldSymbol backingField = Array.Find(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property))!;
byte backingFieldIndex = _context.AddStaticField(backingField);
AccessSlot(OpCode.STSFLD, backingFieldIndex);
}
else if (NeedInstanceConstructor(Symbol))
{
AddInstruction(OpCode.DUP);
fields = fields.Where(p => !p.IsStatic).ToArray();
int backingFieldIndex = Array.FindIndex(fields, p => SymbolEqualityComparer.Default.Equals(p.AssociatedSymbol, property));
AccessSlot(OpCode.LDARG, 0);
Push(backingFieldIndex);
AddInstruction(OpCode.ROT);
AddInstruction(OpCode.SETITEM);
}
endTarget.Instruction = AddInstruction(OpCode.NOP);
}
else
else if (Symbol.MethodKind == MethodKind.PropertySet) // explicitly use `else if` instead of `if` to improve readability.
{
if (Symbol.IsStatic || !NeedInstanceConstructor(Symbol))
if (Symbol.IsStatic || !NeedInstanceConstructor(Symbol) || _context.ContractFields.Any(f =>
SymbolEqualityComparer.Default.Equals(f.Field, backingField)))
AccessSlot(OpCode.LDARG, 0);
else
AccessSlot(OpCode.LDARG, 1);
Expand Down
51 changes: 32 additions & 19 deletions src/Neo.Compiler.CSharp/MethodConvert/SourceConvert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Neo.VM;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;

Expand Down Expand Up @@ -143,6 +144,8 @@ private static bool IsExpressionReturningValue(SemanticModel semanticModel, Meth
return false;
}

internal static Dictionary<IMethodSymbol, bool> _cacheNeedInstanceConstructor = new();

/// <summary>
/// non-static methods needs constructors to be executed
/// But non-static method in smart contract classes without explicit constructor
Expand All @@ -155,27 +158,37 @@ private static bool IsExpressionReturningValue(SemanticModel semanticModel, Meth
/// <returns></returns>
internal static bool NeedInstanceConstructor(IMethodSymbol symbol)
{
if (symbol.IsStatic || symbol.MethodKind == MethodKind.AnonymousFunction)
return false;
INamedTypeSymbol? containingClass = symbol.ContainingType;
if (containingClass == null) return false;
// non-static methods in class
if ((symbol.MethodKind == MethodKind.Constructor || symbol.MethodKind == MethodKind.SharedConstructor)
&& !CompilationEngine.IsDerivedFromSmartContract(containingClass))
// is constructor, and is not smart contract
// typically seen in framework methods
return true;
// is smart contract, or is normal non-static method (whether contract or not)
if (containingClass?.Constructors
.FirstOrDefault(p => p.Parameters.Length == 0 && !p.IsStatic)?
.DeclaringSyntaxReferences.Length == 0)
// No explicit non-static constructor in class
if (_cacheNeedInstanceConstructor.TryGetValue(symbol, out bool result))
return result;
static bool NeedInstanceConstructorInner(IMethodSymbol symbol)
{
if (s_pattern.IsMatch(containingClass.BaseType?.ToString() ?? string.Empty))
// class itself is directly inheriting smart contract; cannot have more base classes
if (symbol.IsStatic || symbol.MethodKind == MethodKind.AnonymousFunction)
return false;
INamedTypeSymbol? containingClass = symbol.ContainingType;
if (containingClass == null)
return false;
// non-static methods in class
if ((symbol.MethodKind == MethodKind.Constructor || symbol.MethodKind == MethodKind.SharedConstructor)
&& !CompilationEngine.IsDerivedFromSmartContract(containingClass))
// is constructor, and is not smart contract
// typically seen in framework methods
return true;
if (containingClass!.Constructors
.FirstOrDefault(p => p.Parameters.Length == 0 && !p.IsStatic)?
.DeclaringSyntaxReferences.Length > 0)
// has explicit constructor
return true;
// No explicit non-static constructor in class
// is smart contract, or is normal non-static method (whether contract or not)
if (!s_pattern.IsMatch(containingClass?.BaseType?.ToString() ?? string.Empty))
// class itself is not directly inheriting smart contract; can have more base classes
return true;
// is non-static method, directly inheriting smart contract
if (containingClass!.GetFields().Any((IFieldSymbol f) => !f.IsStatic))
// has non-static fields
return true;
return false;
}
// has explicit constructor
return true;
return _cacheNeedInstanceConstructor[symbol] = NeedInstanceConstructorInner(symbol);
}
}
Loading
Loading