[TaskLocals] set task local value in synchronous function

This commit is contained in:
Konrad `ktoso` Malawski
2021-04-22 16:37:54 +09:00
parent f0781b1f8b
commit df5ff42d79
2 changed files with 90 additions and 1 deletions

View File

@@ -103,6 +103,29 @@ extension TaskLocal {
}
}
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
extension UnsafeCurrentTask {
/// Allows for executing a synchronous `body` while binding a task-local value
/// in the current task.
///
/// This function MUST NOT be invoked by any other task than the current task
/// represented by this object.
@discardableResult
public func withTaskLocal<Value: Sendable, R>(
_ access: TaskLocal<Value>.Access, boundTo valueDuringBody: Value,
do body: () throws -> R,
file: String = #file, line: UInt = #line) rethrows -> R {
// check if we're not trying to bind a value from an illegal context; this may crash
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
_taskLocalValuePush(self._task, key: access.key, value: valueDuringBody)
defer { _taskLocalValuePop(_task) }
return try body()
}
}
// ==== ------------------------------------------------------------------------
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)

View File

@@ -0,0 +1,66 @@
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-concurrency -parse-as-library %import-libdispatch) | %FileCheck %s
// REQUIRES: executable_test
// REQUIRES: concurrency
// REQUIRES: libdispatch
// rdar://76038845
// UNSUPPORTED: use_os_stdlib
// UNSUPPORTED: back_deployment_runtime
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
enum TL {
@TaskLocal(default: 0)
static var number
}
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
@discardableResult
func printTaskLocal<V>(
_ key: TaskLocal<V>.Access,
_ expected: V? = nil,
file: String = #file, line: UInt = #line
) -> V? {
let value = key.get()
print("\(key) (\(value)) at \(file):\(line)")
if let expected = expected {
assert("\(expected)" == "\(value)",
"Expected [\(expected)] but found: \(value), at \(file):\(line)")
}
return expected
}
// ==== ------------------------------------------------------------------------
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
func synchronous_bind() async {
func synchronous() {
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (1111)
withUnsafeCurrentTask { task in
guard let task = task else {
fatalError()
}
task.withTaskLocal(TL.number, boundTo: 2222) {
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (2222)
}
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (1111)
}
printTaskLocal(TL.number) // CHECK: TaskLocal<Int>.Access (1111)
}
await TL.number.withValue(1111) {
synchronous()
}
}
@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *)
@main struct Main {
static func main() async {
await synchronous_bind()
}
}