[Stdlib] Improves sort and sorted to accept throwing clousre

This commit resolves https://bugs.swift.org/browse/SR-715
This commit is contained in:
codestergit
2017-02-24 13:00:52 +05:30
parent 3019898e06
commit aa9e9edc8a
7 changed files with 324 additions and 42 deletions

View File

@@ -26,6 +26,8 @@ public struct PartitionExhaustiveTest {
}
}
enum SillyError : Error { case JazzHands }
public let partitionExhaustiveTests = [
PartitionExhaustiveTest([]),
PartitionExhaustiveTest([ 10 ]),
@@ -48,6 +50,20 @@ public let partitionExhaustiveTests = [
PartitionExhaustiveTest([ 10, 20, 30, 40, 50, 60 ]),
]
//Random collection of 30 elements
public let largeElementSortTests = [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30],
[30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18,
17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
[30, 29, 28, 27, 26, 25, 20, 19, 18, 5, 4, 3, 2, 1,
15, 14, 13, 12, 11, 10, 9, 24, 23, 22, 21, 8, 17, 16, 7, 6],
[30, 29, 25, 20, 19, 18, 5, 4, 3, 2, 1, 28, 27, 26,
15, 14, 13, 12, 24, 23, 22, 21, 8, 17, 16, 7, 6, 11, 10, 9],
[3, 2, 1, 20, 19, 18, 5, 4, 28, 27, 26, 11, 10, 9,
15, 14, 13, 12, 24, 23, 22, 21, 8, 17, 16, 7, 6, 30, 29, 25],
]
public func withInvalidOrderings(_ body: (@escaping (Int, Int) -> Bool) -> Void) {
// Test some ordering predicates that don't create strict weak orderings
body { (_,_) in true }
@@ -463,6 +479,61 @@ self.test("\(testNamePrefix)._withUnsafeMutableBufferPointerIfSupported()/semant
// sort()
//===----------------------------------------------------------------------===//
func checkSortedPredicateThrow(
sequence: [Int],
lessImpl: ((Int, Int) -> Bool),
throwIndex: Int
) {
let extract = extractValue
let throwElement = sequence[throwIndex]
var thrown = false
let elements: [OpaqueValue<Int>] =
zip(sequence, 0..<sequence.count).map {
OpaqueValue($0, identity: $1)
}
var result: [C.Iterator.Element] = []
let c = makeWrappedCollection(elements)
let closureLifetimeTracker = LifetimeTracked(0)
do {
result = try c.sorted {
(lhs, rhs) throws -> Bool in
_blackHole(closureLifetimeTracker)
if throwElement == extractValue(rhs).value {
thrown = true
throw SillyError.JazzHands
}
return lessImpl(extractValue(lhs).value, extractValue(rhs).value)
}
} catch {}
// Check that the original collection is unchanged.
expectEqualSequence(
elements.map { $0.value },
c.map { extract($0).value })
// If `sorted` throws then result will be empty else
// returned result must be sorted.
if thrown {
expectEqual(0, result.count)
} else {
// Check that the elements are sorted.
let extractedResult = result.map(extract)
for i in extractedResult.indices {
if i != extractedResult.index(before: extractedResult.endIndex) {
let first = extractedResult[i].value
let second = extractedResult[extractedResult.index(after: i)].value
let result = lessImpl(second, first)
if result == true {
print("yep ** Result should be true \(result)")
} else {
print("yep ** Test passed reult is false")
}
expectFalse(result)
}
}
}
}
% for predicate in [False, True]:
self.test("\(testNamePrefix).sorted/DispatchesThrough_withUnsafeMutableBufferPointerIfSupported/${'Predicate' if predicate else 'WhereElementIsComparable'}") {
@@ -587,6 +658,30 @@ self.test("\(testNamePrefix).sorted/${'Predicate' if predicate else 'WhereElemen
}
}
self.test("\(testNamePrefix).sorted/ThrowingPredicate") {
for test in partitionExhaustiveTests {
forAllPermutations(test.sequence) { (sequence) in
for i in 0..<sequence.count {
checkSortedPredicateThrow(
sequence: sequence,
lessImpl: { $0 < $1 },
throwIndex: i)
}
}
}
}
self.test("\(testNamePrefix).sorted/ThrowingPredicateWithLargeNumberElememts") {
for sequence in largeElementSortTests {
for i in 0..<sequence.count {
checkSortedPredicateThrow(
sequence: sequence,
lessImpl: { $0 < $1 },
throwIndex: i)
}
}
}
% end
//===----------------------------------------------------------------------===//
@@ -921,6 +1016,36 @@ self.test("\(testNamePrefix).partition/DispatchesThrough_withUnsafeMutableBuffer
// sort()
//===----------------------------------------------------------------------===//
func checkSortPredicateThrow(
sequence: [Int],
lessImpl: ((Int, Int) -> Bool),
throwIndex: Int
) {
let extract = extractValue
let throwElement = sequence[throwIndex]
let elements: [OpaqueValue<Int>] =
zip(sequence, 0..<sequence.count).map {
OpaqueValue($0, identity: $1)
}
var c = makeWrappedCollection(elements)
let closureLifetimeTracker = LifetimeTracked(0)
do {
try c.sort {
(lhs, rhs) throws -> Bool in
_blackHole(closureLifetimeTracker)
if throwElement == extractValue(rhs).value {
throw SillyError.JazzHands
}
return lessImpl(extractValue(lhs).value, extractValue(rhs).value)
}
} catch {}
//Check no element should lost and added
expectEqualsUnordered(
sequence,
c.map { extract($0).value })
}
% for predicate in [False, True]:
func checkSortInPlace_${'Predicate' if predicate else 'WhereElementIsComparable'}(
@@ -1005,6 +1130,30 @@ self.test("\(testNamePrefix).sort/${'Predicate' if predicate else 'WhereElementI
}
}
self.test("\(testNamePrefix).sort/ThrowingPredicate") {
for test in partitionExhaustiveTests {
forAllPermutations(test.sequence) { (sequence) in
for i in 0..<sequence.count {
checkSortPredicateThrow(
sequence: sequence,
lessImpl: { $0 < $1 },
throwIndex: i)
}
}
}
}
self.test("\(testNamePrefix).sort/ThrowingPredicateWithLargeNumberElements") {
for sequence in largeElementSortTests {
for i in 0..<sequence.count {
checkSortPredicateThrow(
sequence: sequence,
lessImpl: { $0 < $1 },
throwIndex: i)
}
}
}
% end
//===----------------------------------------------------------------------===//

View File

@@ -339,10 +339,10 @@ ${orderingExplanation}
@_inlineable
public func sorted(
by areInIncreasingOrder:
(${IElement}, ${IElement}) -> Bool
) -> [Iterator.Element] {
(${IElement}, ${IElement}) throws -> Bool
) rethrows -> [Iterator.Element] {
var result = ContiguousArray(self)
result.sort(by: areInIncreasingOrder)
try result.sort(by: areInIncreasingOrder)
return Array(result)
}
}
@@ -398,6 +398,9 @@ extension MutableCollection where Self : RandomAccessCollection {
/// Sorts the collection in place, using the given predicate as the
/// comparison between elements.
///
/// This method can take throwing closure. If closure throws error while sorting,
/// order of elements may change. No elements will be lost.
///
/// When you want to sort a collection of elements that doesn't conform to
/// the `Comparable` protocol, pass a closure to this method that returns
/// `true` when the first element passed should be ordered before the
@@ -452,19 +455,19 @@ ${orderingExplanation}
@_inlineable
public mutating func sort(
by areInIncreasingOrder:
(${IElement}, ${IElement}) -> Bool
) {
(${IElement}, ${IElement}) throws -> Bool
) rethrows {
let didSortUnsafeBuffer: Void? =
_withUnsafeMutableBufferPointerIfSupported {
try _withUnsafeMutableBufferPointerIfSupported {
(baseAddress, count) -> Void in
var bufferPointer =
UnsafeMutableBufferPointer(start: baseAddress, count: count)
bufferPointer.sort(by: areInIncreasingOrder)
try bufferPointer.sort(by: areInIncreasingOrder)
return ()
}
if didSortUnsafeBuffer == nil {
_introSort(
try _introSort(
&self,
subRange: startIndex..<endIndex,
by: areInIncreasingOrder)

View File

@@ -13,7 +13,7 @@
%{
def cmp(a, b, p):
if p:
return "areInIncreasingOrder(" + a + ", " + b + ")"
return "(try areInIncreasingOrder(" + a + ", " + b + "))"
else:
return "(" + a + " < " + b + ")"
@@ -24,14 +24,23 @@ def cmp(a, b, p):
// need such a predicate.
% preds = [True, False]
% for p in preds:
%{
if p:
rethrows_ = "rethrows"
try_ = "try"
else:
rethrows_ = ""
try_ = ""
}%
@_inlineable
@_versioned
func _insertionSort<C>(
_ elements: inout C,
subRange range: Range<C.Index>
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
) where
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""}
) ${rethrows_}
where
C : MutableCollection & BidirectionalCollection
${"" if p else ", C.Iterator.Element : Comparable"} {
@@ -55,10 +64,23 @@ func _insertionSort<C>(
repeat {
let predecessor: C.Iterator.Element = elements[elements.index(before: i)]
% if p:
// If clouser throws the error, We catch the error put the element at right
// place and rethrow the error.
do {
// if x doesn't belong before y, we've found its position
if !${cmp("x", "predecessor", p)} {
break
}
} catch {
elements[i] = x
throw error
}
% else:
if !${cmp("x", "predecessor", p)} {
break
}
% end
// Move y forward
elements[i] = predecessor
@@ -87,8 +109,8 @@ public // @testable
func _sort3<C>(
_ elements: inout C,
_ a: C.Index, _ b: C.Index, _ c: C.Index
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
)
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""}
) ${rethrows_}
where
C : MutableCollection & RandomAccessCollection
${"" if p else ", C.Iterator.Element : Comparable"}
@@ -159,8 +181,8 @@ func _sort3<C>(
func _partition<C>(
_ elements: inout C,
subRange range: Range<C.Index>
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
) -> C.Index
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""}
) ${rethrows_} -> C.Index
where
C : MutableCollection & RandomAccessCollection
${"" if p else ", C.Iterator.Element : Comparable"}
@@ -172,7 +194,7 @@ func _partition<C>(
// as the pivot for the partition.
let half = numericCast(elements.distance(from: lo, to: hi)) as UInt / 2
let mid = elements.index(lo, offsetBy: numericCast(half))
_sort3(&elements, lo, mid, hi
${try_} _sort3(&elements, lo, mid, hi
${", by: areInIncreasingOrder" if p else ""})
let pivot = elements[mid]
@@ -210,8 +232,9 @@ public // @testable
func _introSort<C>(
_ elements: inout C,
subRange range: Range<C.Index>
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
) where
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""}
) ${rethrows_}
where
C : MutableCollection & RandomAccessCollection
${"" if p else ", C.Iterator.Element : Comparable"} {
@@ -223,7 +246,7 @@ func _introSort<C>(
// Set max recursion depth to 2*floor(log(N)), as suggested in the introsort
// paper: http://www.cs.rpi.edu/~musser/gp/introsort.ps
let depthLimit = 2 * _floorLog2(Int64(count))
_introSortImpl(
${try_} _introSortImpl(
&elements,
subRange: range,
${"by: areInIncreasingOrder," if p else ""}
@@ -235,22 +258,23 @@ func _introSort<C>(
func _introSortImpl<C>(
_ elements: inout C,
subRange range: Range<C.Index>
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""},
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""},
depthLimit: Int
) where
) ${rethrows_}
where
C : MutableCollection & RandomAccessCollection
${"" if p else ", C.Iterator.Element : Comparable"} {
// Insertion sort is better at handling smaller regions.
if elements.distance(from: range.lowerBound, to: range.upperBound) < 20 {
_insertionSort(
${try_} _insertionSort(
&elements,
subRange: range
${", by: areInIncreasingOrder" if p else ""})
return
}
if depthLimit == 0 {
_heapSort(
${try_} _heapSort(
&elements,
subRange: range
${", by: areInIncreasingOrder" if p else ""})
@@ -260,16 +284,16 @@ func _introSortImpl<C>(
// Partition and sort.
// We don't check the depthLimit variable for underflow because this variable
// is always greater than zero (see check above).
let partIdx: C.Index = _partition(
let partIdx: C.Index = ${try_} _partition(
&elements,
subRange: range
${", by: areInIncreasingOrder" if p else ""})
_introSortImpl(
${try_} _introSortImpl(
&elements,
subRange: range.lowerBound..<partIdx,
${"by: areInIncreasingOrder, " if p else ""}
depthLimit: depthLimit &- 1)
_introSortImpl(
${try_} _introSortImpl(
&elements,
subRange: partIdx..<range.upperBound,
${"by: areInIncreasingOrder, " if p else ""}
@@ -282,8 +306,9 @@ func _siftDown<C>(
_ elements: inout C,
index: C.Index,
subRange range: Range<C.Index>
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
) where
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""}
) ${rethrows_}
where
C : MutableCollection & RandomAccessCollection
${"" if p else ", C.Iterator.Element : Comparable"} {
@@ -310,7 +335,7 @@ func _siftDown<C>(
// down.
if largest != index {
swap(&elements[index], &elements[largest])
_siftDown(
${try_} _siftDown(
&elements,
index: largest,
subRange: range
@@ -323,8 +348,9 @@ func _siftDown<C>(
func _heapify<C>(
_ elements: inout C,
subRange range: Range<C.Index>
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
) where
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""}
) ${rethrows_}
where
C : MutableCollection & RandomAccessCollection
${"" if p else ", C.Iterator.Element : Comparable"} {
// Here we build a heap starting from the lowest nodes and moving to the root.
@@ -341,7 +367,7 @@ func _heapify<C>(
while node != root {
elements.formIndex(before: &node)
_siftDown(
${try_} _siftDown(
&elements,
index: node,
subRange: range
@@ -354,20 +380,21 @@ func _heapify<C>(
func _heapSort<C>(
_ elements: inout C,
subRange range: Range<C.Index>
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) -> Bool" if p else ""}
) where
${", by areInIncreasingOrder: (C.Iterator.Element, C.Iterator.Element) throws -> Bool" if p else ""}
) ${rethrows_}
where
C : MutableCollection & RandomAccessCollection
${"" if p else ", C.Iterator.Element : Comparable"} {
var hi = range.upperBound
let lo = range.lowerBound
_heapify(
${try_} _heapify(
&elements,
subRange: range
${", by: areInIncreasingOrder" if p else ""})
elements.formIndex(before: &hi)
while hi != lo {
swap(&elements[lo], &elements[hi])
_siftDown(
${try_} _siftDown(
&elements,
index: lo,
subRange: lo..<hi

View File

@@ -39,4 +39,4 @@ demo()
// CHECK-O0-NOT: DW_OP_bit_piece
// CHECK-O0: !DILocalVariable(name: "b", arg: 2{{.*}} line: 17,
// CHECK-O0-NOT: DW_OP_bit_piece
// CHECK-O0: !DISubprogram(linkageName: "_T0S2SSbIxxxd_S2SSbIxiid_TR",
// CHECK-O0: !DISubprogram(linkageName: "_T0SSSSSbs5Error_pIxxxdzo_SSSSSbsAA_pIxiidzo_TR",

View File

@@ -37,9 +37,9 @@ func foo3(a: Float, b: Bool) {}
// CHECK-REPLACEMENT1: <Group>Collection/Array</Group>
// CHECK-REPLACEMENT1: <Declaration>{{.*}}func sorted() -&gt; [<Type usr="s:Si">Int</Type>]</Declaration>
// CHECK-REPLACEMENT1: RELATED BEGIN
// CHECK-REPLACEMENT1: sorted(by: (Int, Int) -&gt; Bool) -&gt; [Int]</RelatedName>
// CHECK-REPLACEMENT1: sorted(by: (Int, Int) throws -&gt; Bool) rethrows -&gt; [Int]</RelatedName>
// CHECK-REPLACEMENT1: sorted() -&gt; [Int]</RelatedName>
// CHECK-REPLACEMENT1: sorted(by: (Int, Int) -&gt; Bool) -&gt; [Int]</RelatedName>
// CHECK-REPLACEMENT1: sorted(by: (Int, Int) throws -&gt; Bool) rethrows -&gt; [Int]</RelatedName>
// CHECK-REPLACEMENT1: RELATED END
// RUN: %sourcekitd-test -req=cursor -pos=9:8 %s -- %s %mcp_opt %clang-importer-sdk | %FileCheck -check-prefix=CHECK-REPLACEMENT2 %s
@@ -51,7 +51,7 @@ func foo3(a: Float, b: Bool) {}
// CHECK-REPLACEMENT3: func sorted(by areInIncreasingOrder: (<Type usr="s:13cursor_stdlib2S1V">S1</Type>
// CHECK-REPLACEMENT3: sorted() -&gt; [S1]</RelatedName>
// CHECK-REPLACEMENT3: sorted() -&gt; [S1]</RelatedName>
// CHECK-REPLACEMENT3: sorted(by: (S1, S1) -&gt; Bool) -&gt; [S1]</RelatedName>
// CHECK-REPLACEMENT3: sorted(by: (S1, S1) throws -&gt; Bool) rethrows -&gt; [S1]</RelatedName>
// RUN: %sourcekitd-test -req=cursor -pos=18:8 %s -- %s %mcp_opt %clang-importer-sdk | %FileCheck -check-prefix=CHECK-REPLACEMENT4 %s
// CHECK-REPLACEMENT4: <Group>Collection/Array</Group>

View File

@@ -28409,6 +28409,7 @@
"usr": "s:FEsPs8Sequence6sortedFT2byFTWx8Iterator7Element_WxS0_S1___Sb_GSaWxS0_S1___",
"location": "",
"moduleName": "Swift",
"throwing": true,
"children": [
{
"kind": "TypeNominal",
@@ -51407,6 +51408,7 @@
"usr": "s:FEsPs17MutableCollection6sortedFT2byFTWx8Iterator7Element_WxS0_S1___Sb_GSaWxS0_S1___",
"location": "",
"moduleName": "Swift",
"throwing": true,
"children": [
{
"kind": "TypeNominal",
@@ -51479,6 +51481,7 @@
"usr": "s:FesRxs17MutableCollectionxs22RandomAccessCollectionrS_4sortFT2byFTWxPs10Collection8Iterator7Element_WxS2_S3___Sb_T_",
"location": "",
"moduleName": "Swift",
"throwing": true,
"mutating": true,
"children": [
{

View File

@@ -340,4 +340,104 @@ ErrorHandlingTests.test("ErrorHandling/Collection map") {
}
}
ErrorHandlingTests.test("ErrorHandling/sort") {
var collection = Array(1...5)
forAllPermutations(collection) { sequence in
for i in 0..<sequence.count {
var s = sequence
let throwElment = sequence[i]
do {
try s.sort { (a, b) throws -> Bool in
if b == throwElment {
throw SillyError.JazzHands
}
return a < b
}
} catch {}
//Check no element should lost and added
expectEqualsUnordered(collection, s)
}
}
}
ErrorHandlingTests.test("ErrorHandling/sorted") {
var collection = Array(1...5)
forAllPermutations(collection) { sequence in
for i in 0..<sequence.count {
var s = sequence
var thrown = false
let throwElment = sequence[i]
var result: [Int] = []
do {
result = try s.sorted { (a, b) throws -> Bool in
if b == throwElment {
thrown = true
throw SillyError.JazzHands
}
return a < b
}
} catch {}
//Check actual sequence should not mutate
expectEqualSequence(sequence, s)
if thrown {
//Check result should be empty when thrown
expectEqualSequence([], result)
} else {
//Check result should be sorted when not thrown
expectEqualSequence(collection, result)
}
}
}
}
ErrorHandlingTests.test("ErrorHandling/sort") {
var collection = Array(1...5)
forAllPermutations(collection) { sequence in
for i in 0..<sequence.count {
var s = sequence
let throwElment = sequence[i]
do {
try s.sort { (a, b) throws -> Bool in
if b == throwElment {
throw SillyError.JazzHands
}
return a < b
}
} catch {}
//Check no element should lost and added
expectEqualsUnordered(collection, s)
}
}
}
ErrorHandlingTests.test("ErrorHandling/sorted") {
var collection = Array(1...5)
forAllPermutations(collection) { sequence in
for i in 0..<sequence.count {
var s = sequence
var thrown = false
let throwElment = sequence[i]
var result: [Int] = []
do {
result = try s.sorted { (a, b) throws -> Bool in
if b == throwElment {
thrown = true
throw SillyError.JazzHands
}
return a < b
}
} catch {}
//Check actual sequence should not mutate
expectEqualSequence(sequence, s)
if thrown {
//Check result should be empty when thrown
expectEqualSequence([], result)
} else {
//Check result should be sorted when not thrown
expectEqualSequence(collection, result)
}
}
}
}
runAllTests()