diff --git a/ios/MullvadVPN/View controllers/SelectLocation/AllLocationDataSource.swift b/ios/MullvadVPN/View controllers/SelectLocation/AllLocationDataSource.swift index 2cde9b5b6965..a6e9e1bab06d 100644 --- a/ios/MullvadVPN/View controllers/SelectLocation/AllLocationDataSource.swift +++ b/ios/MullvadVPN/View controllers/SelectLocation/AllLocationDataSource.swift @@ -44,11 +44,11 @@ class AllLocationDataSource: LocationDataSourceProtocol { return switch location { case let .country(countryCode): - rootNode.descendantNodeFor(code: countryCode) - case let .city(_, cityCode): - rootNode.descendantNodeFor(code: cityCode) + rootNode.descendantNodeFor(codes: [countryCode]) + case let .city(countryCode, cityCode): + rootNode.descendantNodeFor(codes: [countryCode, cityCode]) case let .hostname(_, _, hostCode): - rootNode.descendantNodeFor(code: hostCode) + rootNode.descendantNodeFor(codes: [hostCode]) } } @@ -62,7 +62,7 @@ class AllLocationDataSource: LocationDataSourceProtocol { case let .country(countryCode): let countryNode = CountryLocationNode( name: serverLocation.country, - code: countryCode, + code: LocationNode.combineNodeCodes([countryCode]), locations: [location] ) @@ -72,7 +72,11 @@ class AllLocationDataSource: LocationDataSourceProtocol { } case let .city(countryCode, cityCode): - let cityNode = CityLocationNode(name: serverLocation.city, code: cityCode, locations: [location]) + let cityNode = CityLocationNode( + name: serverLocation.city, + code: LocationNode.combineNodeCodes([countryCode, cityCode]), + locations: [location] + ) if let countryNode = rootNode.countryFor(code: countryCode), !countryNode.children.contains(cityNode) { @@ -82,10 +86,14 @@ class AllLocationDataSource: LocationDataSourceProtocol { } case let .hostname(countryCode, cityCode, hostCode): - let hostNode = HostLocationNode(name: relay.hostname, code: hostCode, locations: [location]) + let hostNode = HostLocationNode( + name: relay.hostname, + code: LocationNode.combineNodeCodes([hostCode]), + locations: [location] + ) if let countryNode = rootNode.countryFor(code: countryCode), - let cityNode = countryNode.cityFor(code: cityCode), + let cityNode = countryNode.cityFor(codes: [countryCode, cityCode]), !cityNode.children.contains(hostNode) { hostNode.parent = cityNode cityNode.children.append(hostNode) diff --git a/ios/MullvadVPN/View controllers/SelectLocation/CustomListsDataSource.swift b/ios/MullvadVPN/View controllers/SelectLocation/CustomListsDataSource.swift index 51a26401ea71..a41f0eb9e56b 100644 --- a/ios/MullvadVPN/View controllers/SelectLocation/CustomListsDataSource.swift +++ b/ios/MullvadVPN/View controllers/SelectLocation/CustomListsDataSource.swift @@ -43,7 +43,7 @@ class CustomListsDataSource: LocationDataSourceProtocol { // Since LocationCellViewModel partly depends on LocationNode.code for // equality, each node code needs to be prefixed with the code of its // parent custom list to uphold this. - node.code = "\(listNode.code)-\(node.code)" + node.code = LocationNode.combineNodeCodes([listNode.code, node.code]) } return listNode @@ -51,21 +51,21 @@ class CustomListsDataSource: LocationDataSourceProtocol { } func node(by locations: [RelayLocation], for customList: CustomList) -> LocationNode? { - guard let customListNode = nodes.first(where: { $0.name == customList.name }) + guard let listNode = nodes.first(where: { $0.name == customList.name }) else { return nil } if locations.count > 1 { - return customListNode + return listNode } else { // Each search for descendant nodes needs the parent custom list node code to be // prefixed in order to get a match. See comment in reload() above. return switch locations.first { case let .country(countryCode): - customListNode.descendantNodeFor(code: "\(customListNode.code)-\(countryCode)") - case let .city(_, cityCode): - customListNode.descendantNodeFor(code: "\(customListNode.code)-\(cityCode)") + listNode.descendantNodeFor(codes: [listNode.code, countryCode]) + case let .city(countryCode, cityCode): + listNode.descendantNodeFor(codes: [listNode.code, countryCode, cityCode]) case let .hostname(_, _, hostCode): - customListNode.descendantNodeFor(code: "\(customListNode.code)-\(hostCode)") + listNode.descendantNodeFor(codes: [listNode.code, hostCode]) case .none: nil } @@ -91,12 +91,12 @@ class CustomListsDataSource: LocationDataSourceProtocol { case let .city(countryCode, cityCode): rootNode .countryFor(code: countryCode)?.copy(withParent: parentNode) - .cityFor(code: cityCode) + .cityFor(codes: [countryCode, cityCode]) case let .hostname(countryCode, cityCode, hostCode): rootNode .countryFor(code: countryCode)?.copy(withParent: parentNode) - .cityFor(code: cityCode)? + .cityFor(codes: [countryCode, cityCode])? .hostFor(code: hostCode) } } diff --git a/ios/MullvadVPN/View controllers/SelectLocation/LocationCellFactory.swift b/ios/MullvadVPN/View controllers/SelectLocation/LocationCellFactory.swift index 81bd14b052c3..1d0c1f9742c9 100644 --- a/ios/MullvadVPN/View controllers/SelectLocation/LocationCellFactory.swift +++ b/ios/MullvadVPN/View controllers/SelectLocation/LocationCellFactory.swift @@ -40,7 +40,7 @@ final class LocationCellFactory: CellFactoryProtocol { func configureCell(_ cell: UITableViewCell, item: LocationCellViewModel, indexPath: IndexPath) { guard let cell = cell as? LocationCell else { return } - cell.accessibilityIdentifier = item.node.name + cell.accessibilityIdentifier = item.node.code cell.locationLabel.text = item.node.name cell.showsCollapseControl = !item.node.children.isEmpty cell.isExpanded = item.node.showsChildren diff --git a/ios/MullvadVPN/View controllers/SelectLocation/LocationCellViewModel.swift b/ios/MullvadVPN/View controllers/SelectLocation/LocationCellViewModel.swift index 711f6a8a1279..2425413fdd7d 100644 --- a/ios/MullvadVPN/View controllers/SelectLocation/LocationCellViewModel.swift +++ b/ios/MullvadVPN/View controllers/SelectLocation/LocationCellViewModel.swift @@ -13,6 +13,11 @@ struct LocationCellViewModel: Hashable { let node: LocationNode var indentationLevel = 0 + func hash(into hasher: inout Hasher) { + hasher.combine(section) + hasher.combine(node) + } + static func == (lhs: Self, rhs: Self) -> Bool { lhs.node == rhs.node && lhs.section == rhs.section diff --git a/ios/MullvadVPN/View controllers/SelectLocation/LocationNode.swift b/ios/MullvadVPN/View controllers/SelectLocation/LocationNode.swift index 8b123c55ee90..38e6197fcd36 100644 --- a/ios/MullvadVPN/View controllers/SelectLocation/LocationNode.swift +++ b/ios/MullvadVPN/View controllers/SelectLocation/LocationNode.swift @@ -46,16 +46,18 @@ extension LocationNode { self.code == code ? self : children.first(where: { $0.code == code }) } - func cityFor(code: String) -> LocationNode? { - self.code == code ? self : children.first(where: { $0.code == code }) + func cityFor(codes: [String]) -> LocationNode? { + let combinedCode = Self.combineNodeCodes(codes) + return self.code == combinedCode ? self : children.first(where: { $0.code == combinedCode }) } func hostFor(code: String) -> LocationNode? { self.code == code ? self : children.first(where: { $0.code == code }) } - func descendantNodeFor(code: String) -> LocationNode? { - self.code == code ? self : children.compactMap { $0.descendantNodeFor(code: code) }.first + func descendantNodeFor(codes: [String]) -> LocationNode? { + let combinedCode = Self.combineNodeCodes(codes) + return self.code == combinedCode ? self : children.compactMap { $0.descendantNodeFor(codes: codes) }.first } func forEachDescendant(do callback: (LocationNode) -> Void) { @@ -71,6 +73,10 @@ extension LocationNode { parent.forEachAncestor(do: callback) } } + + static func combineNodeCodes(_ codes: [String]) -> String { + codes.joined(separator: "-") + } } extension LocationNode { diff --git a/ios/MullvadVPNTests/Location/AllLocationsDataSourceTests.swift b/ios/MullvadVPNTests/Location/AllLocationsDataSourceTests.swift index bc343a3db765..5149fa1261ad 100644 --- a/ios/MullvadVPNTests/Location/AllLocationsDataSourceTests.swift +++ b/ios/MullvadVPNTests/Location/AllLocationsDataSourceTests.swift @@ -21,18 +21,18 @@ class AllLocationsDataSourceTests: XCTestCase { let rootNode = RootLocationNode(children: dataSource.nodes) // Testing a selection. - XCTAssertNotNil(rootNode.descendantNodeFor(code: "se")) - XCTAssertNotNil(rootNode.descendantNodeFor(code: "dal")) - XCTAssertNotNil(rootNode.descendantNodeFor(code: "es1-wireguard")) - XCTAssertNotNil(rootNode.descendantNodeFor(code: "se2-wireguard")) + XCTAssertNotNil(rootNode.descendantNodeFor(codes: ["se"])) + XCTAssertNotNil(rootNode.descendantNodeFor(codes: ["us", "dal"])) + XCTAssertNotNil(rootNode.descendantNodeFor(codes: ["es1-wireguard"])) + XCTAssertNotNil(rootNode.descendantNodeFor(codes: ["se2-wireguard"])) } func testSearch() throws { let nodes = dataSource.search(by: "got") let rootNode = RootLocationNode(children: nodes) - XCTAssertTrue(rootNode.descendantNodeFor(code: "got")?.isHiddenFromSearch == false) - XCTAssertTrue(rootNode.descendantNodeFor(code: "sto")?.isHiddenFromSearch == true) + XCTAssertTrue(rootNode.descendantNodeFor(codes: ["se", "got"])?.isHiddenFromSearch == false) + XCTAssertTrue(rootNode.descendantNodeFor(codes: ["se", "sto"])?.isHiddenFromSearch == true) } func testSearchWithEmptyText() throws { @@ -42,15 +42,15 @@ class AllLocationsDataSourceTests: XCTestCase { func testNodeByLocation() throws { var nodeByLocation = dataSource.node(by: .country("es")) - var nodeByCode = dataSource.nodes.first?.descendantNodeFor(code: "es") + var nodeByCode = dataSource.nodes.first?.descendantNodeFor(codes: ["es"]) XCTAssertEqual(nodeByLocation, nodeByCode) nodeByLocation = dataSource.node(by: .city("es", "mad")) - nodeByCode = dataSource.nodes.first?.descendantNodeFor(code: "mad") + nodeByCode = dataSource.nodes.first?.descendantNodeFor(codes: ["es", "mad"]) XCTAssertEqual(nodeByLocation, nodeByCode) nodeByLocation = dataSource.node(by: .hostname("es", "mad", "es1-wireguard")) - nodeByCode = dataSource.nodes.first?.descendantNodeFor(code: "es1-wireguard") + nodeByCode = dataSource.nodes.first?.descendantNodeFor(codes: ["es1-wireguard"]) XCTAssertEqual(nodeByLocation, nodeByCode) } } diff --git a/ios/MullvadVPNTests/Location/CustomListsDataSourceTests.swift b/ios/MullvadVPNTests/Location/CustomListsDataSourceTests.swift index 9085ec65d4ef..0120322702e9 100644 --- a/ios/MullvadVPNTests/Location/CustomListsDataSourceTests.swift +++ b/ios/MullvadVPNTests/Location/CustomListsDataSourceTests.swift @@ -22,21 +22,21 @@ class CustomListsDataSourceTests: XCTestCase { let nodes = dataSource.nodes let netflixNode = try XCTUnwrap(nodes.first(where: { $0.name == "Netflix" })) - XCTAssertNotNil(netflixNode.descendantNodeFor(code: "netflix-es1-wireguard")) - XCTAssertNotNil(netflixNode.descendantNodeFor(code: "netflix-se")) - XCTAssertNotNil(netflixNode.descendantNodeFor(code: "netflix-dal")) + XCTAssertNotNil(netflixNode.descendantNodeFor(codes: ["netflix", "es1-wireguard"])) + XCTAssertNotNil(netflixNode.descendantNodeFor(codes: ["netflix", "se"])) + XCTAssertNotNil(netflixNode.descendantNodeFor(codes: ["netflix", "us", "dal"])) let youtubeNode = try XCTUnwrap(nodes.first(where: { $0.name == "Youtube" })) - XCTAssertNotNil(youtubeNode.descendantNodeFor(code: "youtube-se2-wireguard")) - XCTAssertNotNil(youtubeNode.descendantNodeFor(code: "youtube-dal")) + XCTAssertNotNil(youtubeNode.descendantNodeFor(codes: ["youtube", "se2-wireguard"])) + XCTAssertNotNil(youtubeNode.descendantNodeFor(codes: ["youtube", "us", "dal"])) } func testSearch() throws { let nodes = dataSource.search(by: "got") let rootNode = RootLocationNode(children: nodes) - XCTAssertTrue(rootNode.descendantNodeFor(code: "netflix-got")?.isHiddenFromSearch == false) - XCTAssertTrue(rootNode.descendantNodeFor(code: "netflix-sto")?.isHiddenFromSearch == true) + XCTAssertTrue(rootNode.descendantNodeFor(codes: ["netflix", "se", "got"])?.isHiddenFromSearch == false) + XCTAssertTrue(rootNode.descendantNodeFor(codes: ["netflix", "se", "sto"])?.isHiddenFromSearch == true) } func testSearchWithEmptyText() throws { @@ -51,7 +51,7 @@ class CustomListsDataSourceTests: XCTestCase { func testNodeByLocations() throws { let nodeByLocations = dataSource.node(by: [.hostname("es", "mad", "es1-wireguard")], for: customLists.first!) - let nodeByCode = dataSource.nodes.first?.descendantNodeFor(code: "netflix-es1-wireguard") + let nodeByCode = dataSource.nodes.first?.descendantNodeFor(codes: ["netflix", "es1-wireguard"]) XCTAssertEqual(nodeByLocations, nodeByCode) } diff --git a/ios/MullvadVPNTests/Location/LocationNodeTests.swift b/ios/MullvadVPNTests/Location/LocationNodeTests.swift index b2775a7fb214..9495129d0088 100644 --- a/ios/MullvadVPNTests/Location/LocationNodeTests.swift +++ b/ios/MullvadVPNTests/Location/LocationNodeTests.swift @@ -81,7 +81,7 @@ class LocationNodeTests: XCTestCase { } func testFindByCityCode() { - XCTAssertTrue(countryNode.cityFor(code: cityNode.code) == cityNode) + XCTAssertTrue(countryNode.cityFor(codes: [cityNode.code]) == cityNode) } func testFindByHostCode() { @@ -89,7 +89,7 @@ class LocationNodeTests: XCTestCase { } func testFindDescendantByNodeCode() { - XCTAssertTrue(listNode.descendantNodeFor(code: hostNode.code) == hostNode) + XCTAssertTrue(listNode.descendantNodeFor(codes: [hostNode.code]) == hostNode) } }