diff --git a/csharp/src/Apache.Arrow/Schema.cs b/csharp/src/Apache.Arrow/Schema.cs index 4357e8b2ddd44..233988fdc5d30 100644 --- a/csharp/src/Apache.Arrow/Schema.cs +++ b/csharp/src/Apache.Arrow/Schema.cs @@ -31,6 +31,7 @@ public partial class Schema : IRecordType private readonly List _fieldsList; public ILookup FieldsLookup { get; } + private readonly ILookup _fieldsIndexLookup; public IReadOnlyDictionary Metadata { get; } @@ -43,17 +44,11 @@ public partial class Schema : IRecordType public Schema( IEnumerable fields, IEnumerable> metadata) + : this( + fields.ToList(), + metadata?.ToDictionary(kv => kv.Key, kv => kv.Value), + false) { - if (fields is null) - { - throw new ArgumentNullException(nameof(fields)); - } - - _fieldsList = fields.ToList(); - FieldsLookup = _fieldsList.ToLookup(f => f.Name); - _fieldsDictionary = FieldsLookup.ToDictionary(g => g.Key, g => g.First()); - - Metadata = metadata?.ToDictionary(kv => kv.Key, kv => kv.Value); } internal Schema(List fieldsList, IReadOnlyDictionary metadata, bool copyCollections) @@ -66,6 +61,10 @@ internal Schema(List fieldsList, IReadOnlyDictionary meta _fieldsDictionary = FieldsLookup.ToDictionary(g => g.Key, g => g.First()); Metadata = metadata; + + _fieldsIndexLookup = _fieldsList + .Select((x, idx) => (Name: x.Name, Index: idx)) + .ToLookup(x => x.Name, x => x.Index, StringComparer.CurrentCulture); } public Field GetFieldByIndex(int i) => _fieldsList[i]; @@ -80,7 +79,10 @@ public int GetFieldIndex(string name, StringComparer comparer) public int GetFieldIndex(string name, IEqualityComparer comparer = default) { - comparer ??= StringComparer.CurrentCulture; + if (comparer == null) + { + return _fieldsIndexLookup[name].First(); + } return _fieldsList.IndexOf(_fieldsList.First(x => comparer.Equals(x.Name, name))); } diff --git a/csharp/test/Apache.Arrow.Tests/SchemaTests.cs b/csharp/test/Apache.Arrow.Tests/SchemaTests.cs new file mode 100644 index 0000000000000..15c293e972b4b --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/SchemaTests.cs @@ -0,0 +1,64 @@ + +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow; +using Apache.Arrow.Types; +using System; +using Xunit; + +namespace Apache.Arrow.Tests; + +internal class SchemaTests +{ + public class Construct + { + [Fact] + public void ThrowsWhenFieldsAreNull() + { + Assert.Throws(() => new Schema(null, null)); + Assert.Throws(() => new Schema(null, null, copyCollections: false)); + } + } + + public class IndexRetrieval + { + [Fact] + public void CanRetrieveFieldIndexByName() + { + var field0 = new Field("f0", Int32Type.Default, true); + var field1 = new Field("f1", Int64Type.Default, true); + var schema = new Schema([field0, field1], null); + + Assert.Equal(0, schema.GetFieldIndex("f0")); + Assert.Equal(1, schema.GetFieldIndex("f1")); + Assert.Equal(-1, schema.GetFieldIndex("nonexistent")); + } + + [Fact] + public void CanRetrieveFieldIndexByNonUniqueName() + { + var field0 = new Field("f0", Int32Type.Default, true); + var field1 = new Field("f1", Int64Type.Default, true); + + // Repeat fields in the list + var schema = new Schema([field0, field1, field0, field1], null); + + Assert.Equal(0, schema.GetFieldIndex("f0")); + Assert.Equal(1, schema.GetFieldIndex("f1")); + Assert.Equal(-1, schema.GetFieldIndex("nonexistent")); + } + } +}