Add validation to CollectionDifference decoder (#76876)

The `CollectionDifference` type has a few different invariants
that were not being validated when initializing using the type's
`Decodable` conformance, since the type was using the
autogenerated `Codable` implementation. This change provides
manual implementations of the `Encodable` and `Decodable`
requirements, and adds tests that validate the failure when trying
to decode invalid JSON for CollectionDifference (and a few other
types).
This commit is contained in:
Nate Cook
2024-10-11 19:23:55 -05:00
committed by GitHub
parent 3050916efc
commit 9265743adc
2 changed files with 179 additions and 10 deletions

View File

@@ -63,6 +63,12 @@ public struct CollectionDifference<ChangeElement> {
}
}
}
internal var _isRemoval: Bool {
switch self {
case .insert: false
case .remove: true
}
}
}
/// The insertions contained by this difference, from lowest offset to
@@ -404,13 +410,7 @@ extension CollectionDifference.Change: Codable where ChangeElement: Codable {
public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: _CodingKeys.self)
switch self {
case .remove(_, _, _):
try container.encode(true, forKey: .isRemove)
case .insert(_, _, _):
try container.encode(false, forKey: .isRemove)
}
try container.encode(_isRemoval, forKey: .isRemove)
try container.encode(_offset, forKey: .offset)
try container.encode(_element, forKey: .element)
try container.encode(_associatedOffset, forKey: .associatedOffset)
@@ -418,7 +418,37 @@ extension CollectionDifference.Change: Codable where ChangeElement: Codable {
}
@available(SwiftStdlib 5.1, *)
extension CollectionDifference: Codable where ChangeElement: Codable {}
extension CollectionDifference: Codable where ChangeElement: Codable {
private enum _CodingKeys: String, CodingKey {
case insertions
case removals
}
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: _CodingKeys.self)
var changes = try container.decode([Change].self, forKey: .removals)
let removalCount = changes.count
try changes.append(contentsOf: container.decode([Change].self, forKey: .insertions))
guard changes[..<removalCount].allSatisfy({ $0._isRemoval }),
changes[removalCount...].allSatisfy({ !$0._isRemoval }),
Self._validateChanges(changes)
else {
throw DecodingError.dataCorrupted(
DecodingError.Context(
codingPath: decoder.codingPath,
debugDescription: "Cannot decode an invalid collection difference"))
}
self.init(_validatedChanges: changes)
}
public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: _CodingKeys.self)
try container.encode(insertions, forKey: .insertions)
try container.encode(removals, forKey: .removals)
}
}
@available(SwiftStdlib 5.1, *)
extension CollectionDifference: Sendable where ChangeElement: Sendable { }

View File

