//===--- BindingInferenceTests.cpp ----------------------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// #include "SemaFixture.h" #include "swift/AST/Expr.h" #include "swift/Sema/ConstraintSystem.h" #include "llvm/ADT/SmallPtrSet.h" using namespace swift; using namespace swift::unittest; using namespace swift::constraints; TEST_F(SemaTest, TestIntLiteralBindingInference) { ConstraintSystemOptions options; options |= ConstraintSystemFlags::AllowUnresolvedTypeVariables; ConstraintSystem cs(DC, options); auto *intLiteral = IntegerLiteralExpr::createFromUnsigned(Context, 42); auto *literalTy = cs.createTypeVariable(cs.getConstraintLocator(intLiteral), /*options=*/0); cs.addConstraint( ConstraintKind::LiteralConformsTo, literalTy, Context.getProtocol(KnownProtocolKind::ExpressibleByIntegerLiteral) ->getDeclaredInterfaceType(), cs.getConstraintLocator(intLiteral)); auto bindings = cs.inferBindingsFor(literalTy); ASSERT_EQ(bindings.Bindings.size(), (unsigned)1); const auto &binding = bindings.Bindings.front(); ASSERT_TRUE(binding.BindingType->isEqual(getStdlibType("Int"))); ASSERT_TRUE(binding.hasDefaultedLiteralProtocol()); } // Given a set of inferred protocol requirements, make sure that // all of the expected types are present. static void verifyProtocolInferenceResults( const llvm::SmallPtrSetImpl &protocols, ArrayRef expectedTypes) { ASSERT_TRUE(protocols.size() >= expectedTypes.size()); llvm::SmallPtrSet inferredProtocolTypes; for (auto *protocol : protocols) inferredProtocolTypes.insert(protocol->getSecondType()); for (auto expectedTy : expectedTypes) { ASSERT_TRUE(inferredProtocolTypes.count(expectedTy)); } } TEST_F(SemaTest, TestTransitiveProtocolInference) { ConstraintSystemOptions options; ConstraintSystem cs(DC, options); auto *protocolTy1 = createProtocol("P1"); auto *protocolTy2 = createProtocol("P2"); auto *GPT1 = cs.createTypeVariable(cs.getConstraintLocator({}), /*options=*/TVO_CanBindToNoEscape); auto *GPT2 = cs.createTypeVariable(cs.getConstraintLocator({}), /*options=*/TVO_CanBindToNoEscape); cs.addConstraint( ConstraintKind::ConformsTo, GPT1, protocolTy1, cs.getConstraintLocator({}, LocatorPathElt::TypeParameterRequirement( 0, RequirementKind::Conformance))); cs.addConstraint( ConstraintKind::ConformsTo, GPT2, protocolTy2, cs.getConstraintLocator({}, LocatorPathElt::TypeParameterRequirement( 0, RequirementKind::Conformance))); // First, let's try inferring through a single conversion // relationship. { auto *typeVar = cs.createTypeVariable(cs.getConstraintLocator({}), /*options=*/0); cs.addConstraint( ConstraintKind::Conversion, typeVar, GPT1, cs.getConstraintLocator({}, LocatorPathElt::ContextualType())); auto bindings = inferBindings(cs, typeVar); ASSERT_TRUE(bindings.Protocols.empty()); ASSERT_TRUE(bool(bindings.TransitiveProtocols)); verifyProtocolInferenceResults(*bindings.TransitiveProtocols, {protocolTy1}); } // Now, let's make sure that protocol requirements could be propagated // down conversion/equality chains through multiple hops. { // GPT1 is a subtype of GPT2 and GPT2 is convertible to a target type // variable, target should get both protocols inferred - P1 & P2. auto *typeVar = cs.createTypeVariable(cs.getConstraintLocator({}), /*options=*/0); cs.addConstraint(ConstraintKind::Subtype, GPT1, GPT2, cs.getConstraintLocator({})); cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1, cs.getConstraintLocator({})); auto bindings = inferBindings(cs, typeVar); ASSERT_TRUE(bindings.Protocols.empty()); ASSERT_TRUE(bool(bindings.TransitiveProtocols)); verifyProtocolInferenceResults(*bindings.TransitiveProtocols, {protocolTy1, protocolTy2}); } } /// Let's try a more complicated situation where there protocols /// are inferred from multiple sources on different levels of /// convertion chain. /// /// (P1) T0 T4 (T3) T6 (P4) /// \ / / /// T3 = T1 (P2) = T5 /// \ / /// T2 TEST_F(SemaTest, TestComplexTransitiveProtocolInference) { ConstraintSystemOptions options; ConstraintSystem cs(DC, options); auto *protocolTy1 = createProtocol("P1"); auto *protocolTy2 = createProtocol("P2"); auto *protocolTy3 = createProtocol("P3"); auto *protocolTy4 = createProtocol("P4"); auto *nilLocator = cs.getConstraintLocator({}); auto typeVar0 = cs.createTypeVariable(nilLocator, /*options=*/0); auto typeVar1 = cs.createTypeVariable(nilLocator, /*options=*/0); auto typeVar2 = cs.createTypeVariable(nilLocator, /*options=*/0); // Allow this type variable to be bound to l-value type to prevent // it from being merged with the rest of the type variables. auto typeVar3 = cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue); auto typeVar4 = cs.createTypeVariable(nilLocator, /*options=*/0); auto typeVar5 = cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue); auto typeVar6 = cs.createTypeVariable(nilLocator, /*options=*/0); cs.addConstraint(ConstraintKind::ConformsTo, typeVar0, protocolTy1, nilLocator); cs.addConstraint(ConstraintKind::ConformsTo, typeVar1, protocolTy2, nilLocator); cs.addConstraint(ConstraintKind::ConformsTo, typeVar4, protocolTy3, nilLocator); cs.addConstraint(ConstraintKind::ConformsTo, typeVar6, protocolTy4, nilLocator); // T3 <: T0, T3 <: T4 cs.addConstraint(ConstraintKind::Conversion, typeVar3, typeVar0, nilLocator); cs.addConstraint(ConstraintKind::Conversion, typeVar3, typeVar4, nilLocator); // T2 <: T3, T2 <: T1, T3 == T1 cs.addConstraint(ConstraintKind::Subtype, typeVar2, typeVar3, nilLocator); cs.addConstraint(ConstraintKind::Conversion, typeVar2, typeVar1, nilLocator); cs.addConstraint(ConstraintKind::Equal, typeVar3, typeVar1, nilLocator); // T1 == T5, T <: T6 cs.addConstraint(ConstraintKind::Equal, typeVar1, typeVar5, nilLocator); cs.addConstraint(ConstraintKind::Conversion, typeVar5, typeVar6, nilLocator); auto bindingsForT1 = inferBindings(cs, typeVar1); auto bindingsForT2 = inferBindings(cs, typeVar2); auto bindingsForT3 = inferBindings(cs, typeVar3); auto bindingsForT5 = inferBindings(cs, typeVar5); ASSERT_TRUE(bool(bindingsForT1.TransitiveProtocols)); verifyProtocolInferenceResults(*bindingsForT1.TransitiveProtocols, {protocolTy1, protocolTy3, protocolTy4}); ASSERT_TRUE(bool(bindingsForT2.TransitiveProtocols)); verifyProtocolInferenceResults( *bindingsForT2.TransitiveProtocols, {protocolTy1, protocolTy2, protocolTy3, protocolTy4}); ASSERT_TRUE(bool(bindingsForT3.TransitiveProtocols)); verifyProtocolInferenceResults( *bindingsForT3.TransitiveProtocols, {protocolTy1, protocolTy2, protocolTy3, protocolTy4}); ASSERT_TRUE(bool(bindingsForT5.TransitiveProtocols)); verifyProtocolInferenceResults( *bindingsForT5.TransitiveProtocols, {protocolTy1, protocolTy2, protocolTy3, protocolTy4}); }