diff --git a/csharp/src/Apache.Arrow/Schema.cs b/csharp/src/Apache.Arrow/Schema.cs index 32615e5d67bd8..4d0dc39987b7f 100644 --- a/csharp/src/Apache.Arrow/Schema.cs +++ b/csharp/src/Apache.Arrow/Schema.cs @@ -31,6 +31,7 @@ public partial class Schema : IRecordType private readonly List _fieldsList; public ILookup FieldsLookup { get; } + private readonly ILookup _fieldsIndexLookup; public IReadOnlyDictionary Metadata { get; } @@ -43,17 +44,11 @@ public partial class Schema : IRecordType public Schema( IEnumerable fields, IEnumerable> metadata) + : this( + fields.ToList(), + metadata?.ToDictionary(kv => kv.Key, kv => kv.Value), + false) { - if (fields is null) - { - throw new ArgumentNullException(nameof(fields)); - } - - _fieldsList = fields.ToList(); - FieldsLookup = _fieldsList.ToLookup(f => f.Name); - _fieldsDictionary = FieldsLookup.ToDictionary(g => g.Key, g => g.First()); - - Metadata = metadata?.ToDictionary(kv => kv.Key, kv => kv.Value); } internal Schema(List fieldsList, IReadOnlyDictionary metadata, bool copyCollections) @@ -66,6 +61,10 @@ internal Schema(List fieldsList, IReadOnlyDictionary meta _fieldsDictionary = FieldsLookup.ToDictionary(g => g.Key, g => g.First()); Metadata = metadata; + + _fieldsIndexLookup = _fieldsList + .Select((x, idx) => (Name: x.Name, Index: idx)) + .ToLookup(x => x.Name, x => x.Index, StringComparer.CurrentCulture); } public Field GetFieldByIndex(int i) => _fieldsList[i]; @@ -80,15 +79,20 @@ public int GetFieldIndex(string name, StringComparer comparer) public int GetFieldIndex(string name, IEqualityComparer comparer = default) { - comparer ??= StringComparer.CurrentCulture; + if (comparer == null || comparer.Equals(StringComparer.CurrentCulture)) + { + return _fieldsIndexLookup[name].First(); + } - for (int i = 0; i < _fieldsList.Count; i++) + for (var i = 0; i < _fieldsList.Count; ++i) { if (comparer.Equals(_fieldsList[i].Name, name)) + { return i; + } } - return -1; + throw new InvalidOperationException(); } public Schema RemoveField(int fieldIndex) diff --git a/csharp/test/Apache.Arrow.Tests/SchemaTests.cs b/csharp/test/Apache.Arrow.Tests/SchemaTests.cs new file mode 100644 index 0000000000000..9558d95719d46 --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/SchemaTests.cs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow; +using Apache.Arrow.Types; +using System; +using System.Collections.Generic; +using Xunit; + +namespace Apache.Arrow.Tests; + +public class SchemaTests +{ + [Fact] + public void ThrowsWhenFieldsAreNull() + { + Assert.Throws(() => new Schema(null, null)); + } + + [Theory] + [MemberData(nameof(StringComparers))] + public void CanRetrieveFieldIndexByName(StringComparer comparer) + { + var field0 = new Field("f0", Int32Type.Default, true); + var field1 = new Field("f1", Int64Type.Default, true); + var schema = new Schema([field0, field1], null); + + Assert.Equal(0, schema.GetFieldIndex("f0", comparer)); + Assert.Equal(1, schema.GetFieldIndex("f1", comparer)); + Assert.Throws(() => schema.GetFieldIndex("nonexistent", comparer)); + } + + [Theory] + [MemberData(nameof(StringComparers))] + public void CanRetrieveFieldIndexByNonUniqueName(StringComparer comparer) + { + var field0 = new Field("f0", Int32Type.Default, true); + var field1 = new Field("f1", Int64Type.Default, true); + + // Repeat fields in the list + var schema = new Schema([field0, field1, field0, field1], null); + + Assert.Equal(0, schema.GetFieldIndex("f0", comparer)); + Assert.Equal(1, schema.GetFieldIndex("f1", comparer)); + Assert.Throws(() => schema.GetFieldIndex("nonexistent", comparer)); + } + + public static IEnumerable StringComparers() => + new List + { + new object[] {null}, + new object[] {StringComparer.Ordinal}, + new object[] {StringComparer.OrdinalIgnoreCase}, + new object[] {StringComparer.CurrentCulture}, + }; +}