@@ -118,6 +118,23 @@ func expectRoundTripEqualityThroughPlist<T : Codable>(for value: T, lineNumber:
expectRoundTripEquality(of: value, encode: encode, decode: decode, lineNumber: lineNumber)
}
func expectDecodingErrorViaJSON<T : Codable>(
type: T.Type,
json: String,
errorKind: DecodingErrorKind,
lineNumber: Int = #line)
{
let data = json.data(using: .utf8)!
do {
let value = try JSONDecoder().decode(T.self, from: data)
expectUnreachable(":\(lineNumber): Successfully decoded invalid \(T.self) <\(debugDescription(value))>")
} catch let error as DecodingError {
expectEqual(error.errorKind, errorKind, "\(#file):\(lineNumber): Incorrect error kind <\(error.errorKind)> not equal to expected <\(errorKind)>")
} catch {
expectUnreachableCatch(error, ":\(lineNumber): Unexpected error type when decoding \(T.self)")
}
}
// MARK: - Helper Types
// A wrapper around a UUID that will allow it to be encoded at the top level of an encoder.
struct UUIDCodingWrapper : Codable, Equatable, Hashable, CodingKeyRepresentable {
@@ -141,6 +158,24 @@ struct UUIDCodingWrapper : Codable, Equatable, Hashable, CodingKeyRepresentable
}
}
enum DecodingErrorKind {
case dataCorrupted
case keyNotFound
case typeMismatch
case valueNotFound
}
extension DecodingError {
var errorKind: DecodingErrorKind {
switch self {
case .dataCorrupted: .dataCorrupted
case .keyNotFound: .keyNotFound
case .typeMismatch: .typeMismatch
case .valueNotFound: .valueNotFound
}
}
}
// MARK: - Tests
class TestCodable : TestCodableSuper {
// MARK: - AffineTransform
@@ -392,6 +427,90 @@ class TestCodable : TestCodableSuper {
expectEqual(value.upperBound, decoded.upperBound, "\(#file):\(#line): Decoded ClosedRange upperBound <\(debugDescription(decoded))> not equal to original <\(debugDescription(value))>")
expectEqual(value.lowerBound, decoded.lowerBound, "\(#file):\(#line): Decoded ClosedRange lowerBound <\(debugDescription(decoded))> not equal to original <\(debugDescription(value))>")
}
func test_ClosedRange_JSON_Errors() {
expectDecodingErrorViaJSON(
type: ClosedRange<Int>.self,
json: "[5,0]",
errorKind: .dataCorrupted)
expectDecodingErrorViaJSON(
type: ClosedRange<Int>.self,
json: "[5,]",
errorKind: .valueNotFound)
expectDecodingErrorViaJSON(
type: ClosedRange<Int>.self,
json: "[0,Hello]",
errorKind: .dataCorrupted)
}
// MARK: - CollectionDifference
lazy var collectionDifferenceValues: [Int : CollectionDifference<Int>] = [
#line : [1, 2, 3].difference(from: [1, 2, 3]),
#line : [1, 2, 3].difference(from: [1, 2]),
#line : [1, 2, 3].difference(from: [2, 3, 4]),
#line : [1, 2, 3].difference(from: [6, 7, 8]),
]
func test_CollectionDifference_JSON() {
for (testLine, difference) in collectionDifferenceValues {
expectRoundTripEqualityThroughJSON(for: difference, lineNumber: testLine)
}
}
func test_CollectionDifference_Plist() {
for (testLine, difference) in collectionDifferenceValues {
expectRoundTripEqualityThroughPlist(for: difference, lineNumber: testLine)
}
}
func test_CollectionDifference_JSON_Errors() {
// Valid serialization:
// {
// "insertions" : [ { "associatedOffset" : null, "element" : 1, "isRemove" : false, "offset" : 0 } ],
// "removals" : [ { "associatedOffset" : null, "element" : 4, "isRemove" : true, "offset" : 2 } ]
// }
// Removal in insertion
expectDecodingErrorViaJSON(
type: CollectionDifference<Int>.self,
json: #"""
{
"insertions" : [ { "associatedOffset" : null, "element" : 1, "isRemove" : true, "offset" : 0 } ],
"removals" : [ { "associatedOffset" : null, "element" : 4, "isRemove" : true, "offset" : 2 } ]
}
"""#,
errorKind: .dataCorrupted)
// Repeated offset
expectDecodingErrorViaJSON(
type: CollectionDifference<Int>.self,
json: #"""
{
"insertions" : [ { "associatedOffset" : null, "element" : 1, "isRemove" : true, "offset" : 2 } ],
"removals" : [ { "associatedOffset" : null, "element" : 4, "isRemove" : true, "offset" : 2 } ]
}
"""#,
errorKind: .dataCorrupted)
// Invalid offset
expectDecodingErrorViaJSON(
type: CollectionDifference<Int>.self,
json: #"""
{
"insertions" : [ { "associatedOffset" : null, "element" : 1, "isRemove" : true, "offset" : -2 } ],
"removals" : [ { "associatedOffset" : null, "element" : 4, "isRemove" : true, "offset" : 2 } ]
}
"""#,
errorKind: .dataCorrupted)
// Invalid associated offset
expectDecodingErrorViaJSON(
type: CollectionDifference<Int>.self,
json: #"""
{
"insertions" : [ { "associatedOffset" : 2, "element" : 1, "isRemove" : true, "offset" : 0 } ],
"removals" : [ { "associatedOffset" : null, "element" : 4, "isRemove" : true, "offset" : 2 } ]
}
"""#,
errorKind: .dataCorrupted)
}
// MARK: - ContiguousArray
lazy var contiguousArrayValues: [Int : ContiguousArray<String>] = [
@@ -789,6 +908,21 @@ class TestCodable : TestCodableSuper {
expectEqual(value.upperBound, decoded.upperBound, "\(#file):\(#line): Decoded Range upperBound<\(debugDescription(decoded))> not equal to original <\(debugDescription(value))>")
expectEqual(value.lowerBound, decoded.lowerBound, "\(#file):\(#line): Decoded Range lowerBound<\(debugDescription(decoded))> not equal to original <\(debugDescription(value))>")
}
func test_Range_JSON_Errors() {
expectDecodingErrorViaJSON(
type: Range<Int>.self,
json: "[5,0]",
errorKind: .dataCorrupted)
expectDecodingErrorViaJSON(
type: Range<Int>.self,
json: "[5,]",
errorKind: .valueNotFound)
expectDecodingErrorViaJSON(
type: Range<Int>.self,
json: "[0,Hello]",
errorKind: .dataCorrupted)
}
// MARK: - TimeZone
lazy var timeZoneValues: [Int : TimeZone] = [
@@ -808,7 +942,7 @@ class TestCodable : TestCodableSuper {
expectRoundTripEqualityThroughPlist(for: timeZone, lineNumber: testLine)
}
}
// MARK: - URL
lazy var urlValues: [Int : URL] = {
var values: [Int : URL] = [
@@ -845,7 +979,7 @@ class TestCodable : TestCodableSuper {
expectRoundTripEqualityThroughPlist(for: url, lineNumber: testLine)
}
}
// MARK: - URLComponents
lazy var urlComponentsValues: [Int : URLComponents] = [
#line : URLComponents(),
@@ -1016,6 +1150,10 @@ var tests = [
"test_CGVector_Plist" : TestCodable.test_CGVector_Plist,
"test_ClosedRange_JSON" : TestCodable.test_ClosedRange_JSON,
"test_ClosedRange_Plist" : TestCodable.test_ClosedRange_Plist,
"test_ClosedRange_JSON_Errors" : TestCodable.test_ClosedRange_JSON_Errors,
"test_CollectionDifference_JSON" : TestCodable.test_CollectionDifference_JSON,
"test_CollectionDifference_Plist" : TestCodable.test_CollectionDifference_Plist,
"test_CollectionDifference_JSON_Errors" : TestCodable.test_CollectionDifference_JSON_Errors,
"test_ContiguousArray_JSON" : TestCodable.test_ContiguousArray_JSON,
"test_ContiguousArray_Plist" : TestCodable.test_ContiguousArray_Plist,
"test_DateComponents_JSON" : TestCodable.test_DateComponents_JSON,
@@ -1038,6 +1176,7 @@ var tests = [
"test_PartialRangeUpTo_Plist" : TestCodable.test_PartialRangeUpTo_Plist,
"test_Range_JSON" : TestCodable.test_Range_JSON,
"test_Range_Plist" : TestCodable.test_Range_Plist,
"test_Range_JSON_Errors" : TestCodable.test_Range_JSON_Errors,
"test_TimeZone_JSON" : TestCodable.test_TimeZone_JSON,
"test_TimeZone_Plist" : TestCodable.test_TimeZone_Plist,
"test_URL_JSON" : TestCodable.test_URL_JSON,