[Strict memory safety] Fix "unsafe" checking for the for..in loop

The `$generator` variable we create for the async for..in loop is
`nonisolated(unsafe)`, so ensure that we generate an `unsafe`
expression when we use it. This uncovered some inconsistencies in how
we do `unsafe` checking for for..in loops, so fix those.

Fixes rdar://154775389.
This commit is contained in:
Doug Gregor
2025-07-09 16:19:14 -07:00
parent 1e87656e40
commit 35628cb503
4 changed files with 23 additions and 10 deletions

View File

@@ -4728,10 +4728,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
} }
// Wrap the 'next' call in 'unsafe', if the for..in loop has that // Wrap the 'next' call in 'unsafe', if the for..in loop has that
// effect. // effect or if the loop is async (in which case the iterator variable
if (stmt->getUnsafeLoc().isValid()) { // is nonisolated(unsafe).
nextCall = new (ctx) UnsafeExpr( if (stmt->getUnsafeLoc().isValid() ||
stmt->getUnsafeLoc(), nextCall, Type(), /*implicit=*/true); (isAsync &&
ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) {
SourceLoc loc = stmt->getUnsafeLoc();
if (loc.isInvalid())
loc = stmt->getForLoc();
nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), /*implicit=*/true);
} }
// The iterator type must conform to IteratorProtocol. // The iterator type must conform to IteratorProtocol.

View File

@@ -2457,7 +2457,7 @@ private:
return ShouldRecurse; return ShouldRecurse;
} }
ShouldRecurse_t checkUnsafe(UnsafeExpr *E) { ShouldRecurse_t checkUnsafe(UnsafeExpr *E) {
return E->isImplicit() ? ShouldRecurse : ShouldNotRecurse; return ShouldNotRecurse;
} }
ShouldRecurse_t checkTry(TryExpr *E) { ShouldRecurse_t checkTry(TryExpr *E) {
return ShouldRecurse; return ShouldRecurse;
@@ -4626,10 +4626,6 @@ private:
diagnoseUnsafeUse(unsafeUse); diagnoseUnsafeUse(unsafeUse);
} }
} }
} else if (S->getUnsafeLoc().isValid()) {
// Extraneous "unsafe" on the sequence.
Ctx.Diags.diagnose(S->getUnsafeLoc(), diag::no_unsafe_in_unsafe_for)
.fixItRemove(S->getUnsafeLoc());
} }
return ShouldRecurse; return ShouldRecurse;
@@ -4689,7 +4685,10 @@ private:
return; return;
} }
Ctx.Diags.diagnose(E->getUnsafeLoc(), diag::no_unsafe_in_unsafe) Ctx.Diags.diagnose(E->getUnsafeLoc(),
forEachNextCallExprs.contains(E)
? diag::no_unsafe_in_unsafe_for
: diag::no_unsafe_in_unsafe)
.fixItRemove(E->getUnsafeLoc()); .fixItRemove(E->getUnsafeLoc());
} }

View File

@@ -98,6 +98,8 @@ func testUnsafeAsSequenceForEach() {
for _ in unsafe uas { } // expected-warning{{for-in loop uses unsafe constructs but is not marked with 'unsafe'}}{{documentation-file=strict-memory-safety}}{{7-7=unsafe }} for _ in unsafe uas { } // expected-warning{{for-in loop uses unsafe constructs but is not marked with 'unsafe'}}{{documentation-file=strict-memory-safety}}{{7-7=unsafe }}
for unsafe _ in unsafe uas { } // okay for unsafe _ in unsafe uas { } // okay
for unsafe _ in [1, 2, 3] { } // expected-warning{{no unsafe operations occur within 'unsafe' for-in loop}}
} }
func testForInUnsafeAmbiguity(_ integers: [Int]) { func testForInUnsafeAmbiguity(_ integers: [Int]) {

View File

@@ -55,3 +55,10 @@ open class SyntaxVisitor {
open class SyntaxAnyVisitor: SyntaxVisitor { open class SyntaxAnyVisitor: SyntaxVisitor {
override open func visit(_ token: TokenSyntax) { } override open func visit(_ token: TokenSyntax) { }
} }
@available(SwiftStdlib 5.1, *)
func testMemorySafetyWithForLoop() async {
let (stream, continuation) = AsyncStream<Int>.makeStream()
for await _ in stream {}
_ = continuation
}