From 988dec4be33ccb7372e6f4ac747e769f342a88de Mon Sep 17 00:00:00 2001 From: Jeremy Powell Date: Thu, 5 Dec 2024 23:00:46 +1300 Subject: [PATCH] Handle range lock sector --- OpenMcdf.Tests/RootStorageTests.cs | 24 ++++++++++++++++++++ OpenMcdf/Fat.cs | 36 ++++++++++++++++++++++++------ OpenMcdf/FatEntry.cs | 2 -- OpenMcdf/FatEnumerator.cs | 2 +- OpenMcdf/FatStream.cs | 5 ++++- OpenMcdf/MiniFat.cs | 6 +++-- OpenMcdf/RootContext.cs | 16 +++++++++---- OpenMcdf/RootStorage.cs | 7 +++--- 8 files changed, 78 insertions(+), 20 deletions(-) diff --git a/OpenMcdf.Tests/RootStorageTests.cs b/OpenMcdf.Tests/RootStorageTests.cs index f6350ed..55ce2e3 100644 --- a/OpenMcdf.Tests/RootStorageTests.cs +++ b/OpenMcdf.Tests/RootStorageTests.cs @@ -252,4 +252,28 @@ public void V3ThrowsIOExceptionAt2GB() Assert.ThrowsException(() => stream.Write(buffer, 0, buffer.Length)); } + + [TestMethod] + [DoNotParallelize] // High memory usage + public void ValidateRangeLockSector() + { + RecyclableMemoryStreamManager manager = new(); + using RecyclableMemoryStream baseStream = new(manager); + baseStream.Capacity64 = RootContext.RangeLockSectorOffset; + + using var rootStorage = RootStorage.Create(baseStream, Version.V4); + using (CfbStream stream = rootStorage.CreateStream("Test")) + { + byte[] buffer = TestData.CreateByteArray(4096); + while (baseStream.Length <= RootContext.RangeLockSectorOffset) + stream.Write(buffer, 0, buffer.Length); + } + + Assert.IsTrue(rootStorage.Validate()); + + rootStorage.Delete("Test"); + rootStorage.Flush(); + + Assert.IsTrue(rootStorage.Validate()); + } } diff --git a/OpenMcdf/Fat.cs b/OpenMcdf/Fat.cs index 92c8a94..2d71b32 100644 --- a/OpenMcdf/Fat.cs +++ b/OpenMcdf/Fat.cs @@ -15,11 +15,20 @@ internal sealed class Fat : ContextBase, IEnumerable, IDisposable Sector cachedSector = Sector.EndOfChain; private bool isDirty; + public Func IsUsed { get; } + public Fat(RootContextSite rootContextSite) : base(rootContextSite) { fatSectorEnumerator = new(rootContextSite); cachedSectorBuffer = new byte[Context.SectorSize]; + + if (Context.Version == Version.V3) + IsUsed = entry => entry.Value is not SectorType.Free; + else if (Context.Version == Version.V4) + IsUsed = entry => entry.Value is not SectorType.Free && entry.Index is not RootContext.RangeLockSectorId; + else + throw new NotSupportedException($"Unsupported major version: {Context.Version}."); } public void Dispose() @@ -143,14 +152,14 @@ public uint Add(FatEnumerator fatEnumerator, uint startIndex) public Sector GetLastUsedSector() { - uint lastUsedSectorIndex = uint.MaxValue; + FatEntry lastUsedSectorIndex = new(uint.MaxValue, uint.MaxValue); foreach (FatEntry entry in this) { - if (!entry.IsFree) - lastUsedSectorIndex = entry.Index; + if (IsUsed(entry)) + lastUsedSectorIndex = entry; } - return new(lastUsedSectorIndex, Context.SectorSize); + return new(lastUsedSectorIndex.Index, Context.SectorSize); } public IEnumerator GetEnumerator() => new FatEnumerator(Context.Fat); @@ -172,7 +181,7 @@ internal void WriteTrace(TextWriter writer) foreach (FatEntry entry in this) { Sector sector = new(entry.Index, Context.SectorSize); - if (entry.IsFree) + if (entry.Value is SectorType.Free) { freeCount++; writer.WriteLine($"{entry}"); @@ -194,7 +203,7 @@ internal void WriteTrace(TextWriter writer) } [ExcludeFromCodeCoverage] - internal void Validate() + internal bool Validate() { long fatSectorCount = 0; long difatSectorCount = 0; @@ -213,8 +222,21 @@ internal void Validate() throw new FileFormatException($"FAT sector count mismatch. Expected: {Context.Header.FatSectorCount} Actual: {fatSectorCount}."); if (Context.Header.DifatSectorCount != difatSectorCount) throw new FileFormatException($"DIFAT sector count mismatch: Expected: {Context.Header.DifatSectorCount} Actual: {difatSectorCount}."); + + if (Context.Length < RootContext.RangeLockSectorOffset) + { + if (this.TryGetValue(RootContext.RangeLockSectorId, out uint value) && value != SectorType.Free) + throw new FileFormatException($"Range lock FAT entry is not free."); + } + else + { + if (this[RootContext.RangeLockSectorId] != SectorType.EndOfChain) + throw new FileFormatException($"Range lock sector is not at the end of the chain."); + } + + return true; } [ExcludeFromCodeCoverage] - internal long GetFreeSectorCount() => this.Count(entry => entry.IsFree); + internal long GetFreeSectorCount() => this.Count(entry => entry.Value == SectorType.Free); } diff --git a/OpenMcdf/FatEntry.cs b/OpenMcdf/FatEntry.cs index d6e430c..890fc9e 100644 --- a/OpenMcdf/FatEntry.cs +++ b/OpenMcdf/FatEntry.cs @@ -7,8 +7,6 @@ namespace OpenMcdf; /// internal record struct FatEntry(uint Index, uint Value) { - public readonly bool IsFree => Value == SectorType.Free; - [ExcludeFromCodeCoverage] public override readonly string ToString() => $"#{Index}: {Value}"; } diff --git a/OpenMcdf/FatEnumerator.cs b/OpenMcdf/FatEnumerator.cs index cdc429f..0645fbe 100644 --- a/OpenMcdf/FatEnumerator.cs +++ b/OpenMcdf/FatEnumerator.cs @@ -77,7 +77,7 @@ public bool MoveNextFreeEntry() { while (MoveNext()) { - if (value == SectorType.Free) + if (value is SectorType.Free) return true; } diff --git a/OpenMcdf/FatStream.cs b/OpenMcdf/FatStream.cs index 8c61a65..b68bd50 100644 --- a/OpenMcdf/FatStream.cs +++ b/OpenMcdf/FatStream.cs @@ -1,4 +1,6 @@ -namespace OpenMcdf; +using System.Diagnostics; + +namespace OpenMcdf; /// /// Provides a for a stream object in a compound file./> @@ -185,6 +187,7 @@ public override void Write(byte[] buffer, int offset, int count) long writeLength = Math.Min(remaining, sector.Length - sectorOffset); writer.Write(buffer, localOffset, (int)writeLength); Context.ExtendStreamLength(sector.EndPosition); + Debug.Assert(Context.Length >= Context.Stream.Length); position += writeLength; writeCount += (int)writeLength; sectorOffset = 0; diff --git a/OpenMcdf/MiniFat.cs b/OpenMcdf/MiniFat.cs index 6d42998..aa233ba 100644 --- a/OpenMcdf/MiniFat.cs +++ b/OpenMcdf/MiniFat.cs @@ -133,7 +133,7 @@ public uint Add(MiniFatEnumerator miniFatEnumerator, uint startIndex) FatEntry entry = miniFatEnumerator.Current; this[entry.Index] = SectorType.EndOfChain; - Debug.Assert(entry.IsFree); + Debug.Assert(entry.Value is SectorType.Free); MiniSector miniSector = new(entry.Index, Context.MiniSectorSize); if (Context.MiniStream.Length < miniSector.EndPosition) Context.MiniStream.SetLength(miniSector.EndPosition); @@ -153,7 +153,7 @@ internal void WriteTrace(TextWriter writer) } [ExcludeFromCodeCoverage] - internal void Validate() + internal bool Validate() { using MiniFatEnumerator miniFatEnumerator = new(ContextSite); @@ -165,5 +165,7 @@ internal void Validate() throw new FileFormatException($"Mini FAT entry {current} is beyond the end of the mini stream."); } } + + return true; } } diff --git a/OpenMcdf/RootContext.cs b/OpenMcdf/RootContext.cs index 0898897..cb6f94e 100644 --- a/OpenMcdf/RootContext.cs +++ b/OpenMcdf/RootContext.cs @@ -16,7 +16,9 @@ enum IOContextFlags /// internal sealed class RootContext : ContextBase, IDisposable { - const long MaximumV3StreamLength = 2147483648; + internal const long MaximumV3StreamLength = 2147483648; + internal const uint RangeLockSectorOffset = 0x7FFFFF00; + internal const uint RangeLockSectorId = RangeLockSectorOffset / (1 << Header.SectorShiftV4) - 1; readonly IOContextFlags contextFlags; readonly CfbBinaryWriter? writer; @@ -187,11 +189,14 @@ public void Flush() public void ExtendStreamLength(long length) { + if (Length >= length) + return; + if (Version is Version.V3 && length > MaximumV3StreamLength) throw new IOException("V3 compound files are limited to 2 GB."); - - if (Length < length) - Length = length; + else if (Version is Version.V4 && Length < RangeLockSectorOffset && length >= RangeLockSectorOffset) + Fat[RangeLockSectorId] = SectorType.EndOfChain; + Length = length; } void TrimBaseStream() @@ -200,6 +205,9 @@ void TrimBaseStream() if (!lastUsedSector.IsValid) throw new FileFormatException("Last used sector is invalid"); + if (Version is Version.V4 && lastUsedSector.EndPosition < RangeLockSectorOffset) + Fat.TrySetValue(RangeLockSectorId, SectorType.Free); + Length = lastUsedSector.EndPosition; BaseStream.SetLength(Length); } diff --git a/OpenMcdf/RootStorage.cs b/OpenMcdf/RootStorage.cs index e567036..8b3cce8 100644 --- a/OpenMcdf/RootStorage.cs +++ b/OpenMcdf/RootStorage.cs @@ -244,10 +244,11 @@ internal void Trace(TextWriter writer) Context.MiniFat.WriteTrace(writer); } + // TODO: Move checks to Tests project as Asserts [ExcludeFromCodeCoverage] - internal void Validate() + internal bool Validate() { - Context.Fat.Validate(); - Context.MiniFat.Validate(); + return Context.Fat.Validate() + && Context.MiniFat.Validate(); } }