diff --git a/src/BaseContext.cs b/src/BaseContext.cs index 8c6cf51..3ab2934 100644 --- a/src/BaseContext.cs +++ b/src/BaseContext.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -64,6 +65,8 @@ public void UpdateLoopTime() CurrentLoopTime = $"{DateTime.Now:yyyyMMddHHmmss}"; } + public HashSet CollectedDomainSids { get; } = new(); + public async Task DoDelay() { if (Throttle == 0) diff --git a/src/Client/Context.cs b/src/Client/Context.cs index 4e52f96..7e24734 100644 --- a/src/Client/Context.cs +++ b/src/Client/Context.cs @@ -75,5 +75,6 @@ public interface IContext string ResolveFileName(string filename, string extension, bool addTimestamp); EnumerationDomain[] Domains { get; set; } void UpdateLoopTime(); + public HashSet CollectedDomainSids { get; } } } \ No newline at end of file diff --git a/src/Producers/LdapProducer.cs b/src/Producers/LdapProducer.cs index 82a456c..fb6027e 100644 --- a/src/Producers/LdapProducer.cs +++ b/src/Producers/LdapProducer.cs @@ -44,7 +44,7 @@ public override async Task Produce() foreach (var domain in Context.Domains) { - Context.Logger.LogInformation("Beginning LDAP search for {Domain}", domain); + Context.Logger.LogInformation("Beginning LDAP search for {Domain}", domain.Name); //Do a basic LDAP search and grab results var successfulConnect = false; try @@ -64,14 +64,7 @@ public override async Task Produce() continue; } - await OutputChannel.Writer.WriteAsync(new Domain - { - ObjectIdentifier = domain.DomainSid, - Properties = new Dictionary - { - { "collected", true }, - } - }); + Context.CollectedDomainSids.Add(domain.DomainSid); foreach (var searchResult in Context.LDAPUtils.QueryLDAP(ldapData.Filter.GetFilter(), SearchScope.Subtree, ldapData.Props.Distinct().ToArray(), cancellationToken, domain.Name, diff --git a/src/Runtime/LDAPConsumer.cs b/src/Runtime/LDAPConsumer.cs index c298d5b..6960bd5 100644 --- a/src/Runtime/LDAPConsumer.cs +++ b/src/Runtime/LDAPConsumer.cs @@ -37,6 +37,11 @@ internal static async Task ConsumeSearchResults(Channel inpu watch.Elapsed.TotalMilliseconds, res.DisplayName); if (processed == null) continue; + + if (processed is Domain d && context.CollectedDomainSids.Contains(d.ObjectIdentifier)) + { + d.Properties.Add("collected", true); + } await outputChannel.Writer.WriteAsync(processed); } catch (Exception e)