From 60304cf3ff5ccf0c18a84efd508e0fd996fc721d Mon Sep 17 00:00:00 2001 From: Akira Hatanaka Date: Fri, 27 Sep 2024 07:04:47 -0700 Subject: [PATCH] [SILGen] Fix the type of closure thunks that are passed const reference structs (#75491) The thunk's parameter needs the @in_guaranteed convention if it's a const reference parameter. However, that convention wasn't being used because clang importer was removing the const reference from the type and SILGen was computing the type of the parameter based on the type without const reference. This commit fixes the bug by passing the clang function type to SILDeclRef so that it can be used to compute the correct thunk type. This fixes a crash when a closure is passed to a C function taking a pointer to a function that has a const reference struct parameter. rdar://131321096 --- include/swift/SIL/SILBridging.h | 6 +- include/swift/SIL/SILDeclRef.h | 16 ++-- lib/SIL/IR/SILDeclRef.cpp | 10 ++- lib/SIL/IR/SILFunctionType.cpp | 14 +-- lib/SILGen/SILGenBridging.cpp | 10 ++- lib/SILGen/SILGenExpr.cpp | 26 +++++- test/Interop/Cxx/class/Inputs/closure.h | 13 +++ .../Cxx/class/closure-thunk-macosx.swift | 85 +++++++++++++++++++ .../Interop/Cxx/stdlib/use-std-function.swift | 11 ++- 9 files changed, 164 insertions(+), 27 deletions(-) diff --git a/include/swift/SIL/SILBridging.h b/include/swift/SIL/SILBridging.h index c9be44e56d2d2..8fd948c7d73dd 100644 --- a/include/swift/SIL/SILBridging.h +++ b/include/swift/SIL/SILBridging.h @@ -1011,7 +1011,7 @@ struct BridgedSuccessorArray { }; struct BridgedDeclRef { - uint64_t storage[3]; + uint64_t storage[4]; #ifdef USED_IN_CPP_SOURCE BridgedDeclRef(swift::SILDeclRef declRef) { @@ -1029,7 +1029,7 @@ struct BridgedDeclRef { }; struct BridgedVTableEntry { - uint64_t storage[5]; + uint64_t storage[6]; enum class Kind { Normal, @@ -1077,7 +1077,7 @@ struct OptionalBridgedVTable { }; struct BridgedWitnessTableEntry { - uint64_t storage[5]; + uint64_t storage[6]; enum class Kind { invalid, diff --git a/include/swift/SIL/SILDeclRef.h b/include/swift/SIL/SILDeclRef.h index e24807b3bd8cc..df66dada68d34 100644 --- a/include/swift/SIL/SILDeclRef.h +++ b/include/swift/SIL/SILDeclRef.h @@ -32,6 +32,10 @@ namespace llvm { class raw_ostream; } +namespace clang { +class Type; +} + namespace swift { enum class EffectsKind : uint8_t; class AbstractFunctionDecl; @@ -204,6 +208,9 @@ struct SILDeclRef { const GenericSignatureImpl *, CustomAttr *> pointer; + // Type of closure thunk. + const clang::Type *thunkType = nullptr; + /// Returns the type of AST node location being stored by the SILDeclRef. LocKind getLocKind() const { if (loc.is()) @@ -257,11 +264,10 @@ struct SILDeclRef { /// for the containing ClassDecl. /// - If 'loc' is a global VarDecl, this returns its GlobalAccessor /// SILDeclRef. - explicit SILDeclRef( - Loc loc, - bool isForeign = false, - bool isDistributed = false, - bool isDistributedLocal = false); + explicit SILDeclRef(Loc loc, bool isForeign = false, + bool isDistributed = false, + bool isDistributedLocal = false, + const clang::Type *thunkType = nullptr); /// See above put produces a prespecialization according to the signature. explicit SILDeclRef(Loc loc, GenericSignature prespecializationSig); diff --git a/lib/SIL/IR/SILDeclRef.cpp b/lib/SIL/IR/SILDeclRef.cpp index 7cb41ee10607a..ff35b247b3219 100644 --- a/lib/SIL/IR/SILDeclRef.cpp +++ b/lib/SIL/IR/SILDeclRef.cpp @@ -135,11 +135,15 @@ SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind, bool isForeign, isAsyncLetClosure(0), pointer(derivativeId) {} SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign, - bool asDistributed, bool asDistributedKnownToBeLocal) + bool asDistributed, bool asDistributedKnownToBeLocal, + const clang::Type *thunkType) : isRuntimeAccessible(false), backDeploymentKind(SILDeclRef::BackDeploymentKind::None), defaultArgIndex(0), isAsyncLetClosure(0), - pointer((AutoDiffDerivativeFunctionIdentifier *)nullptr) { + pointer((AutoDiffDerivativeFunctionIdentifier *)nullptr), + thunkType(thunkType) { + assert((!thunkType || baseLoc.is()) && + "thunk type is needed only for closures"); if (auto *vd = baseLoc.dyn_cast()) { if (auto *fd = dyn_cast(vd)) { // Map FuncDecls directly to Func SILDeclRefs. @@ -169,6 +173,8 @@ SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign, llvm_unreachable("invalid loc decl for SILDeclRef!"); } } else if (auto *ACE = baseLoc.dyn_cast()) { + assert((!asForeign || thunkType) && + "thunk type needed for foreign type for closures"); loc = ACE; kind = Kind::Func; if (ACE->getASTContext().LangOpts.hasFeature( diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 4bad181e26b1b..6f2e96eb6a690 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -3965,12 +3965,10 @@ static CanSILFunctionType getUncachedSILFunctionTypeForConstant( // The type of the native-to-foreign thunk for a swift closure. if (constant.isForeign && constant.hasClosureExpr() && shouldStoreClangType(TC.getDeclRefRepresentation(constant))) { - auto clangType = TC.Context.getClangFunctionType( - origLoweredInterfaceType->getParams(), - origLoweredInterfaceType->getResult(), - FunctionTypeRepresentation::CFunctionPointer); - AbstractionPattern pattern = - AbstractionPattern(origLoweredInterfaceType, clangType); + assert(!extInfoBuilder.getClangTypeInfo().empty() && + "clang type not found"); + AbstractionPattern pattern = AbstractionPattern( + origLoweredInterfaceType, extInfoBuilder.getClangTypeInfo().getType()); return getSILFunctionTypeForAbstractCFunction( TC, pattern, origLoweredInterfaceType, extInfoBuilder, constant); } @@ -4476,9 +4474,13 @@ getAbstractionPatternForConstant(ASTContext &ctx, SILDeclRef constant, if (!constant.isForeign) return AbstractionPattern(fnType); + if (constant.thunkType) + return AbstractionPattern(fnType, constant.thunkType); + auto bridgedFn = getBridgedFunction(constant); if (!bridgedFn) return AbstractionPattern(fnType); + const clang::Decl *clangDecl = bridgedFn->getClangDecl(); if (!clangDecl) return AbstractionPattern(fnType); diff --git a/lib/SILGen/SILGenBridging.cpp b/lib/SILGen/SILGenBridging.cpp index f9f096afc50e9..63961b0fad313 100644 --- a/lib/SILGen/SILGenBridging.cpp +++ b/lib/SILGen/SILGenBridging.cpp @@ -1315,8 +1315,9 @@ static SILValue emitObjCUnconsumedArgument(SILGenFunction &SGF, SILLocation loc, SILValue arg) { auto &lowering = SGF.getTypeLowering(arg->getType()); - // If address-only, make a +1 copy and operate on that. - if (lowering.isAddressOnly() && SGF.useLoweredAddresses()) { + // If arg is non-trivial and has an address type, make a +1 copy and operate + // on that. + if (!lowering.isTrivial() && arg->getType().isAddress()) { auto tmp = SGF.emitTemporaryAllocation(loc, arg->getType().getObjectType()); SGF.B.createCopyAddr(loc, arg, tmp, IsNotTake, IsInitialization); return tmp; @@ -1448,6 +1449,11 @@ emitObjCThunkArguments(SILGenFunction &SGF, SILLocation loc, SILDeclRef thunk, auto buf = SGF.emitTemporaryAllocation(loc, native.getType()); native.forwardInto(SGF, loc, buf); native = SGF.emitManagedBufferWithCleanup(buf); + } else if (!fnConv.isSILIndirect(nativeInputs[i]) && + native.getType().isAddress()) { + // Load the value if the argument has an address type and the native + // function expects the argument to be passed directly. + native = SGF.emitManagedLoadCopy(loc, native.getValue()); } if (nativeInputs[i].isConsumedInCaller()) { diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp index 69f22012c38ff..e47239a78748e 100644 --- a/lib/SILGen/SILGenExpr.cpp +++ b/lib/SILGen/SILGenExpr.cpp @@ -1723,7 +1723,19 @@ static ManagedValue convertCFunctionSignature(SILGenFunction &SGF, FunctionConversionExpr *e, SILType loweredResultTy, llvm::function_ref fnEmitter) { - SILType loweredDestTy = SGF.getLoweredType(e->getType()); + SILType loweredDestTy; + auto destTy = e->getType(); + auto clangInfo = + destTy->castTo()->getExtInfo().getClangTypeInfo(); + if (clangInfo.empty()) + loweredDestTy = SGF.getLoweredType(destTy); + else + // This won't be necessary after we stop dropping clang types when + // canonicalizing function types. + loweredDestTy = SGF.getLoweredType( + AbstractionPattern(destTy->getCanonicalType(), clangInfo.getType()), + destTy); + ManagedValue result; // We're converting between C function pointer types. They better be @@ -1794,7 +1806,9 @@ ManagedValue emitCFunctionPointer(SILGenFunction &SGF, #endif semanticExpr = conv->getSubExpr()->getSemanticsProvidingExpr(); } - + + const clang::Type *destFnType = nullptr; + if (auto declRef = dyn_cast(semanticExpr)) { setLocFromConcreteDeclRef(declRef->getDeclRef()); } else if (auto memberRef = dyn_cast(semanticExpr)) { @@ -1808,12 +1822,18 @@ ManagedValue emitCFunctionPointer(SILGenFunction &SGF, loc = closure; return ManagedValue(); }); + auto clangInfo = conversionExpr->getType() + ->castTo() + ->getExtInfo() + .getClangTypeInfo(); + if (!clangInfo.empty()) + destFnType = clangInfo.getType(); } else { llvm_unreachable("c function pointer converted from a non-concrete decl ref"); } // Produce a reference to the C-compatible entry point for the function. - SILDeclRef constant(loc, /*foreign*/ true); + SILDeclRef constant(loc, /*foreign*/ true, false, false, destFnType); SILConstantInfo constantInfo = SGF.getConstantInfo(SGF.getTypeExpansionContext(), constant); diff --git a/test/Interop/Cxx/class/Inputs/closure.h b/test/Interop/Cxx/class/Inputs/closure.h index 7081ff54c7b08..fa04d441f55df 100644 --- a/test/Interop/Cxx/class/Inputs/closure.h +++ b/test/Interop/Cxx/class/Inputs/closure.h @@ -10,6 +10,10 @@ struct NonTrivial { int *p; }; +struct Trivial { + int i; +}; + void cfunc(void (^ _Nonnull block)(NonTrivial)) noexcept { block(NonTrivial()); } @@ -45,4 +49,13 @@ void cfuncARCWeak(ARCWeak) noexcept; void (* _Nonnull getFnPtr() noexcept)(NonTrivial) noexcept; void (* _Nonnull getFnPtr2() noexcept)(ARCWeak) noexcept; +void cfuncConstRefNonTrivial(void (*_Nonnull)(const NonTrivial &)); +void cfuncConstRefTrivial(void (*_Nonnull)(const Trivial &)); +void blockConstRefNonTrivial(void (^_Nonnull)(const NonTrivial &)); +void blockConstRefTrivial(void (^_Nonnull)(const Trivial &)); +#if __OBJC__ +void cfuncConstRefStrong(void (*_Nonnull)(const ARCStrong &)); +void blockConstRefStrong(void (^_Nonnull)(const ARCStrong &)); +#endif + #endif // __CLOSURE__ diff --git a/test/Interop/Cxx/class/closure-thunk-macosx.swift b/test/Interop/Cxx/class/closure-thunk-macosx.swift index 3fedb04ca1328..f8ce0a681e00c 100644 --- a/test/Interop/Cxx/class/closure-thunk-macosx.swift +++ b/test/Interop/Cxx/class/closure-thunk-macosx.swift @@ -35,3 +35,88 @@ public func testClosureToFuncPtr() { public func testClosureToBlockReturnNonTrivial() { cfuncReturnNonTrivial({() -> NonTrivial in return NonTrivial() }) } + +// CHECK-LABEL: sil private [thunk] [ossa] @$s4main22testConstRefNonTrivialyyFySo0eF0VcfU_To : $@convention(c) (@in_guaranteed NonTrivial) -> () { +// CHECK: bb0(%[[V0:.*]] : $*NonTrivial): +// CHECK: %[[V1:.*]] = alloc_stack $NonTrivial +// CHECK: copy_addr %[[V0]] to [init] %[[V1]] : $*NonTrivial +// CHECK: %[[V3:.*]] = function_ref @$s4main22testConstRefNonTrivialyyFySo0eF0VcfU_ : $@convention(thin) (@in_guaranteed NonTrivial) -> () +// CHECK: %[[V4:.*]] = apply %[[V3]](%[[V1]]) : $@convention(thin) (@in_guaranteed NonTrivial) -> () +// CHECK: destroy_addr %[[V1]] : $*NonTrivial +// CHECK: dealloc_stack %[[V1]] : $*NonTrivial +// CHECK: return %[[V4]] : $() + +public func testConstRefNonTrivial() { + cfuncConstRefNonTrivial({S in }); +} + +// CHECK-LABEL: sil private [thunk] [ossa] @$s4main19testConstRefTrivialyyFySo0E0VcfU_To : $@convention(c) (@in_guaranteed Trivial) -> () { +// CHECK: bb0(%[[V0:.*]] : $*Trivial): +// CHECK: %[[V1:.*]] = load [trivial] %[[V0]] : $*Trivial +// CHECK: %[[V2:.*]] = function_ref @$s4main19testConstRefTrivialyyFySo0E0VcfU_ : $@convention(thin) (Trivial) -> () +// CHECK: %[[V3:.*]] = apply %[[V2]](%[[V1]]) : $@convention(thin) (Trivial) -> () +// CHECK: return %[[V3]] : $() + +public func testConstRefTrivial() { + cfuncConstRefTrivial({S in }); +} + +// CHECK-LABEL: sil private [thunk] [ossa] @$s4main18testConstRefStrongyyFySo9ARCStrongVcfU_To : $@convention(c) (@in_guaranteed ARCStrong) -> () { +// CHECK: bb0(%[[V0:.*]] : $*ARCStrong): +// CHECK: %[[V1:.*]] = alloc_stack $ARCStrong +// CHECK: copy_addr %[[V0]] to [init] %[[V1]] : $*ARCStrong +// CHECK: %[[V3:.*]] = load [copy] %[[V1]] : $*ARCStrong +// CHECK: %[[V4:.*]] = begin_borrow %[[V3]] : $ARCStrong +// CHECK: %[[V5:.*]] = function_ref @$s4main18testConstRefStrongyyFySo9ARCStrongVcfU_ : $@convention(thin) (@guaranteed ARCStrong) -> () +// CHECK: %[[V6:.*]] = apply %[[V5]](%[[V4]]) : $@convention(thin) (@guaranteed ARCStrong) -> () +// CHECK: end_borrow %[[V4]] : $ARCStrong +// CHECK: destroy_value %[[V3]] : $ARCStrong +// CHECK: destroy_addr %[[V1]] : $*ARCStrong +// CHECK: dealloc_stack %[[V1]] : $*ARCStrong +// CHECK: return %[[V6]] : $() + +public func testConstRefStrong() { + cfuncConstRefStrong({S in }); +} + +// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @$sSo10NonTrivialVIegn_ABIeyBn_TR : $@convention(c) (@inout_aliasable @block_storage @callee_guaranteed (@in_guaranteed NonTrivial) -> (), @in_guaranteed NonTrivial) -> () { +// CHECK: bb0(%[[V0:.*]] : $*@block_storage @callee_guaranteed (@in_guaranteed NonTrivial) -> (), %[[V1:.*]] : $*NonTrivial): +// CHECK: %[[V2:.*]] = project_block_storage %[[V0]] : $*@block_storage @callee_guaranteed (@in_guaranteed NonTrivial) -> () +// CHECK: %[[V3:.*]] = load [copy] %[[V2]] : $*@callee_guaranteed (@in_guaranteed NonTrivial) -> () +// CHECK: %[[V4:.*]] = begin_borrow %[[V3]] : $@callee_guaranteed (@in_guaranteed NonTrivial) -> () +// CHECK: apply %[[V4]](%[[V1]]) : $@callee_guaranteed (@in_guaranteed NonTrivial) -> () +// CHECK: end_borrow %[[V4]] : $@callee_guaranteed (@in_guaranteed NonTrivial) -> () +// CHECK: destroy_value %[[V3]] : $@callee_guaranteed (@in_guaranteed NonTrivial) -> () + +public func testBlockConstRefNonTrivial() { + blockConstRefNonTrivial({S in }); +} + +// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @$sSo7TrivialVIegy_ABIeyBn_TR : $@convention(c) (@inout_aliasable @block_storage @callee_guaranteed (Trivial) -> (), @in_guaranteed Trivial) -> () { +// CHECK: bb0(%[[V0:.*]] : $*@block_storage @callee_guaranteed (Trivial) -> (), %[[V1:.*]] : $*Trivial): +// CHECK: %[[V2:.*]] = project_block_storage %[[V0]] : $*@block_storage @callee_guaranteed (Trivial) -> () +// CHECK: %[[V3:.*]] = load [copy] %[[V2]] : $*@callee_guaranteed (Trivial) -> () +// CHECK: %[[V4:.*]] = load [trivial] %[[V1]] : $*Trivial +// CHECK: %[[V5:.*]] = begin_borrow %[[V3]] : $@callee_guaranteed (Trivial) -> () +// CHECK: apply %[[V5]](%[[V4]]) : $@callee_guaranteed (Trivial) -> () +// CHECK: end_borrow %[[V5]] : $@callee_guaranteed (Trivial) -> () +// CHECK: destroy_value %[[V3]] : $@callee_guaranteed (Trivial) -> () + +public func testBlockConstRefTrivial() { + blockConstRefTrivial({S in }); +} + +// CHECK-LABEL: sil shared [transparent] [serialized] [reabstraction_thunk] [ossa] @$sSo9ARCStrongVIegg_ABIeyBn_TR : $@convention(c) (@inout_aliasable @block_storage @callee_guaranteed (@guaranteed ARCStrong) -> (), @in_guaranteed ARCStrong) -> () { +// CHECK: bb0(%[[V0:.*]] : $*@block_storage @callee_guaranteed (@guaranteed ARCStrong) -> (), %[[V1:.*]] : $*ARCStrong): +// CHECK: %[[V2:.*]] = project_block_storage %[[V0]] : $*@block_storage @callee_guaranteed (@guaranteed ARCStrong) -> () +// CHECK: %[[V3:.*]] = load [copy] %[[V2]] : $*@callee_guaranteed (@guaranteed ARCStrong) -> () +// CHECK: %[[V4:.*]] = load_borrow %[[V1]] : $*ARCStrong +// CHECK: %[[V5:.*]] = begin_borrow %[[V3]] : $@callee_guaranteed (@guaranteed ARCStrong) -> () +// CHECK: apply %[[V5]](%[[V4]]) : $@callee_guaranteed (@guaranteed ARCStrong) -> () +// CHECK: end_borrow %[[V5]] : $@callee_guaranteed (@guaranteed ARCStrong) -> () +// CHECK: end_borrow %[[V4]] : $ARCStrong +// CHECK: destroy_value %[[V3]] : $@callee_guaranteed (@guaranteed ARCStrong) -> () + +public func testBlockConstRefStrong() { + blockConstRefStrong({S in }); +} diff --git a/test/Interop/Cxx/stdlib/use-std-function.swift b/test/Interop/Cxx/stdlib/use-std-function.swift index 9ad06c4e3d8e8..55df7f605cb41 100644 --- a/test/Interop/Cxx/stdlib/use-std-function.swift +++ b/test/Interop/Cxx/stdlib/use-std-function.swift @@ -64,12 +64,11 @@ StdFunctionTestSuite.test("FunctionStringToString init from closure and pass as expectEqual(std.string("prefixabcabc"), res) } -// FIXME: assertion for address-only closure params (rdar://124501345) -//StdFunctionTestSuite.test("FunctionStringToStringConstRef init from closure and pass as parameter") { -// let res = invokeFunctionTwiceConstRef(.init({ $0 + std.string("abc") }), -// std.string("prefix")) -// expectEqual(std.string("prefixabcabc"), res) -//} +StdFunctionTestSuite.test("FunctionStringToStringConstRef init from closure and pass as parameter") { + let res = invokeFunctionTwiceConstRef(.init({ $0 + std.string("abc") }), + std.string("prefix")) + expectEqual(std.string("prefixabcabc"), res) +} #endif runAllTests()