Skip to content

Commit

Permalink
Add GetFstResults
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmaxwell3 committed Feb 10, 2025
1 parent 2d211b1 commit 2e29512
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ DeterministicFsaTraversalInstance<TData, TOffset> ni in Advance(
ReleaseInstance(inst);
}

var newResults = new List<FstResult<TData, TOffset>>();
GetFstResults(newResults, curResults);
return curResults;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ DeterministicFstTraversalInstance<TData, TOffset> ni in Advance(
ReleaseInstance(inst);
}

var newResults = new List<FstResult<TData, TOffset>>();
GetFstResults(newResults, curResults);
return curResults;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ NondeterministicFsaTraversalInstance<TData, TOffset> newInst in Advance(
ReleaseInstance(inst);
}

var newResults = new List<FstResult<TData, TOffset>>();
GetFstResults(newResults, curResults);
return curResults;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ NondeterministicFstTraversalInstance<TData, TOffset> newInst in Advance(
ReleaseInstance(inst);
}

var newResults = new List<FstResult<TData, TOffset>>();
GetFstResults(newResults, curResults);
return curResults;
}

Expand Down
63 changes: 54 additions & 9 deletions src/SIL.Machine/FiniteState/TraversalMethodBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ internal abstract class TraversalMethodBase<TData, TOffset, TInst> : ITraversalM
private readonly Queue<TInst> _cachedInstances;
private readonly IDictionary<TInst, IList<CommandUpdate>> _commandUpdates;
private readonly IDictionary<TInst, IList<TraverseOutput>> _outputs;
private readonly TInst _finalInst;

protected TraversalMethodBase(
Fst<TData, TOffset> fst,
Expand Down Expand Up @@ -67,6 +68,7 @@ bool useDefaults
_cachedInstances = new Queue<TInst>();
_commandUpdates = new Dictionary<TInst, IList<CommandUpdate>>();
_outputs = new Dictionary<TInst, IList<TraverseOutput>>();
_finalInst = CreateInstance();
}

private int CompareAnnotations(Annotation<TOffset> x, Annotation<TOffset> y)
Expand Down Expand Up @@ -173,6 +175,34 @@ protected bool CheckInputMatch(Arc<TData, TOffset> arc, int annIndex, VariableBi
);
}

private void RecordAccepting(
TInst inst,
Arc<TData, TOffset> arc
)
{
if (arc.Target.IsAccepting && (!_endAnchor || inst.AnnotationIndex == _annotations.Count))
{
RecordCommands(inst, arc, arc.Target.Finishers, new Register<TOffset>(), new Register<TOffset>(), _finalInst);
}
}

public void GetFstResults(ICollection<FstResult<TData, TOffset>> curResults, ICollection<FstResult<TData, TOffset>> oldResults)
{
foreach (TraverseOutput output in GetOutputs(_finalInst))
{
CheckAccepting(
output.Instance.AnnotationIndex,
output.Registers,
output.Output,
output.Instance.VariableBindings,
output.Arc,
curResults,
output.Instance.Priorities,
output.Instance);
}
Debug.Assert(curResults.Count.Equals(oldResults.Count), "results didn't match");
}

private void CheckAccepting(
int annIndex,
Register<TOffset>[,] registers,
Expand All @@ -190,17 +220,19 @@ TInst inst
annIndex < _annotations.Count ? _annotations[annIndex] : _data.Annotations.GetEnd(_fst.Direction);
var matchRegisters = (Register<TOffset>[,])registers.Clone();
ExecuteCommands(matchRegisters, arc.Target.Finishers, new Register<TOffset>(), new Register<TOffset>());
TInst finalInst = CreateInstance();
RecordCommands(inst, null, arc.Target.Finishers, new Register<TOffset>(), new Register<TOffset>(), finalInst);
if (true)
{
TInst finalInst = CreateInstance();
RecordCommands(inst, null, arc.Target.Finishers, new Register<TOffset>(), new Register<TOffset>(), finalInst);
var outputs = GetOutputs(finalInst);
Debug.Assert(_fst.RegistersEqualityComparer.Equals(outputs[0].Registers, matchRegisters), "registers didn't match");
if (output != null)
Debug.Assert(outputs[0].Output.ToString().Equals(output.ToString()), "output didn't match");
if (varBindings != null)
Debug.Assert(varBindings.ValueEquals(inst.VariableBindings), "varBindings didn't match");
Debug.Assert(annIndex.Equals(inst.AnnotationIndex), "annIndex didn't match");
if (priorities != null)
Debug.Assert(priorities.Equals(inst.Priorities), "priorities didn't match");
}
if (arc.Target.AcceptInfos.Count > 0)
{
Expand Down Expand Up @@ -251,6 +283,8 @@ TInst inst

protected class TraverseOutput
{
public TInst Instance;
public Arc<TData, TOffset> Arc;
public Register<TOffset>[,] Registers;
public TData Output;
public IDictionary<Annotation<TOffset>, Annotation<TOffset>> Mappings;
Expand Down Expand Up @@ -281,6 +315,8 @@ private IList<TraverseOutput> GetOutputs(TInst inst)
IList<CommandUpdate> updates = GetCommandUpdates(inst);
if (updates.Count == 0)
{
if (inst == _finalInst)
return outputs;
// We are at the beginning.
var registers = inst != null ? inst.Registers : new Register<TOffset>[Fst.RegisterCount, 2];
var dataOutput = ((ICloneable<TData>)Data).Clone();
Expand All @@ -296,12 +332,14 @@ private IList<TraverseOutput> GetOutputs(TInst inst)
}
foreach (CommandUpdate update in updates)
{
foreach(TraverseOutput output in GetOutputs(update.Source))
foreach (TraverseOutput output in GetOutputs(update.Source))
{
var newOutput = new TraverseOutput(output);
newOutput.Instance = update.Source;
newOutput.Arc = update.Arc;
if (update.Cmds != null)
ExecuteCommands(newOutput.Registers, update.Cmds, update.Start, update.End);
if (update.Arc != null)
if (update.Arc != null && inst != _finalInst)
{
for (int j = 0; j < update.Arc.Input.EnqueueCount; j++)
newOutput.Queue.Enqueue(Annotations[update.Source.AnnotationIndex]);
Expand Down Expand Up @@ -342,9 +380,15 @@ protected IEnumerable<TInst> Initialize(
ref int annIndex,
Register<TOffset>[,] registers,
IList<TagMapCommand> cmds,
ISet<int> initAnns
ISet<int> initAnns,
Boolean first = true
)
{
if (first)
{
_commandUpdates.Clear();
_outputs.Clear();
}
var insts = new List<TInst>();
TOffset offset = _annotations[annIndex].Range.GetStart(_fst.Direction);

Expand All @@ -362,7 +406,7 @@ ISet<int> initAnns
if (nextIndex != _annotations.Count)
{
insts.AddRange(
Initialize(ref nextIndex, (Register<TOffset>[,])registers.Clone(), cmds, initAnns)
Initialize(ref nextIndex, (Register<TOffset>[,])registers.Clone(), cmds, initAnns, false)
);
}
}
Expand Down Expand Up @@ -394,12 +438,10 @@ ISet<int> initAnns
inst.VariableBindings = _varBindings != null ? _varBindings.Clone() : new VariableBindings();
insts.Add(inst);
initAnns.Add(annIndex);
RecordCommands(startInst, null, cmds, new Register<TOffset>(offset, true), new Register<TOffset>(), inst);
}
}

foreach (var inst in insts)
RecordCommands(startInst, null, cmds, new Register<TOffset>(offset, true), new Register<TOffset>(), inst);

return insts;
}

Expand Down Expand Up @@ -480,6 +522,7 @@ protected IEnumerable<TInst> Advance(
inst.Priorities,
inst
);
RecordAccepting(inst, arc);
}

inst.State = arc.Target;
Expand Down Expand Up @@ -518,6 +561,7 @@ protected IEnumerable<TInst> Advance(
inst.AnnotationIndex = nextIndex;
inst.VariableBindings = varBindings;
CheckAccepting(nextIndex, inst.Registers, inst.Output, varBindings, arc, curResults, inst.Priorities, inst);
RecordAccepting(inst, arc);

yield return inst;
}
Expand Down Expand Up @@ -561,6 +605,7 @@ ICollection<FstResult<TData, TOffset>> curResults
inst.Priorities,
inst
);
RecordAccepting(inst, arc);

inst.State = arc.Target;
return inst;
Expand Down

0 comments on commit 2e29512

Please sign in to comment.