From c8a27cc7576e432287df23f2a97f7bf55727a13f Mon Sep 17 00:00:00 2001 From: Damien Daspit Date: Wed, 6 Sep 2023 17:51:12 -0500 Subject: [PATCH] Refactor ClearML NMT build job - add support for multiple build stages - add support for running build jobs on Hangfire or ClearML - add BuildJobService - categorize build jobs into CPU or GPU jobs - decouple build job runners from translation engines - fix issues with S3FileStorage - fix issues with ClearMLService --- .../Configuration/BuildJobOptions.cs | 9 + ...LNmtEngineOptions.cs => ClearMLOptions.cs} | 7 +- .../IMachineBuilderExtensions.cs | 121 ++++-- .../IServiceCollectionExtensions.cs | 8 +- .../Configuration/SmtTransferEngineOptions.cs | 2 - .../Configuration/ThotSmtModelOptions.cs | 2 +- src/SIL.Machine.AspNetCore/Models/Build.cs | 24 ++ .../Models/ClearMLMetricsEvent.cs | 12 + .../Models/ClearMLTask.cs | 1 + .../Models/TranslationEngine.cs | 12 +- .../SIL.Machine.AspNetCore.csproj | 14 +- .../Services/BuildJobService.cs | 223 +++++++++++ .../Services/ClearMLAuthenticationService.cs | 29 +- .../Services/ClearMLBuildJobRunner.cs | 68 ++++ .../Services/ClearMLHealthCheck.cs | 4 +- .../Services/ClearMLMonitorService.cs | 324 +++++++++++++++ .../Services/ClearMLNmtEngineBuildJob.cs | 374 ------------------ .../Services/ClearMLNmtEngineService.cs | 51 --- .../Services/ClearMLService.cs | 112 ++---- .../Services/FileStorage.cs | 43 -- .../Services/HangfireBuildJob.cs | 159 ++++++++ .../Services/HangfireBuildJobRunner.cs | 74 ++++ .../Services/IBuildJobRunner.cs | 24 ++ .../Services/IBuildJobService.cs | 56 +++ .../Services/IClearMLBuildJobFactory.cs | 14 + .../Services/IClearMLService.cs | 11 +- .../Services/IFileStorage.cs | 18 + .../Services/IHangfireBuildJobFactory.cs | 8 + .../Services/ISharedFileService.cs | 6 - .../Services/InMemoryStorage.cs | 58 ++- .../Services/LocalStorage.cs | 101 ++--- .../Services/NmtClearMLBuildJobFactory.cs | 71 ++++ .../Services/NmtEngineService.cs | 153 +++++++ .../Services/NmtHangfireBuildJobFactory.cs | 21 + .../Services/NmtPostprocessBuildJob.cs | 85 ++++ .../Services/NmtPreprocessBuildJob.cs | 157 ++++++++ .../Services/NmtTrainBuildJob.cs | 139 +++++++ .../Services/S3FileStorage.cs | 54 ++- .../Services/SharedFileService.cs | 28 +- .../Services/SmtTransferBuildJob.cs | 126 ++++++ .../Services/SmtTransferEngineBuildJob.cs | 215 ---------- .../SmtTransferEngineCommitService.cs | 49 ++- .../Services/SmtTransferEngineService.cs | 153 ++++--- .../Services/SmtTransferEngineState.cs | 20 +- .../Services/SmtTransferEngineStateService.cs | 21 +- .../SmtTransferHangfireBuildJobFactory.cs | 18 + .../Services/TranslationEngineServiceBase.cs | 228 ----------- src/SIL.Machine.AspNetCore/Usings.cs | 10 +- .../Utils/RecurrentTask.cs | 49 +++ .../Utils/SharedFileUtils.cs | 28 ++ .../Program.cs | 5 +- .../appsettings.Development.json | 13 +- .../appsettings.json | 14 + src/SIL.Machine.Serval.JobServer/Program.cs | 7 +- .../appsettings.Development.json | 13 +- .../appsettings.json | 14 + src/SIL.Machine/Utils/TempDirectory.cs | 6 +- .../SIL.Machine.AspNetCore.Tests.csproj | 2 +- .../Services/ClearMLNmtEngineServiceTests.cs | 192 --------- .../Services/ClearMLServiceTests.cs | 40 +- .../Services/FileStorageTests.cs | 178 --------- .../Services/InMemoryStorageTests.cs | 91 +++++ .../Services/LocalStorageTests.cs | 96 +++++ .../Services/NmtEngineServiceTests.cs | 272 +++++++++++++ .../Services/SmtTransferEngineServiceTests.cs | 186 +++++++-- tests/SIL.Machine.AspNetCore.Tests/Usings.cs | 7 +- 66 files changed, 2954 insertions(+), 1776 deletions(-) create mode 100644 src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs rename src/SIL.Machine.AspNetCore/Configuration/{ClearMLNmtEngineOptions.cs => ClearMLOptions.cs} (79%) create mode 100644 src/SIL.Machine.AspNetCore/Models/Build.cs create mode 100644 src/SIL.Machine.AspNetCore/Models/ClearMLMetricsEvent.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/BuildJobService.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineBuildJob.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineService.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/FileStorage.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/HangfireBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/IFileStorage.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferEngineBuildJob.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs delete mode 100644 src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs create mode 100644 src/SIL.Machine.AspNetCore/Utils/RecurrentTask.cs create mode 100644 src/SIL.Machine.AspNetCore/Utils/SharedFileUtils.cs delete mode 100644 tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLNmtEngineServiceTests.cs delete mode 100644 tests/SIL.Machine.AspNetCore.Tests/Services/FileStorageTests.cs create mode 100644 tests/SIL.Machine.AspNetCore.Tests/Services/InMemoryStorageTests.cs create mode 100644 tests/SIL.Machine.AspNetCore.Tests/Services/LocalStorageTests.cs create mode 100644 tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs diff --git a/src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs b/src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs new file mode 100644 index 000000000..d761ac4d0 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Configuration/BuildJobOptions.cs @@ -0,0 +1,9 @@ +namespace SIL.Machine.AspNetCore.Configuration; + +public class BuildJobOptions +{ + public const string Key = "BuildJob"; + + public Dictionary Runners { get; set; } = + new() { { BuildJobType.Cpu, BuildJobRunner.Hangfire }, { BuildJobType.Gpu, BuildJobRunner.ClearML } }; +} diff --git a/src/SIL.Machine.AspNetCore/Configuration/ClearMLNmtEngineOptions.cs b/src/SIL.Machine.AspNetCore/Configuration/ClearMLOptions.cs similarity index 79% rename from src/SIL.Machine.AspNetCore/Configuration/ClearMLNmtEngineOptions.cs rename to src/SIL.Machine.AspNetCore/Configuration/ClearMLOptions.cs index b34948d04..c53802cea 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/ClearMLNmtEngineOptions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/ClearMLOptions.cs @@ -1,14 +1,15 @@ namespace SIL.Machine.AspNetCore.Configuration; -public class ClearMLNmtEngineOptions +public class ClearMLOptions { - public const string Key = "ClearMLNmtEngine"; + public const string Key = "ClearML"; public string ApiServer { get; set; } = "http://localhost:8008"; public string Queue { get; set; } = "default"; public string AccessKey { get; set; } = ""; public string SecretKey { get; set; } = ""; - public TimeSpan BuildPollingTimeout { get; set; } = TimeSpan.FromSeconds(2); + public bool BuildPollingEnabled { get; set; } = false; + public TimeSpan BuildPollingTimeout { get; set; } = TimeSpan.FromSeconds(10); public string ModelType { get; set; } = "huggingface"; public int MaxSteps { get; set; } = 20_000; public string RootProject { get; set; } = "Machine"; diff --git a/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs b/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs index c4b95abef..6aa6bbd68 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs @@ -1,5 +1,4 @@ -using Microsoft.AspNetCore.Http; -using Serval.Translation.V1; +using Serval.Translation.V1; namespace Microsoft.Extensions.DependencyInjection; @@ -35,18 +34,18 @@ public static IMachineBuilder AddSmtTransferEngineOptions(this IMachineBuilder b return builder; } - public static IMachineBuilder AddClearMLNmtEngineOptions( + public static IMachineBuilder AddClearMLOptions( this IMachineBuilder builder, - Action configureOptions + Action configureOptions ) { builder.Services.Configure(configureOptions); return builder; } - public static IMachineBuilder AddClearMLNmtEngineOptions(this IMachineBuilder builder, IConfiguration config) + public static IMachineBuilder AddClearMLOptions(this IMachineBuilder builder, IConfiguration config) { - builder.Services.Configure(config); + builder.Services.Configure(config); return builder; } @@ -67,8 +66,10 @@ public static IMachineBuilder AddSharedFileOptions(this IMachineBuilder builder, public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder) { - builder.Services.AddSingleton(); - return builder; + if (builder.Configuration is null) + return builder.AddThotSmtModel(o => { }); + else + return builder.AddThotSmtModel(builder.Configuration.GetSection(ThotSmtModelOptions.Key)); } public static IMachineBuilder AddThotSmtModel( @@ -77,13 +78,15 @@ Action configureOptions ) { builder.Services.Configure(configureOptions); - return builder.AddThotSmtModel(); + builder.Services.AddSingleton(); + return builder; } public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder, IConfiguration config) { builder.Services.Configure(config); - return builder.AddThotSmtModel(); + builder.Services.AddSingleton(); + return builder; } public static IMachineBuilder AddTransferEngine(this IMachineBuilder builder) @@ -98,7 +101,7 @@ public static IMachineBuilder AddUnigramTruecaser(this IMachineBuilder builder) return builder; } - public static IMachineBuilder AddClearMLService(this IMachineBuilder builder) + private static IMachineBuilder AddClearMLBuildJobRunner(this IMachineBuilder builder) { builder.Services.AddSingleton(); //Add retry policy; fail after approx. 2 + 4 + 8 = 14 seconds @@ -111,20 +114,33 @@ public static IMachineBuilder AddClearMLService(this IMachineBuilder builder) // workaround register satisfying the interface and as a hosted service. builder.Services.AddSingleton(); builder.Services.AddHostedService(p => p.GetRequiredService()); - //Add retry policy; fail after approx. 2 + 4 + 8 = 14 seconds + // Add retry policy; fail after approx. 2 + 4 + 8 = 14 seconds builder.Services .AddHttpClient() .AddTransientHttpErrorPolicy( b => b.WaitAndRetryAsync(3, retryAttempt => TimeSpan.FromSeconds(Math.Pow(2, retryAttempt))) ); - builder.Services.AddSingleton(); + builder.Services.AddScoped(); + builder.Services.AddScoped(); + builder.Services.AddHostedService(); + builder.Services.AddHealthChecks().AddCheck("ClearML Health Check"); return builder; } - public static IMachineBuilder AddMongoBackgroundJobClient( + private static IMachineBuilder AddHangfireBuildJobRunner(this IMachineBuilder builder) + { + builder.Services.AddScoped(); + + builder.Services.AddScoped(); + builder.Services.AddScoped(); + + return builder; + } + + public static IMachineBuilder AddMongoHangfireJobClient( this IMachineBuilder builder, string? connectionString = null ) @@ -147,12 +163,13 @@ public static IMachineBuilder AddMongoBackgroundJobClient( CheckQueuedJobsStrategy = CheckQueuedJobsStrategy.TailNotificationsCollection, } ) + .UseFilter(new AutomaticRetryAttribute { Attempts = 0 }) ); builder.Services.AddHealthChecks().AddCheck(name: "Hangfire"); return builder; } - public static IMachineBuilder AddBackgroundJobServer( + public static IMachineBuilder AddHangfireJobServer( this IMachineBuilder builder, IEnumerable? engineTypes = null ) @@ -170,7 +187,6 @@ public static IMachineBuilder AddBackgroundJobServer( queues.Add("smt_transfer"); break; case TranslationEngineType.Nmt: - builder.AddClearMLService(); queues.Add("nmt"); break; } @@ -205,28 +221,24 @@ public static IMachineBuilder AddMongoDataAccess(this IMachineBuilder builder, s { o.AddRepository( "translation_engines", - init: c => - c.Indexes.CreateOrUpdateAsync( - new CreateIndexModel( - Builders.IndexKeys.Ascending(p => p.EngineId) - ) - ) - ); - o.AddRepository( - "locks", + mapSetup: m => m.SetIgnoreExtraElements(true), init: async c => { await c.Indexes.CreateOrUpdateAsync( - new CreateIndexModel(Builders.IndexKeys.Ascending("writerLock._id")) - ); - await c.Indexes.CreateOrUpdateAsync( - new CreateIndexModel(Builders.IndexKeys.Ascending("readerLocks._id")) + new CreateIndexModel( + Builders.IndexKeys + .Ascending(e => e.EngineId) + .Ascending("currentBuild._id") + ) ); await c.Indexes.CreateOrUpdateAsync( - new CreateIndexModel(Builders.IndexKeys.Ascending("writerQueue._id")) + new CreateIndexModel( + Builders.IndexKeys.Ascending(e => e.CurrentBuild!.JobRunner) + ) ); } ); + o.AddRepository("locks"); o.AddRepository( "train_segment_pairs", init: c => @@ -313,8 +325,7 @@ public static IMachineBuilder AddServalTranslationEngineService( builder.Services.AddScoped(); break; case TranslationEngineType.Nmt: - builder.AddClearMLService(); - builder.Services.AddScoped(); + builder.Services.AddScoped(); break; } } @@ -322,4 +333,50 @@ public static IMachineBuilder AddServalTranslationEngineService( return builder; } + + public static IMachineBuilder AddBuildJobService( + this IMachineBuilder builder, + Action configureOptions + ) + { + builder.Services.Configure(configureOptions); + var options = new BuildJobOptions(); + configureOptions(options); + return builder.AddBuildJobService(options); + } + + public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, IConfiguration config) + { + builder.Services.Configure(config); + var options = config.Get(); + return builder.AddBuildJobService(options); + } + + public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder) + { + if (builder.Configuration is null) + builder.AddBuildJobService(o => { }); + else + builder.AddBuildJobService(builder.Configuration.GetSection(BuildJobOptions.Key)); + return builder; + } + + private static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, BuildJobOptions options) + { + builder.Services.AddScoped(); + + foreach (BuildJobRunner runnerType in options.Runners.Values.Distinct()) + { + switch (runnerType) + { + case BuildJobRunner.ClearML: + builder.AddClearMLBuildJobRunner(); + break; + case BuildJobRunner.Hangfire: + builder.AddHangfireBuildJobRunner(); + break; + } + } + return builder; + } } diff --git a/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs b/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs index f659d3ca3..00301da47 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/IServiceCollectionExtensions.cs @@ -4,9 +4,13 @@ public static class IServiceCollectionExtensions { public static IMachineBuilder AddMachine(this IServiceCollection services, IConfiguration? configuration = null) { + if (!Sldr.IsInitialized) + Sldr.Initialize(); + services.AddSingleton(); services.AddSingleton(); services.AddHealthChecks().AddCheck("S3 Bucket"); + services.AddScoped(); services.AddSingleton(); services.AddStartupTask((sp, ct) => sp.GetRequiredService().InitAsync(ct)); @@ -17,14 +21,14 @@ public static IMachineBuilder AddMachine(this IServiceCollection services, IConf builder.AddServiceOptions(o => { }); builder.AddSharedFileOptions(o => { }); builder.AddSmtTransferEngineOptions(o => { }); - builder.AddClearMLNmtEngineOptions(o => { }); + builder.AddClearMLOptions(o => { }); } else { builder.AddServiceOptions(configuration.GetSection(ServiceOptions.Key)); builder.AddSharedFileOptions(configuration.GetSection(SharedFileOptions.Key)); builder.AddSmtTransferEngineOptions(configuration.GetSection(SmtTransferEngineOptions.Key)); - builder.AddClearMLNmtEngineOptions(configuration.GetSection(ClearMLNmtEngineOptions.Key)); + builder.AddClearMLOptions(configuration.GetSection(ClearMLOptions.Key)); } return builder; } diff --git a/src/SIL.Machine.AspNetCore/Configuration/SmtTransferEngineOptions.cs b/src/SIL.Machine.AspNetCore/Configuration/SmtTransferEngineOptions.cs index 416d3302b..67df3d1d5 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/SmtTransferEngineOptions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/SmtTransferEngineOptions.cs @@ -7,6 +7,4 @@ public class SmtTransferEngineOptions public string EnginesDir { get; set; } = "translation_engines"; public TimeSpan EngineCommitFrequency { get; set; } = TimeSpan.FromMinutes(5); public TimeSpan InactiveEngineTimeout { get; set; } = TimeSpan.FromMinutes(10); - public ISet Types { get; set; } = - new HashSet { TranslationEngineType.Nmt, TranslationEngineType.SmtTransfer }; } diff --git a/src/SIL.Machine.AspNetCore/Configuration/ThotSmtModelOptions.cs b/src/SIL.Machine.AspNetCore/Configuration/ThotSmtModelOptions.cs index e0c9b8f87..5941cac46 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/ThotSmtModelOptions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/ThotSmtModelOptions.cs @@ -2,7 +2,7 @@ public class ThotSmtModelOptions { - public const string ThotSmtModel = "ThotSmtModel"; + public const string Key = "ThotSmtModel"; public ThotSmtModelOptions() { diff --git a/src/SIL.Machine.AspNetCore/Models/Build.cs b/src/SIL.Machine.AspNetCore/Models/Build.cs new file mode 100644 index 000000000..a8e630b3f --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Models/Build.cs @@ -0,0 +1,24 @@ +namespace SIL.Machine.AspNetCore.Models; + +public enum BuildJobState +{ + None, + Pending, + Active, + Canceling +} + +public enum BuildJobRunner +{ + Hangfire, + ClearML +} + +public class Build +{ + public string BuildId { get; set; } = default!; + public BuildJobState JobState { get; set; } + public string JobId { get; set; } = default!; + public BuildJobRunner JobRunner { get; set; } + public string Stage { get; set; } = default!; +} diff --git a/src/SIL.Machine.AspNetCore/Models/ClearMLMetricsEvent.cs b/src/SIL.Machine.AspNetCore/Models/ClearMLMetricsEvent.cs new file mode 100644 index 000000000..8c1cc26b5 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Models/ClearMLMetricsEvent.cs @@ -0,0 +1,12 @@ +namespace SIL.Machine.AspNetCore.Models; + +public class ClearMLMetricsEvent +{ + public string Metric { get; set; } = default!; + public string Variant { get; set; } = default!; + public double Value { get; set; } + public double MinValue { get; set; } + public int MinValueIteration { get; set; } + public double MaxValue { get; set; } + public int MaxValueIteration { get; set; } +} diff --git a/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs b/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs index c6fefad9d..573e9a505 100644 --- a/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs +++ b/src/SIL.Machine.AspNetCore/Models/ClearMLTask.cs @@ -24,4 +24,5 @@ public class ClearMLTask public string StatusMessage { get; set; } = default!; public int LastIteration { get; set; } public int ActiveDuration { get; set; } + public Dictionary> LastMetrics { get; set; } = default!; } diff --git a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs index c4d77d351..ffc639fc7 100644 --- a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs +++ b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs @@ -1,12 +1,5 @@ namespace SIL.Machine.AspNetCore.Models; -public enum BuildState -{ - None, - Pending, - Active -} - public class TranslationEngine : IEntity { public string Id { get; set; } = default!; @@ -14,9 +7,6 @@ public class TranslationEngine : IEntity public string EngineId { get; set; } = default!; public string SourceLanguage { get; set; } = default!; public string TargetLanguage { get; set; } = default!; - public BuildState BuildState { get; set; } = BuildState.None; - public bool IsCanceled { get; set; } - public string? BuildId { get; set; } public int BuildRevision { get; set; } - public string? JobId { get; set; } + public Build? CurrentBuild { get; set; } } diff --git a/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj b/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj index bdf0fbc42..9356f66c8 100644 --- a/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj +++ b/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj @@ -26,16 +26,16 @@ - - - - - - + + + + + + - + diff --git a/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs new file mode 100644 index 000000000..d2b00c795 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/BuildJobService.cs @@ -0,0 +1,223 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class BuildJobService : IBuildJobService +{ + private readonly Dictionary _runnersByJobType; + private readonly Dictionary _runners; + private readonly IRepository _engines; + + public BuildJobService( + IEnumerable runners, + IRepository engines, + IOptions options + ) + { + _runners = runners.ToDictionary(r => r.Type); + _runnersByJobType = new Dictionary(); + foreach (KeyValuePair kvp in options.Value.Runners) + _runnersByJobType.Add(kvp.Key, _runners[kvp.Value]); + _engines = engines; + } + + public Task IsEngineBuilding(string engineId, CancellationToken cancellationToken = default) + { + return _engines.ExistsAsync(e => e.EngineId == engineId && e.CurrentBuild != null, cancellationToken); + } + + public Task> GetBuildingEnginesAsync( + BuildJobRunner runner, + CancellationToken cancellationToken = default + ) + { + return _engines.GetAllAsync( + e => e.CurrentBuild != null && e.CurrentBuild.JobRunner == runner, + cancellationToken + ); + } + + public async Task GetBuildAsync( + string engineId, + string buildId, + CancellationToken cancellationToken = default + ) + { + TranslationEngine? engine = await _engines.GetAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, + cancellationToken + ); + return engine?.CurrentBuild; + } + + public async Task CreateEngineAsync( + IEnumerable jobTypes, + string engineId, + string? name = null, + CancellationToken cancellationToken = default + ) + { + foreach (BuildJobType jobType in jobTypes) + { + IBuildJobRunner runner = _runnersByJobType[jobType]; + await runner.CreateEngineAsync(engineId, name, cancellationToken); + } + } + + public async Task DeleteEngineAsync( + IEnumerable jobTypes, + string engineId, + CancellationToken cancellationToken = default + ) + { + foreach (BuildJobType jobType in jobTypes) + { + IBuildJobRunner runner = _runnersByJobType[jobType]; + await runner.DeleteEngineAsync(engineId, cancellationToken); + } + } + + public async Task StartBuildJobAsync( + BuildJobType jobType, + TranslationEngineType engineType, + string engineId, + string buildId, + string stage, + object? data = null, + CancellationToken cancellationToken = default + ) + { + if ( + !await _engines.ExistsAsync( + e => + e.EngineId == engineId + && (e.CurrentBuild == null || e.CurrentBuild.JobState != BuildJobState.Canceling), + cancellationToken + ) + ) + { + return false; + } + + IBuildJobRunner runner = _runnersByJobType[jobType]; + string jobId = await runner.CreateJobAsync(engineType, engineId, buildId, stage, data, cancellationToken); + try + { + await _engines.UpdateAsync( + e => e.EngineId == engineId, + u => + u.Set( + e => e.CurrentBuild, + new Build + { + BuildId = buildId, + JobId = jobId, + JobRunner = runner.Type, + Stage = stage, + JobState = BuildJobState.Pending + } + ), + cancellationToken: cancellationToken + ); + await runner.EnqueueJobAsync(jobId, cancellationToken); + return true; + } + catch + { + await runner.DeleteJobAsync(jobId, CancellationToken.None); + throw; + } + } + + public async Task<(string? BuildId, BuildJobState State)> CancelBuildJobAsync( + string engineId, + CancellationToken cancellationToken = default + ) + { + TranslationEngine? engine = await _engines.GetAsync( + e => e.EngineId == engineId && e.CurrentBuild != null, + cancellationToken + ); + if (engine is null || engine.CurrentBuild is null) + return (null, BuildJobState.None); + + IBuildJobRunner runner = _runners[engine.CurrentBuild.JobRunner]; + + if (engine.CurrentBuild.JobState is BuildJobState.Pending) + { + // cancel a job that hasn't started yet + engine = await _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null, + u => u.Unset(b => b.CurrentBuild), + returnOriginal: true, + cancellationToken: cancellationToken + ); + if (engine is not null && engine.CurrentBuild is not null) + { + // job will be deleted from the queue + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.None); + } + } + else if (engine.CurrentBuild.JobState is BuildJobState.Active) + { + // cancel a job that is already running + engine = await _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null, + u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Canceling), + cancellationToken: cancellationToken + ); + if (engine is not null && engine.CurrentBuild is not null) + { + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.Canceling); + } + } + + return (null, BuildJobState.None); + } + + public async Task BuildJobStartedAsync( + string engineId, + string buildId, + CancellationToken cancellationToken = default + ) + { + TranslationEngine? engine = await _engines.UpdateAsync( + e => + e.EngineId == engineId + && e.CurrentBuild != null + && e.CurrentBuild.BuildId == buildId + && e.CurrentBuild.JobState == BuildJobState.Pending, + u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Active), + cancellationToken: cancellationToken + ); + return engine is not null; + } + + public Task BuildJobFinishedAsync( + string engineId, + string buildId, + bool buildComplete, + CancellationToken cancellationToken = default + ) + { + return _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, + u => + { + u.Unset(e => e.CurrentBuild); + if (buildComplete) + u.Inc(e => e.BuildRevision); + }, + cancellationToken: cancellationToken + ); + } + + public Task BuildJobRestartingAsync(string engineId, string buildId, CancellationToken cancellationToken = default) + { + return _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, + u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Pending), + cancellationToken: cancellationToken + ); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLAuthenticationService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLAuthenticationService.cs index 45aa3efb9..9581cf78e 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLAuthenticationService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLAuthenticationService.cs @@ -1,22 +1,24 @@ namespace SIL.Machine.AspNetCore.Services; -public class ClearMLAuthenticationService : BackgroundService, IClearMLAuthenticationService +public class ClearMLAuthenticationService : RecurrentTask, IClearMLAuthenticationService { private readonly HttpClient _httpClient; - private readonly IOptionsMonitor _options; + private readonly IOptionsMonitor _options; private readonly ILogger _logger; private readonly AsyncLock _lock = new(); // technically, the token should be good for 30 days, but let's refresh each hour // to know well ahead of time if something is wrong. - private const int RefreshPeriod = 3600; + private static readonly TimeSpan RefreshPeriod = TimeSpan.FromSeconds(3600); private string _authToken = ""; public ClearMLAuthenticationService( + IServiceProvider services, HttpClient httpClient, - IOptionsMonitor options, + IOptionsMonitor options, ILogger logger ) + : base("ClearML authentication service", services, RefreshPeriod, logger) { _httpClient = httpClient; _options = options; @@ -25,7 +27,7 @@ ILogger logger public async Task GetAuthTokenAsync(CancellationToken cancellationToken = default) { - using (await _lock.LockAsync()) + using (await _lock.LockAsync(cancellationToken)) { if (_authToken is "") { @@ -37,20 +39,17 @@ public async Task GetAuthTokenAsync(CancellationToken cancellationToken return _authToken; } - protected override async Task ExecuteAsync(CancellationToken stoppingToken) + protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken) { - _logger.LogInformation("ClearML Authentication Token Refresh service running - and has initial token."); try { - while (!stoppingToken.IsCancellationRequested) - { - await Task.Delay(TimeSpan.FromSeconds(RefreshPeriod), stoppingToken); - using (await _lock.LockAsync()) - await AuthorizeAsync(stoppingToken); - } + using (await _lock.LockAsync(cancellationToken)) + await AuthorizeAsync(cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Error occurred while refreshing ClearML authentication token."); } - catch (TaskCanceledException) { } - _logger.LogInformation("ClearML Authentication Token Refresh service successfully stopped"); } private async Task AuthorizeAsync(CancellationToken cancellationToken) diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs new file mode 100644 index 000000000..677966c19 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLBuildJobRunner.cs @@ -0,0 +1,68 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class ClearMLBuildJobRunner : IBuildJobRunner +{ + private readonly IClearMLService _clearMLService; + private readonly Dictionary _buildJobFactories; + + public ClearMLBuildJobRunner(IClearMLService clearMLService, IEnumerable buildJobFactories) + { + _clearMLService = clearMLService; + _buildJobFactories = buildJobFactories.ToDictionary(f => f.EngineType); + } + + public BuildJobRunner Type => BuildJobRunner.ClearML; + + public async Task CreateEngineAsync( + string engineId, + string? name = null, + CancellationToken cancellationToken = default + ) + { + await _clearMLService.CreateProjectAsync(engineId, name, cancellationToken); + } + + public async Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default) + { + string? projectId = await _clearMLService.GetProjectIdAsync(engineId, cancellationToken); + if (projectId is not null) + await _clearMLService.DeleteProjectAsync(projectId, cancellationToken); + } + + public async Task CreateJobAsync( + TranslationEngineType engineType, + string engineId, + string buildId, + string stage, + object? data = null, + CancellationToken cancellationToken = default + ) + { + string? projectId = await _clearMLService.GetProjectIdAsync(engineId, cancellationToken); + if (projectId is null) + throw new InvalidOperationException("The project does not exist."); + + ClearMLTask? task = await _clearMLService.GetTaskByNameAsync(buildId, cancellationToken); + if (task is not null) + return task.Id; + + IClearMLBuildJobFactory buildJobFactory = _buildJobFactories[engineType]; + string script = await buildJobFactory.CreateJobScriptAsync(engineId, buildId, stage, data, cancellationToken); + return await _clearMLService.CreateTaskAsync(buildId, projectId, script, cancellationToken); + } + + public Task DeleteJobAsync(string jobId, CancellationToken cancellationToken = default) + { + return _clearMLService.DeleteTaskAsync(jobId, cancellationToken); + } + + public Task EnqueueJobAsync(string jobId, CancellationToken cancellationToken = default) + { + return _clearMLService.EnqueueTaskAsync(jobId, cancellationToken); + } + + public Task StopJobAsync(string jobId, CancellationToken cancellationToken = default) + { + return _clearMLService.StopTaskAsync(jobId, cancellationToken); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs index ac08ae84d..82676906f 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLHealthCheck.cs @@ -1,13 +1,13 @@ public class ClearMLHealthCheck : IHealthCheck { private readonly HttpClient _httpClient; - private readonly IOptionsMonitor _options; + private readonly IOptionsMonitor _options; private readonly IClearMLAuthenticationService _clearMLAuthenticationService; public ClearMLHealthCheck( IClearMLAuthenticationService clearMLAuthenticationService, HttpClient httpClient, - IOptionsMonitor options + IOptionsMonitor options ) { _httpClient = httpClient; diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs new file mode 100644 index 000000000..9f15e8ea8 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLMonitorService.cs @@ -0,0 +1,324 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class ClearMLMonitorService : RecurrentTask +{ + private static readonly string EvalMetric = CreateMD5("eval"); + private static readonly string BleuVariant = CreateMD5("bleu"); + + private static readonly string SummaryMetric = CreateMD5("Summary"); + private static readonly string CorpusSizeVariant = CreateMD5("corpus_size"); + + private readonly IClearMLService _clearMLService; + private readonly ISharedFileService _sharedFileService; + private readonly ILogger _logger; + private readonly Dictionary _curBuildStatus = new(); + + public ClearMLMonitorService( + IServiceProvider services, + IClearMLService clearMLService, + ISharedFileService sharedFileService, + IOptions options, + ILogger logger + ) + : base( + "ClearML monitor service", + services, + options.Value.BuildPollingTimeout, + logger, + options.Value.BuildPollingEnabled + ) + { + _clearMLService = clearMLService; + _sharedFileService = sharedFileService; + _logger = logger; + } + + protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken) + { + try + { + var buildJobService = scope.ServiceProvider.GetRequiredService(); + IReadOnlyList trainingEngines = await buildJobService.GetBuildingEnginesAsync( + BuildJobRunner.ClearML, + cancellationToken + ); + if (trainingEngines.Count == 0) + return; + + Dictionary tasks = ( + await _clearMLService.GetTasksByIdAsync( + trainingEngines.Select(e => e.CurrentBuild!.JobId), + cancellationToken + ) + ).ToDictionary(t => t.Id); + + var platformService = scope.ServiceProvider.GetRequiredService(); + var lockFactory = scope.ServiceProvider.GetRequiredService(); + foreach (TranslationEngine engine in trainingEngines) + { + if (engine.CurrentBuild is null || !tasks.TryGetValue(engine.CurrentBuild.JobId, out ClearMLTask? task)) + continue; + + if (engine.CurrentBuild.Stage == NmtBuildStages.Train) + { + if ( + engine.CurrentBuild.JobState is BuildJobState.Pending + && task.Status + is ClearMLTaskStatus.InProgress + or ClearMLTaskStatus.Stopped + or ClearMLTaskStatus.Failed + or ClearMLTaskStatus.Completed + ) + { + bool canceled = !await TrainJobStartedAsync( + lockFactory, + buildJobService, + platformService, + engine.EngineId, + engine.CurrentBuild.BuildId, + cancellationToken + ); + if (canceled) + continue; + } + + switch (task.Status) + { + case ClearMLTaskStatus.InProgress: + await UpdateTrainJobStatus( + platformService, + engine.CurrentBuild.BuildId, + new ProgressStatus(task.LastIteration), + cancellationToken + ); + break; + + case ClearMLTaskStatus.Completed: + await UpdateTrainJobStatus( + platformService, + engine.CurrentBuild.BuildId, + new ProgressStatus(task.LastIteration), + cancellationToken + ); + bool canceling = !await TrainJobCompletedAsync( + lockFactory, + buildJobService, + engine.EngineId, + engine.CurrentBuild.BuildId, + (int)GetMetric(task, SummaryMetric, CorpusSizeVariant), + GetMetric(task, EvalMetric, BleuVariant), + cancellationToken + ); + if (canceling) + { + await TrainJobCanceledAsync( + lockFactory, + buildJobService, + platformService, + engine.EngineId, + engine.CurrentBuild.BuildId, + cancellationToken + ); + } + break; + + case ClearMLTaskStatus.Stopped: + await TrainJobCanceledAsync( + lockFactory, + buildJobService, + platformService, + engine.EngineId, + engine.CurrentBuild.BuildId, + cancellationToken + ); + break; + + case ClearMLTaskStatus.Failed: + await TrainJobFaultedAsync( + lockFactory, + buildJobService, + platformService, + engine.EngineId, + engine.CurrentBuild.BuildId, + $"{task.StatusReason} : {task.StatusMessage}", + cancellationToken + ); + break; + } + } + } + } + catch (Exception e) + { + _logger.LogError(e, "Error occurred while monitoring ClearML tasks."); + } + } + + private async Task TrainJobStartedAsync( + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + IPlatformService platformService, + string engineId, + string buildId, + CancellationToken cancellationToken = default + ) + { + IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + if (!await buildJobService.BuildJobStartedAsync(engineId, buildId, cancellationToken)) + return false; + } + await platformService.BuildStartedAsync(buildId, CancellationToken.None); + _logger.LogInformation("Build started ({0})", buildId); + return true; + } + + private async Task TrainJobCompletedAsync( + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + string engineId, + string buildId, + int corpusSize, + double confidence, + CancellationToken cancellationToken + ) + { + try + { + IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + return await buildJobService.StartBuildJobAsync( + BuildJobType.Cpu, + TranslationEngineType.Nmt, + engineId, + buildId, + NmtBuildStages.Postprocess, + (corpusSize, confidence), + cancellationToken + ); + } + } + finally + { + _curBuildStatus.Remove(buildId); + } + } + + private async Task TrainJobFaultedAsync( + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + IPlatformService platformService, + string engineId, + string buildId, + string message, + CancellationToken cancellationToken + ) + { + try + { + IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + await platformService.BuildFaultedAsync(buildId, message, cancellationToken); + await buildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + } + _logger.LogError("Build faulted ({0}). Error: {1}", buildId, message); + } + finally + { + try + { + await _sharedFileService.DeleteAsync($"builds/{buildId}/", CancellationToken.None); + } + catch (Exception e) + { + _logger.LogWarning(e, "Unable to to delete job data for build {0}.", buildId); + } + _curBuildStatus.Remove(buildId); + } + } + + private async Task TrainJobCanceledAsync( + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + IPlatformService platformService, + string engineId, + string buildId, + CancellationToken cancellationToken + ) + { + try + { + IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + await platformService.BuildCanceledAsync(buildId, cancellationToken); + await buildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + } + _logger.LogInformation("Build canceled ({0})", buildId); + } + finally + { + try + { + await _sharedFileService.DeleteAsync($"builds/{buildId}/", CancellationToken.None); + } + catch (Exception e) + { + _logger.LogWarning(e, "Unable to to delete job data for build {0}.", buildId); + } + _curBuildStatus.Remove(buildId); + } + } + + private async Task UpdateTrainJobStatus( + IPlatformService platformService, + string buildId, + ProgressStatus progressStatus, + CancellationToken cancellationToken + ) + { + if ( + _curBuildStatus.TryGetValue(buildId, out ProgressStatus curProgressStatus) + && curProgressStatus.Equals(progressStatus) + ) + { + return; + } + await platformService.UpdateBuildStatusAsync(buildId, progressStatus, cancellationToken); + _curBuildStatus[buildId] = progressStatus; + } + + private static double GetMetric(ClearMLTask task, string metric, string variant) + { + if (!task.LastMetrics.TryGetValue(metric, out Dictionary? metricVariants)) + return 0; + + if (!metricVariants.TryGetValue(variant, out ClearMLMetricsEvent? metricEvent)) + return 0; + + return metricEvent.Value; + } + + private static string CreateMD5(string input) + { + using var md5 = MD5.Create(); + + byte[] inputBytes = Encoding.UTF8.GetBytes(input); + byte[] hashBytes = md5.ComputeHash(inputBytes); + + return Convert.ToHexString(hashBytes); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineBuildJob.cs deleted file mode 100644 index 4bd36e326..000000000 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineBuildJob.cs +++ /dev/null @@ -1,374 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public class ClearMLNmtEngineBuildJob -{ - private readonly IPlatformService _platformService; - private readonly IRepository _engines; - private readonly ILogger _logger; - private readonly IClearMLService _clearMLService; - private readonly ISharedFileService _sharedFileService; - private readonly IOptionsMonitor _options; - private readonly ICorpusService _corpusService; - - public ClearMLNmtEngineBuildJob( - IPlatformService platformService, - IRepository engines, - ILogger logger, - IClearMLService clearMLService, - ISharedFileService sharedFileService, - IOptionsMonitor options, - ICorpusService corpusService - ) - { - _platformService = platformService; - _engines = engines; - _logger = logger; - _clearMLService = clearMLService; - _sharedFileService = sharedFileService; - _options = options; - _corpusService = corpusService; - } - - [Queue("nmt")] - [AutomaticRetry(Attempts = 0)] - public async Task RunAsync( - string engineId, - string buildId, - IReadOnlyList corpora, - CancellationToken cancellationToken - ) - { - string? clearMLProjectId = await _clearMLService.GetProjectIdAsync(engineId, cancellationToken); - if (clearMLProjectId is null) - return; - - try - { - TranslationEngine? engine = await _engines.GetAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - cancellationToken: cancellationToken - ); - if (engine is null || engine.IsCanceled) - throw new OperationCanceledException(); - - int corpusSize; - if (engine.BuildState is BuildState.Pending) - corpusSize = await WriteDataFilesAsync(buildId, corpora, cancellationToken); - else - corpusSize = GetCorpusSize(corpora); - - string clearMLTaskId; - ClearMLTask? clearMLTask = await _clearMLService.GetTaskByNameAsync(buildId, cancellationToken); - if (clearMLTask is null) - { - clearMLTaskId = await _clearMLService.CreateTaskAsync( - buildId, - clearMLProjectId, - engineId, - engine.SourceLanguage, - engine.TargetLanguage, - _sharedFileService.GetBaseUri().ToString(), - cancellationToken - ); - await _clearMLService.EnqueueTaskAsync(clearMLTaskId, CancellationToken.None); - } - else - { - clearMLTaskId = clearMLTask.Id; - } - - int lastIteration = 0; - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - - clearMLTask = await _clearMLService.GetTaskByIdAsync(clearMLTaskId, cancellationToken); - if (clearMLTask is null) - throw new InvalidOperationException("The ClearML task does not exist."); - - if ( - engine.BuildState == BuildState.Pending - && clearMLTask.Status - is ClearMLTaskStatus.InProgress - or ClearMLTaskStatus.Stopped - or ClearMLTaskStatus.Failed - or ClearMLTaskStatus.Completed - ) - { - engine = await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId && !e.IsCanceled, - u => u.Set(e => e.BuildState, BuildState.Active), - cancellationToken: cancellationToken - ); - if (engine is null) - throw new OperationCanceledException(); - await _platformService.BuildStartedAsync(buildId, CancellationToken.None); - _logger.LogInformation("Build started ({0})", buildId); - } - - switch (clearMLTask.Status) - { - case ClearMLTaskStatus.InProgress: - case ClearMLTaskStatus.Completed: - if (lastIteration != clearMLTask.LastIteration) - { - await _platformService.UpdateBuildStatusAsync(buildId, clearMLTask.LastIteration); - lastIteration = clearMLTask.LastIteration; - } - break; - case ClearMLTaskStatus.Stopped: - // This could have been triggered from the ClearML UI, so set IsCanceled to true. - await _engines.UpdateAsync( - e => e.EngineId == engineId && !e.IsCanceled, - u => u.Set(e => e.IsCanceled, true), - cancellationToken: CancellationToken.None - ); - throw new OperationCanceledException(); - case ClearMLTaskStatus.Failed: - throw new InvalidOperationException( - $"{clearMLTask.StatusReason} : {clearMLTask.StatusMessage}" - ); - } - if (clearMLTask.Status is ClearMLTaskStatus.Completed) - break; - await Task.Delay(_options.CurrentValue.BuildPollingTimeout, cancellationToken); - } - - // The ClearML task has successfully completed, so insert the generated pretranslations into the database. - await InsertPretranslationsAsync(engineId, buildId, cancellationToken); - - IReadOnlyDictionary metrics = await _clearMLService.GetTaskMetricsAsync( - clearMLTaskId, - CancellationToken.None - ); - - try - { - //Don't fail the whole job if we can't delete the files. - await _sharedFileService.DeleteAsync($"builds/{buildId}/", CancellationToken.None); - } - catch (AmazonS3Exception e) - { - _logger.LogError(e, $"Could not delete build ({buildId}). Finishing up build anyway."); - } - - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - u => - u.Set(e => e.BuildState, BuildState.None) - .Set(e => e.IsCanceled, false) - .Inc(e => e.BuildRevision) - .Unset(e => e.JobId) - .Unset(e => e.BuildId), - cancellationToken: CancellationToken.None - ); - - if (!metrics.TryGetValue("bleu", out double confidence)) - confidence = 0; - - await _platformService.BuildCompletedAsync( - buildId, - corpusSize, - Math.Round(confidence, 2, MidpointRounding.AwayFromZero), - CancellationToken.None - ); - _logger.LogInformation("Build completed in {0}s ({1})", clearMLTask.ActiveDuration, buildId); - } - catch (OperationCanceledException) - { - // Check if the cancellation was initiated by an API call or a shutdown. - TranslationEngine? engine = await _engines.GetAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - CancellationToken.None - ); - if (engine is null || engine.IsCanceled) - { - // This is an actual cancellation triggered by an API call. - ClearMLTask? task = await _clearMLService.GetTaskByNameAsync(buildId, CancellationToken.None); - if (task is not null) - await _clearMLService.StopTaskAsync(task.Id, CancellationToken.None); - - await _sharedFileService.DeleteAsync($"builds/{buildId}/", CancellationToken.None); - - bool buildStarted = await _engines.ExistsAsync( - e => e.EngineId == engineId && e.BuildId == buildId && e.BuildState == BuildState.Active, - CancellationToken.None - ); - - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - u => - u.Set(e => e.BuildState, BuildState.None) - .Set(e => e.IsCanceled, false) - .Unset(e => e.JobId) - .Unset(e => e.BuildId), - cancellationToken: CancellationToken.None - ); - - if (buildStarted) - { - await _platformService.BuildCanceledAsync(buildId, CancellationToken.None); - _logger.LogInformation("Build canceled ({0})", buildId); - } - } - else if (engine is not null) - { - // the build was canceled, because of a server shutdown - // switch state back to pending - await _platformService.BuildRestartingAsync(buildId, CancellationToken.None); - } - - throw; - } - catch (Exception e) - { - _logger.LogError(0, e, $"Build faulted ({buildId}) because of exception {e.GetType().Name}:{e.Message}."); - - try - { - await _sharedFileService.DeleteAsync($"builds/{buildId}/", CancellationToken.None); - } - catch (Exception e2) - { - _logger.LogError( - $"Unable to access S3 bucket to delete clearml job {buildId} because it threw the exception {e2.GetType().Name}:{e2.Message}." - ); - } - - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - u => - u.Set(e => e.BuildState, BuildState.None) - .Set(e => e.IsCanceled, false) - .Unset(e => e.JobId) - .Unset(e => e.BuildId), - cancellationToken: CancellationToken.None - ); - - await _platformService.BuildFaultedAsync(buildId, e.Message, CancellationToken.None); - throw; - } - } - - private async Task WriteDataFilesAsync( - string buildId, - IReadOnlyList corpora, - CancellationToken cancellationToken - ) - { - await using var sourceTrainWriter = new StreamWriter( - await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken) - ); - await using var targetTrainWriter = new StreamWriter( - await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken) - ); - - int corpusSize = 0; - async IAsyncEnumerable ProcessRowsAsync() - { - foreach (Corpus corpus in corpora) - { - ITextCorpus sourceCorpus = _corpusService.CreateTextCorpus(corpus.SourceFiles); - ITextCorpus targetCorpus = _corpusService.CreateTextCorpus(corpus.TargetFiles); - - IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows( - targetCorpus, - allSourceRows: true, - allTargetRows: true - ); - - foreach (ParallelTextRow row in parallelCorpus) - { - await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); - await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); - if ( - (corpus.PretranslateAll || corpus.PretranslateTextIds.Contains(row.TextId)) - && row.SourceSegment.Count > 0 - && row.TargetSegment.Count == 0 - ) - { - IReadOnlyList refs; - if (row.TargetRefs.Count == 0) - { - if (sourceCorpus is ScriptureTextCorpus sstc && targetCorpus is ScriptureTextCorpus tstc) - { - refs = row.SourceRefs - .Cast() - .Select(srcRef => - { - var trgRef = srcRef.Clone(); - trgRef.ChangeVersification(tstc.Versification); - return (object)trgRef; - }) - .ToList(); - } - else - { - refs = row.SourceRefs; - } - } - else - { - refs = row.TargetRefs; - } - yield return new Pretranslation - { - CorpusId = corpus.Id, - TextId = row.TextId, - Refs = refs.Select(r => r.ToString()!).ToList(), - Translation = row.SourceText - }; - } - if (!row.IsEmpty) - corpusSize++; - } - } - } - - await using var sourcePretranslateStream = await _sharedFileService.OpenWriteAsync( - $"builds/{buildId}/pretranslate.src.json", - cancellationToken - ); - - await JsonSerializer.SerializeAsync( - sourcePretranslateStream, - ProcessRowsAsync(), - new JsonSerializerOptions { WriteIndented = true, PropertyNamingPolicy = JsonNamingPolicy.CamelCase }, - cancellationToken: cancellationToken - ); - return corpusSize; - } - - private int GetCorpusSize(IReadOnlyList corpora) - { - int corpusSize = 0; - foreach (Corpus corpus in corpora) - { - ITextCorpus sourceCorpus = _corpusService.CreateTextCorpus(corpus.SourceFiles); - ITextCorpus targetCorpus = _corpusService.CreateTextCorpus(corpus.TargetFiles); - - IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows(targetCorpus); - - corpusSize += parallelCorpus.Count(includeEmpty: false); - } - return corpusSize; - } - - private async Task InsertPretranslationsAsync(string engineId, string buildId, CancellationToken cancellationToken) - { - await using var targetPretranslateStream = await _sharedFileService.OpenReadAsync( - $"builds/{buildId}/pretranslate.trg.json", - cancellationToken - ); - - IAsyncEnumerable pretranslations = JsonSerializer - .DeserializeAsyncEnumerable( - targetPretranslateStream, - new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }, - cancellationToken - ) - .OfType(); - - await _platformService.InsertPretranslationsAsync(engineId, pretranslations, cancellationToken); - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineService.cs deleted file mode 100644 index 42eb064e1..000000000 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLNmtEngineService.cs +++ /dev/null @@ -1,51 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public class ClearMLNmtEngineService : TranslationEngineServiceBase -{ - private readonly IClearMLService _clearMLService; - - public ClearMLNmtEngineService( - IBackgroundJobClient jobClient, - IPlatformService platformService, - IDistributedReaderWriterLockFactory lockFactory, - IDataAccessContext dataAccessContext, - IRepository engines, - IClearMLService clearMLService - ) - : base(jobClient, lockFactory, platformService, dataAccessContext, engines) - { - _clearMLService = clearMLService; - } - - public override TranslationEngineType Type => TranslationEngineType.Nmt; - - public override async Task CreateAsync( - string engineId, - string? engineName, - string sourceLanguage, - string targetLanguage, - CancellationToken cancellationToken = default - ) - { - await base.CreateAsync(engineId, engineName, sourceLanguage, targetLanguage, cancellationToken); - await _clearMLService.CreateProjectAsync(engineId, engineName, cancellationToken: CancellationToken.None); - } - - public override async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) - { - await base.DeleteAsync(engineId, cancellationToken); - string? projectId = await _clearMLService.GetProjectIdAsync(engineId, CancellationToken.None); - if (projectId is not null) - await _clearMLService.DeleteProjectAsync(projectId, CancellationToken.None); - } - - protected override Expression> GetJobExpression( - string engineId, - string buildId, - IReadOnlyList corpora - ) - { - // Token "None" is used here because hangfire injects the proper cancellation token - return r => r.RunAsync(engineId, buildId, corpora, CancellationToken.None); - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs b/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs index b7022dd6b..471cb0f5b 100644 --- a/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs @@ -3,8 +3,7 @@ public class ClearMLService : IClearMLService { private readonly HttpClient _httpClient; - private readonly IOptionsMonitor _options; - private readonly ILogger _logger; + private readonly IOptionsMonitor _options; private static readonly JsonNamingPolicy JsonNamingPolicy = new SnakeCaseJsonNamingPolicy(); private static readonly JsonSerializerOptions JsonSerializerOptions = new() @@ -13,21 +12,17 @@ public class ClearMLService : IClearMLService Converters = { new CustomEnumConverterFactory(JsonNamingPolicy) } }; - private IClearMLAuthenticationService _clearMLAuthService; + private readonly IClearMLAuthenticationService _clearMLAuthService; public ClearMLService( HttpClient httpClient, - IOptionsMonitor options, - ILogger logger, + IOptionsMonitor options, IClearMLAuthenticationService clearMLAuthService ) { _httpClient = httpClient; _options = options; - _logger = logger; _clearMLAuthService = clearMLAuthService; - if (!Sldr.IsInitialized) - Sldr.Initialize(); } public async Task GetProjectIdAsync(string name, CancellationToken cancellationToken = default) @@ -80,27 +75,10 @@ public async Task DeleteProjectAsync(string id, CancellationToken cancella public async Task CreateTaskAsync( string buildId, string projectId, - string engineId, - string sourceLanguageTag, - string targetLanguageTag, - string sharedFileUri, + string script, CancellationToken cancellationToken = default ) { - string script = - "from machine.jobs.build_nmt_engine import run\n" - + "args = {\n" - + $" 'model_type': '{_options.CurrentValue.ModelType}',\n" - + $" 'engine_id': '{engineId}',\n" - + $" 'build_id': '{buildId}',\n" - + $" 'src_lang': '{ConvertLanguageTag(sourceLanguageTag)}',\n" - + $" 'trg_lang': '{ConvertLanguageTag(targetLanguageTag)}',\n" - + $" 'max_steps': {_options.CurrentValue.MaxSteps},\n" - + $" 'shared_file_uri': '{sharedFileUri}',\n" - + $" 'clearml': True,\n" - + "}\n" - + "run(args)\n"; - var body = new JsonObject { ["name"] = buildId, @@ -116,6 +94,16 @@ public async Task CreateTaskAsync( return taskId; } + public async Task DeleteTaskAsync(string id, CancellationToken cancellationToken = default) + { + var body = new JsonObject { ["task"] = id }; + JsonObject? result = await CallAsync("tasks", "delete", body, cancellationToken); + var deleted = (bool?)result?["data"]?["deleted"]; + if (deleted is null) + throw new InvalidOperationException("Malformed response from ClearML server."); + return deleted.Value; + } + public async Task EnqueueTaskAsync(string id, CancellationToken cancellationToken = default) { var body = new JsonObject { ["task"] = id, ["queue_name"] = _options.CurrentValue.Queue }; @@ -146,49 +134,26 @@ public async Task StopTaskAsync(string id, CancellationToken cancellationT return updated == 1; } - public Task GetTaskByNameAsync(string name, CancellationToken cancellationToken = default) + public async Task GetTaskByNameAsync(string name, CancellationToken cancellationToken = default) { - return GetTaskAsync(new JsonObject { ["name"] = name }, cancellationToken); - } - - public Task GetTaskByIdAsync(string id, CancellationToken cancellationToken = default) - { - return GetTaskAsync(new JsonObject { ["id"] = id }, cancellationToken); + IReadOnlyList tasks = await GetTasksAsync(new JsonObject { ["name"] = name }, cancellationToken); + if (tasks.Count == 0) + return null; + return tasks[0]; } - public async Task> GetTaskMetricsAsync( - string id, + public Task> GetTasksByIdAsync( + IEnumerable ids, CancellationToken cancellationToken = default ) { - var body = new JsonObject { ["task"] = id }; - JsonObject? result = await CallAsync("events", "get_task_latest_scalar_values", body, cancellationToken); - var metrics = (JsonArray?)result?["data"]?["metrics"]; - if (metrics is null) - throw new InvalidOperationException("Malformed response from ClearML server."); - var performanceMetrics = (JsonObject?)metrics.FirstOrDefault(m => (string?)m?["name"] == "metrics"); - var results = new Dictionary(); - if (performanceMetrics is null) - return results; - var variants = (JsonArray?)performanceMetrics?["variants"]; - if (variants is null) - return results; - foreach (JsonObject? variant in variants) - { - if (variant is null) - continue; - var name = (string?)variant?["name"]; - if (name is null) - continue; - var value = (double?)variant?["last_value"]; - if (value is null) - continue; - results[name] = value.Value; - } - return results; + return GetTasksAsync(new JsonObject { ["id"] = JsonValue.Create(ids.ToArray()) }, cancellationToken); } - private async Task GetTaskAsync(JsonObject body, CancellationToken cancellationToken = default) + private async Task> GetTasksAsync( + JsonObject body, + CancellationToken cancellationToken = default + ) { body["only_fields"] = new JsonArray( "id", @@ -197,13 +162,13 @@ public async Task> GetTaskMetricsAsync( "project", "last_iteration", "status_reason", - "active_duration" + "active_duration", + "last_metrics" ); JsonObject? result = await CallAsync("tasks", "get_all_ex", body, cancellationToken); var tasks = (JsonArray?)result?["data"]?["tasks"]; - if (tasks is null || tasks.Count == 0) - return null; - return tasks[0].Deserialize(JsonSerializerOptions); + return tasks?.Select(t => t.Deserialize(JsonSerializerOptions)!).ToArray() + ?? Array.Empty(); } private async Task CallAsync( @@ -223,23 +188,6 @@ public async Task> GetTaskMetricsAsync( return (JsonObject?)JsonNode.Parse(result); } - private static string ConvertLanguageTag(string languageTag) - { - if ( - !IetfLanguageTag.TryGetSubtags( - languageTag, - out LanguageSubtag languageSubtag, - out ScriptSubtag scriptSubtag, - out _, - out _ - ) - ) - return languageTag; - - // Convert to NLLB language codes - return $"{languageSubtag.Iso3Code}_{scriptSubtag.Code}"; - } - private class SnakeCaseJsonNamingPolicy : JsonNamingPolicy { public override string ConvertName(string name) diff --git a/src/SIL.Machine.AspNetCore/Services/FileStorage.cs b/src/SIL.Machine.AspNetCore/Services/FileStorage.cs deleted file mode 100644 index a38c1f541..000000000 --- a/src/SIL.Machine.AspNetCore/Services/FileStorage.cs +++ /dev/null @@ -1,43 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public abstract class FileStorage : IDisposable -{ - public abstract void Dispose(); - public abstract Task Exists(string path, CancellationToken cancellationToken = default); - - public abstract Task> Ls( - string path, - bool recurse = false, - CancellationToken cancellationToken = default - ); - - public abstract Task OpenRead(string path, CancellationToken cancellationToken = default); - - public abstract Task OpenWrite(string path, CancellationToken cancellationToken = default); - - public abstract Task Rm(string path, bool recurse = false, CancellationToken cancellationToken = default); - - protected string Normalize(string? path, bool includeLeadingSlash = true, bool includeTrailingSlash = false) - { - string normalizedPath = path ?? ""; - if (normalizedPath == "/") - return normalizedPath; - if (!includeLeadingSlash && normalizedPath.StartsWith("/")) - { - normalizedPath = normalizedPath.Remove(0, 1); - } - else if (includeLeadingSlash && !normalizedPath.StartsWith("/")) - { - normalizedPath = "/" + normalizedPath; - } - if (!includeTrailingSlash && normalizedPath.EndsWith("/")) - { - normalizedPath = normalizedPath.Remove(normalizedPath.Length - 1, 1); - } - else if (includeTrailingSlash && !normalizedPath.EndsWith("/")) - { - normalizedPath = normalizedPath + "/"; - } - return normalizedPath; - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/HangfireBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/HangfireBuildJob.cs new file mode 100644 index 000000000..082cf64a3 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/HangfireBuildJob.cs @@ -0,0 +1,159 @@ +namespace SIL.Machine.AspNetCore.Services; + +public abstract class HangfireBuildJob : HangfireBuildJob +{ + protected HangfireBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger logger + ) + : base(platformService, engines, lockFactory, buildJobService, logger) { } + + public virtual Task RunAsync(string engineId, string buildId, CancellationToken cancellationToken) + { + return RunAsync(engineId, buildId, null, cancellationToken); + } +} + +public abstract class HangfireBuildJob +{ + protected HangfireBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger> logger + ) + { + PlatformService = platformService; + Engines = engines; + LockFactory = lockFactory; + BuildJobService = buildJobService; + Logger = logger; + } + + protected IPlatformService PlatformService { get; } + protected IRepository Engines { get; } + protected IDistributedReaderWriterLockFactory LockFactory { get; } + protected IBuildJobService BuildJobService { get; } + protected ILogger> Logger { get; } + + public virtual async Task RunAsync(string engineId, string buildId, T data, CancellationToken cancellationToken) + { + IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); + JobCompletionStatus completionStatus = JobCompletionStatus.Completed; + try + { + await InitializeAsync(engineId, buildId, data, @lock, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + if (!await BuildJobService.BuildJobStartedAsync(engineId, buildId, cancellationToken)) + { + completionStatus = JobCompletionStatus.Canceled; + return; + } + } + + await DoWorkAsync(engineId, buildId, data, @lock, cancellationToken); + } + catch (OperationCanceledException) + { + // Check if the cancellation was initiated by an API call or a shutdown. + TranslationEngine? engine = await Engines.GetAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, + CancellationToken.None + ); + if (engine?.CurrentBuild?.JobState is BuildJobState.Canceling) + { + completionStatus = JobCompletionStatus.Canceled; + await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) + { + await PlatformService.BuildCanceledAsync(buildId, CancellationToken.None); + await BuildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + } + Logger.LogInformation("Build canceled ({0})", buildId); + } + else if (engine is not null) + { + // the build was canceled, because of a server shutdown + // switch state back to pending + completionStatus = JobCompletionStatus.Restarting; + await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) + { + await PlatformService.BuildRestartingAsync(buildId, CancellationToken.None); + await BuildJobService.BuildJobRestartingAsync(engineId, buildId, CancellationToken.None); + } + throw; + } + else + { + completionStatus = JobCompletionStatus.Canceled; + } + } + catch (Exception e) + { + completionStatus = JobCompletionStatus.Faulted; + await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) + { + await PlatformService.BuildFaultedAsync(buildId, e.Message, CancellationToken.None); + await BuildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + } + Logger.LogError(0, e, "Build faulted ({0})", buildId); + throw; + } + finally + { + await CleanupAsync(engineId, buildId, data, @lock, completionStatus); + } + } + + protected virtual Task InitializeAsync( + string engineId, + string buildId, + T data, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + return Task.CompletedTask; + } + + protected abstract Task DoWorkAsync( + string engineId, + string buildId, + T data, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ); + + protected virtual Task CleanupAsync( + string engineId, + string buildId, + T data, + IDistributedReaderWriterLock @lock, + JobCompletionStatus completionStatus + ) + { + return Task.CompletedTask; + } + + protected enum JobCompletionStatus + { + Completed, + Faulted, + Canceled, + Restarting + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs b/src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs new file mode 100644 index 000000000..068c8c510 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/HangfireBuildJobRunner.cs @@ -0,0 +1,74 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class HangfireBuildJobRunner : IBuildJobRunner +{ + public static Job CreateJob(string engineId, string buildId, string queue, object? data) + where TJob : HangfireBuildJob + { + if (data is null) + throw new ArgumentNullException(nameof(data)); + // Token "None" is used here because hangfire injects the proper cancellation token + return Job.FromExpression(j => j.RunAsync(engineId, buildId, (TData)data, CancellationToken.None), queue); + } + + public static Job CreateJob(string engineId, string buildId, string queue) + where TJob : HangfireBuildJob + { + // Token "None" is used here because hangfire injects the proper cancellation token + return Job.FromExpression(j => j.RunAsync(engineId, buildId, CancellationToken.None), queue); + } + + private readonly IBackgroundJobClient _jobClient; + private readonly Dictionary _buildJobFactories; + + public HangfireBuildJobRunner( + IBackgroundJobClient jobClient, + IEnumerable buildJobFactories + ) + { + _jobClient = jobClient; + _buildJobFactories = buildJobFactories.ToDictionary(f => f.EngineType); + } + + public BuildJobRunner Type => BuildJobRunner.Hangfire; + + public Task CreateEngineAsync(string engineId, string? name = null, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + public Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + public Task CreateJobAsync( + TranslationEngineType engineType, + string engineId, + string buildId, + string stage, + object? data = null, + CancellationToken cancellationToken = default + ) + { + IHangfireBuildJobFactory buildJobFactory = _buildJobFactories[engineType]; + Job job = buildJobFactory.CreateJob(engineId, buildId, stage, data); + return Task.FromResult(_jobClient.Create(job, new ScheduledState(TimeSpan.FromDays(10000)))); + } + + public Task DeleteJobAsync(string jobId, CancellationToken cancellationToken = default) + { + return Task.FromResult(_jobClient.Delete(jobId)); + } + + public Task EnqueueJobAsync(string jobId, CancellationToken cancellationToken = default) + { + return Task.FromResult(_jobClient.Requeue(jobId)); + } + + public Task StopJobAsync(string jobId, CancellationToken cancellationToken = default) + { + // Trigger the cancellation token for the job + return Task.FromResult(_jobClient.Delete(jobId)); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs b/src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs new file mode 100644 index 000000000..74b9650c2 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/IBuildJobRunner.cs @@ -0,0 +1,24 @@ +namespace SIL.Machine.AspNetCore.Services; + +public interface IBuildJobRunner +{ + BuildJobRunner Type { get; } + + Task CreateEngineAsync(string engineId, string? name = null, CancellationToken cancellationToken = default); + Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default); + + Task CreateJobAsync( + TranslationEngineType engineType, + string engineId, + string buildId, + string stage, + object? data = null, + CancellationToken cancellationToken = default + ); + + Task DeleteJobAsync(string jobId, CancellationToken cancellationToken = default); + + Task EnqueueJobAsync(string jobId, CancellationToken cancellationToken = default); + + Task StopJobAsync(string jobId, CancellationToken cancellationToken = default); +} diff --git a/src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs b/src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs new file mode 100644 index 000000000..70731b69e --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/IBuildJobService.cs @@ -0,0 +1,56 @@ +namespace SIL.Machine.AspNetCore.Services; + +public enum BuildJobType +{ + Cpu, + Gpu +} + +public interface IBuildJobService +{ + Task> GetBuildingEnginesAsync( + BuildJobRunner runner, + CancellationToken cancellationToken = default + ); + + Task IsEngineBuilding(string engineId, CancellationToken cancellationToken = default); + + Task CreateEngineAsync( + IEnumerable jobTypes, + string engineId, + string? name = null, + CancellationToken cancellationToken = default + ); + + Task DeleteEngineAsync( + IEnumerable jobTypes, + string engineId, + CancellationToken cancellationToken = default + ); + + Task StartBuildJobAsync( + BuildJobType jobType, + TranslationEngineType engineType, + string engineId, + string buildId, + string stage, + object? data = default, + CancellationToken cancellationToken = default + ); + + Task<(string? BuildId, BuildJobState State)> CancelBuildJobAsync( + string engineId, + CancellationToken cancellationToken = default + ); + + Task BuildJobStartedAsync(string engineId, string buildId, CancellationToken cancellationToken = default); + + Task BuildJobFinishedAsync( + string engineId, + string buildId, + bool buildComplete, + CancellationToken cancellationToken = default + ); + + Task BuildJobRestartingAsync(string engineId, string buildId, CancellationToken cancellationToken = default); +} diff --git a/src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs new file mode 100644 index 000000000..f4bc6091f --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/IClearMLBuildJobFactory.cs @@ -0,0 +1,14 @@ +namespace SIL.Machine.AspNetCore.Services; + +public interface IClearMLBuildJobFactory +{ + TranslationEngineType EngineType { get; } + + Task CreateJobScriptAsync( + string engineId, + string buildId, + string stage, + object? data = null, + CancellationToken cancellationToken = default + ); +} diff --git a/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs b/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs index ea9bc7608..7f45670e0 100644 --- a/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs +++ b/src/SIL.Machine.AspNetCore/Services/IClearMLService.cs @@ -13,19 +13,16 @@ Task CreateProjectAsync( Task CreateTaskAsync( string buildId, string projectId, - string engineId, - string sourceLanguageTag, - string targetLanguageTag, - string sharedFileUri, + string script, CancellationToken cancellationToken = default ); + Task DeleteTaskAsync(string id, CancellationToken cancellationToken = default); Task EnqueueTaskAsync(string id, CancellationToken cancellationToken = default); Task DequeueTaskAsync(string id, CancellationToken cancellationToken = default); Task StopTaskAsync(string id, CancellationToken cancellationToken = default); Task GetTaskByNameAsync(string name, CancellationToken cancellationToken = default); - Task GetTaskByIdAsync(string id, CancellationToken cancellationToken = default); - Task> GetTaskMetricsAsync( - string id, + Task> GetTasksByIdAsync( + IEnumerable ids, CancellationToken cancellationToken = default ); } diff --git a/src/SIL.Machine.AspNetCore/Services/IFileStorage.cs b/src/SIL.Machine.AspNetCore/Services/IFileStorage.cs new file mode 100644 index 000000000..89a15ccc9 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/IFileStorage.cs @@ -0,0 +1,18 @@ +namespace SIL.Machine.AspNetCore.Services; + +public interface IFileStorage : IDisposable +{ + Task ExistsAsync(string path, CancellationToken cancellationToken = default); + + Task> ListFilesAsync( + string path, + bool recurse = false, + CancellationToken cancellationToken = default + ); + + Task OpenReadAsync(string path, CancellationToken cancellationToken = default); + + Task OpenWriteAsync(string path, CancellationToken cancellationToken = default); + + Task DeleteAsync(string path, bool recurse = false, CancellationToken cancellationToken = default); +} diff --git a/src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs new file mode 100644 index 000000000..988750742 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/IHangfireBuildJobFactory.cs @@ -0,0 +1,8 @@ +namespace SIL.Machine.AspNetCore.Services; + +public interface IHangfireBuildJobFactory +{ + TranslationEngineType EngineType { get; } + + Job CreateJob(string engineId, string buildId, string stage, object? data); +} diff --git a/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs b/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs index 42c2d0e7c..acbac0687 100644 --- a/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs @@ -13,10 +13,4 @@ public interface ISharedFileService Task ExistsAsync(string path, CancellationToken cancellationToken = default); Task DeleteAsync(string path, CancellationToken cancellationToken = default); - - Task> Ls( - string path, - bool recurse = false, - CancellationToken cancellationToken = default - ); } diff --git a/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs b/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs index accad7b81..76755f245 100644 --- a/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs +++ b/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs @@ -1,6 +1,8 @@ +using static SIL.Machine.AspNetCore.Utils.SharedFileUtils; + namespace SIL.Machine.AspNetCore.Services; -public class InMemoryStorage : FileStorage +public class InMemoryStorage : DisposableBase, IFileStorage { public class Entry : Stream { @@ -39,9 +41,7 @@ public Entry(Entry other) protected override void Dispose(bool disposing) { - bool alreadyExisted = !_parent._memoryStreams.TryAdd(Path, new Entry(this)); - if (alreadyExisted) - _parent._memoryStreams[Path] = new Entry(this); + _parent._memoryStreams[Path] = new Entry(this); } public override void Flush() @@ -70,46 +70,33 @@ public override void Write(byte[] buffer, int offset, int count) } } - public ConcurrentDictionary _memoryStreams; + private readonly ConcurrentDictionary _memoryStreams = new(); - public InMemoryStorage() - { - _memoryStreams = new(); - } - - public override Task Exists(string path, CancellationToken cancellationToken = default) + public Task ExistsAsync(string path, CancellationToken cancellationToken = default) { return Task.FromResult(_memoryStreams.TryGetValue(Normalize(path), out _)); } - public override Task> Ls( + public Task> ListFilesAsync( string? path, bool recurse = false, CancellationToken cancellationToken = default ) { + path = string.IsNullOrEmpty(path) ? "" : Normalize(path, includeTrailingSlash: true); if (recurse) - return Task.FromResult( - (IReadOnlyCollection) - _memoryStreams - .Where(kvPair => kvPair.Key.StartsWith(Normalize(path, true, true))) - .Select(kvPair => kvPair.Key) - .ToList() + { + return Task.FromResult>( + _memoryStreams.Keys.Where(p => p.StartsWith(path)).ToList() ); - return Task.FromResult( - (IReadOnlyCollection) - _memoryStreams - .Where( - kvPair => - kvPair.Key.StartsWith(Normalize(path, true, true)) - && !kvPair.Key.Remove(0, Normalize(path, true, true).Length).Contains("/") - ) - .Select(kvPair => kvPair.Key) - .ToList() + } + + return Task.FromResult>( + _memoryStreams.Keys.Where(p => p.StartsWith(path) && !p[path.Length..].Contains('/')).ToList() ); } - public override Task OpenRead(string path, CancellationToken cancellationToken = default) + public Task OpenReadAsync(string path, CancellationToken cancellationToken = default) { if (!_memoryStreams.TryGetValue(Normalize(path), out Entry? ret)) throw new FileNotFoundException($"Unable to find file {path}"); @@ -117,12 +104,12 @@ public override Task OpenRead(string path, CancellationToken cancellatio return Task.FromResult(ret); } - public override Task OpenWrite(string path, CancellationToken cancellationToken = default) + public Task OpenWriteAsync(string path, CancellationToken cancellationToken = default) { return Task.FromResult(new Entry(Normalize(path), this)); } - public override async Task Rm(string path, bool recurse, CancellationToken cancellationToken = default) + public async Task DeleteAsync(string path, bool recurse, CancellationToken cancellationToken = default) { if (_memoryStreams.ContainsKey(Normalize(path))) { @@ -130,17 +117,16 @@ public override async Task Rm(string path, bool recurse, CancellationToken cance } else { - IEnumerable filesToRemove = await Ls(path, recurse, cancellationToken); + IEnumerable filesToRemove = await ListFilesAsync(path, recurse, cancellationToken); foreach (string filePath in filesToRemove) _memoryStreams.Remove(Normalize(filePath), out _); } } - public override void Dispose() + protected override void DisposeManagedResources() { - foreach (Entry stream in _memoryStreams.Select(kvPair => kvPair.Value)) - { + foreach (Entry stream in _memoryStreams.Values) stream.Dispose(); - } + _memoryStreams.Clear(); } } diff --git a/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs b/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs index c07a26034..6826869ee 100644 --- a/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs +++ b/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs @@ -1,100 +1,67 @@ +using static SIL.Machine.AspNetCore.Utils.SharedFileUtils; + namespace SIL.Machine.AspNetCore.Services; -public class LocalStorage : FileStorage +public class LocalStorage : DisposableBase, IFileStorage { - private readonly string _basePath; + private readonly Uri _basePath; public LocalStorage(string basePath) { - _basePath = basePath.EndsWith("/") ? basePath.Remove(basePath.Length - 1, 1) : basePath; - Random r = new Random(Guid.NewGuid().GetHashCode()); - while (Directory.Exists(_basePath + "/")) - { - _basePath += r.Next(); - } - Directory.CreateDirectory(_basePath + "/"); - } - - public override void Dispose() - { - DirectoryHelper.DeleteDirectoryRobust(_basePath + "/"); + _basePath = new Uri(basePath); + if (!_basePath.AbsoluteUri.EndsWith("/")) + _basePath = new Uri(_basePath.AbsoluteUri + "/"); } - public override Task Exists(string path, CancellationToken cancellationToken = default) + public Task ExistsAsync(string path, CancellationToken cancellationToken = default) { - return Task.FromResult(File.Exists(_basePath + Normalize(path))); + Uri pathUri = new(_basePath, Normalize(path)); + return Task.FromResult(File.Exists(pathUri.LocalPath)); } - public override async Task> Ls( + public Task> ListFilesAsync( string path = "", bool recurse = false, CancellationToken cancellationToken = default ) { - if (path.Contains(_basePath)) - path = path.Replace(_basePath, ""); - if (recurse) - { - List files = Directory.GetFiles(_basePath + Normalize(path)).ToList(); - foreach (var subDir in Directory.GetDirectories(_basePath + Normalize(path))) - { - var subFiles = await Ls(subDir, recurse: true); - foreach (var file in subFiles) - files.Add(file); - } - return files; - } - if (Directory.Exists(_basePath + Normalize(path))) - return Directory.GetFiles(_basePath + Normalize(path)); - return new List(); + Uri pathUri = new(_basePath, Normalize(path)); + string[] files = Directory.GetFiles( + pathUri.LocalPath, + "*", + new EnumerationOptions { RecurseSubdirectories = recurse } + ); + return Task.FromResult>( + files.Select(f => _basePath.MakeRelativeUri(new Uri(f)).ToString()).ToArray() + ); } - public override Task OpenRead(string path, CancellationToken cancellationToken) + public Task OpenReadAsync(string path, CancellationToken cancellationToken = default) { - Stream? ret = File.OpenRead(_basePath + Normalize(path)); - if (ret is null) - throw new FileNotFoundException($"Unable to locate file {_basePath + Normalize(path)}"); - return Task.FromResult(ret); + Uri pathUri = new(_basePath, Normalize(path)); + return Task.FromResult(File.OpenRead(pathUri.LocalPath)); } - public override Task OpenWrite(string path, CancellationToken cancellationToken = default) + public Task OpenWriteAsync(string path, CancellationToken cancellationToken = default) { - Stream s; - try - { - s = File.OpenWrite(_basePath + Normalize(path)); - } - catch (IOException) - { - string accumulator = _basePath; - List segments = path.Split("/").ToList(); - foreach (string segment in segments.Take(segments.Count() - 1)) - { - accumulator += Normalize(segment); - if (!Directory.Exists(accumulator)) - { - Directory.CreateDirectory(accumulator); - } - } - s = File.OpenWrite(_basePath + Normalize(path)); - } - return Task.FromResult(s); + Uri pathUri = new(_basePath, Normalize(path)); + Directory.CreateDirectory(Path.GetDirectoryName(pathUri.LocalPath)!); + return Task.FromResult(File.OpenWrite(pathUri.LocalPath)); } - public async override Task Rm(string path, bool recurse, CancellationToken cancellationToken = default) + public async Task DeleteAsync(string path, bool recurse, CancellationToken cancellationToken = default) { - if (path.Contains(_basePath)) - path = path.Replace(_basePath, ""); + Uri pathUri = new(_basePath, Normalize(path)); - if (File.Exists(_basePath + Normalize(path))) + if (File.Exists(pathUri.LocalPath)) { - File.Delete(_basePath + Normalize(path)); + File.Delete(pathUri.LocalPath); } - else + else if (Directory.Exists(pathUri.LocalPath)) { - foreach (string filePath in await Ls(path, recurse, cancellationToken)) + foreach (string filePath in await ListFilesAsync(path, recurse, cancellationToken)) { - await Rm(filePath, false, cancellationToken); + await DeleteAsync(filePath, false, cancellationToken); } } } diff --git a/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs new file mode 100644 index 000000000..c98c5cf19 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs @@ -0,0 +1,71 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class NmtClearMLBuildJobFactory : IClearMLBuildJobFactory +{ + private readonly ISharedFileService _sharedFileService; + private readonly IRepository _engines; + private readonly IOptionsMonitor _options; + + public NmtClearMLBuildJobFactory( + ISharedFileService sharedFileService, + IRepository engines, + IOptionsMonitor options + ) + { + _sharedFileService = sharedFileService; + _engines = engines; + _options = options; + } + + public TranslationEngineType EngineType => TranslationEngineType.Nmt; + + public async Task CreateJobScriptAsync( + string engineId, + string buildId, + string stage, + object? data = null, + CancellationToken cancellationToken = default + ) + { + if (stage == NmtBuildStages.Train) + { + TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new InvalidOperationException("The engine does not exist."); + + return "from machine.jobs.build_nmt_engine import run\n" + + "args = {\n" + + $" 'model_type': '{_options.CurrentValue.ModelType}',\n" + + $" 'engine_id': '{engineId}',\n" + + $" 'build_id': '{buildId}',\n" + + $" 'src_lang': '{ConvertLanguageTag(engine.SourceLanguage)}',\n" + + $" 'trg_lang': '{ConvertLanguageTag(engine.TargetLanguage)}',\n" + + $" 'max_steps': {_options.CurrentValue.MaxSteps},\n" + + $" 'shared_file_uri': '{_sharedFileService.GetBaseUri()}',\n" + + $" 'clearml': True\n" + + "}\n" + + "run(args)\n"; + } + else + { + throw new ArgumentException("Unknown build stage.", nameof(stage)); + } + } + + private static string ConvertLanguageTag(string languageTag) + { + if ( + !IetfLanguageTag.TryGetSubtags( + languageTag, + out LanguageSubtag languageSubtag, + out ScriptSubtag scriptSubtag, + out _, + out _ + ) + ) + return languageTag; + + // Convert to NLLB language codes + return $"{languageSubtag.Iso3Code}_{scriptSubtag.Code}"; + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs new file mode 100644 index 000000000..4a962520b --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs @@ -0,0 +1,153 @@ +namespace SIL.Machine.AspNetCore.Services; + +public static class NmtBuildStages +{ + public const string Preprocess = "preprocess"; + public const string Train = "train"; + public const string Postprocess = "postprocess"; +} + +public class NmtEngineService : ITranslationEngineService +{ + private readonly IDistributedReaderWriterLockFactory _lockFactory; + private readonly IPlatformService _platformService; + private readonly IDataAccessContext _dataAccessContext; + private readonly IRepository _engines; + private readonly IBuildJobService _buildJobService; + + public NmtEngineService( + IPlatformService platformService, + IDistributedReaderWriterLockFactory lockFactory, + IDataAccessContext dataAccessContext, + IRepository engines, + IBuildJobService buildJobService + ) + { + _lockFactory = lockFactory; + _platformService = platformService; + _dataAccessContext = dataAccessContext; + _engines = engines; + _buildJobService = buildJobService; + } + + public TranslationEngineType Type => TranslationEngineType.Nmt; + + public async Task CreateAsync( + string engineId, + string? engineName, + string sourceLanguage, + string targetLanguage, + CancellationToken cancellationToken = default + ) + { + await _dataAccessContext.BeginTransactionAsync(cancellationToken); + await _engines.InsertAsync( + new TranslationEngine + { + EngineId = engineId, + SourceLanguage = sourceLanguage, + TargetLanguage = targetLanguage + }, + cancellationToken + ); + await _buildJobService.CreateEngineAsync( + new[] { BuildJobType.Cpu, BuildJobType.Gpu }, + engineId, + engineName, + cancellationToken + ); + await _dataAccessContext.CommitTransactionAsync(CancellationToken.None); + } + + public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) + { + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + await CancelBuildJobAsync(engineId, cancellationToken); + + await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); + await _buildJobService.DeleteEngineAsync( + new[] { BuildJobType.Cpu, BuildJobType.Gpu }, + engineId, + CancellationToken.None + ); + } + await _lockFactory.DeleteAsync(engineId, CancellationToken.None); + } + + public async Task StartBuildAsync( + string engineId, + string buildId, + IReadOnlyList corpora, + CancellationToken cancellationToken = default + ) + { + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + // If there is a pending/running build, then no need to start a new one. + if (await _buildJobService.IsEngineBuilding(engineId, cancellationToken)) + throw new InvalidOperationException("The engine has already started a build."); + + await _buildJobService.StartBuildJobAsync( + BuildJobType.Cpu, + TranslationEngineType.Nmt, + engineId, + buildId, + NmtBuildStages.Preprocess, + corpora, + cancellationToken + ); + } + } + + public async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) + { + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + await CancelBuildJobAsync(engineId, cancellationToken); + } + } + + public Task> TranslateAsync( + string engineId, + int n, + string segment, + CancellationToken cancellationToken = default + ) + { + throw new NotSupportedException(); + } + + public Task GetWordGraphAsync( + string engineId, + string segment, + CancellationToken cancellationToken = default + ) + { + throw new NotSupportedException(); + } + + public Task TrainSegmentPairAsync( + string engineId, + string sourceSegment, + string targetSegment, + bool sentenceStart, + CancellationToken cancellationToken = default + ) + { + throw new NotSupportedException(); + } + + private async Task CancelBuildJobAsync(string engineId, CancellationToken cancellationToken) + { + (string? buildId, BuildJobState jobState) = await _buildJobService.CancelBuildJobAsync( + engineId, + cancellationToken + ); + if (buildId is not null && jobState is BuildJobState.None) + await _platformService.BuildCanceledAsync(buildId, CancellationToken.None); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs new file mode 100644 index 000000000..73168d8bd --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/NmtHangfireBuildJobFactory.cs @@ -0,0 +1,21 @@ +using static SIL.Machine.AspNetCore.Services.HangfireBuildJobRunner; + +namespace SIL.Machine.AspNetCore.Services; + +public class NmtHangfireBuildJobFactory : IHangfireBuildJobFactory +{ + public TranslationEngineType EngineType => TranslationEngineType.Nmt; + + public Job CreateJob(string engineId, string buildId, string stage, object? data) + { + return stage switch + { + NmtBuildStages.Preprocess + => CreateJob>(engineId, buildId, "nmt", data), + NmtBuildStages.Postprocess + => CreateJob(engineId, buildId, "nmt", data), + NmtBuildStages.Train => CreateJob(engineId, buildId, "nmt"), + _ => throw new ArgumentException("Unknown build stage.", nameof(stage)), + }; + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs new file mode 100644 index 000000000..5cbd0fe9a --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/NmtPostprocessBuildJob.cs @@ -0,0 +1,85 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class NmtPostprocessBuildJob : HangfireBuildJob<(int, double)> +{ + private readonly ISharedFileService _sharedFileService; + + public NmtPostprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger logger, + ISharedFileService sharedFileService + ) + : base(platformService, engines, lockFactory, buildJobService, logger) + { + _sharedFileService = sharedFileService; + } + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + (int, double) data, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + (int corpusSize, double confidence) = data; + + // The NMT job has successfully completed, so insert the generated pretranslations into the database. + await InsertPretranslationsAsync(engineId, buildId, cancellationToken); + + await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) + { + await PlatformService.BuildCompletedAsync( + buildId, + corpusSize, + Math.Round(confidence, 2, MidpointRounding.AwayFromZero), + CancellationToken.None + ); + await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, CancellationToken.None); + } + + Logger.LogInformation("Build completed ({0}).", buildId); + } + + protected override async Task CleanupAsync( + string engineId, + string buildId, + (int, double) data, + IDistributedReaderWriterLock @lock, + JobCompletionStatus completionStatus + ) + { + if (completionStatus is JobCompletionStatus.Restarting) + return; + + try + { + await _sharedFileService.DeleteAsync($"builds/{buildId}/"); + } + catch (Exception e) + { + Logger.LogWarning(e, "Unable to to delete job data for build {0}.", buildId); + } + } + + private async Task InsertPretranslationsAsync(string engineId, string buildId, CancellationToken cancellationToken) + { + await using var targetPretranslateStream = await _sharedFileService.OpenReadAsync( + $"builds/{buildId}/pretranslate.trg.json", + cancellationToken + ); + + IAsyncEnumerable pretranslations = JsonSerializer + .DeserializeAsyncEnumerable( + targetPretranslateStream, + new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }, + cancellationToken + ) + .OfType(); + + await PlatformService.InsertPretranslationsAsync(engineId, pretranslations, cancellationToken); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs new file mode 100644 index 000000000..cbfd7bfef --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs @@ -0,0 +1,157 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class NmtPreprocessBuildJob : HangfireBuildJob> +{ + private readonly ISharedFileService _sharedFileService; + private readonly ICorpusService _corpusService; + + public NmtPreprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + ILogger logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService + ) + : base(platformService, engines, lockFactory, buildJobService, logger) + { + _sharedFileService = sharedFileService; + _corpusService = corpusService; + } + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + IReadOnlyList data, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + await WriteDataFilesAsync(buildId, data, cancellationToken); + + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + bool canceling = !await BuildJobService.StartBuildJobAsync( + BuildJobType.Gpu, + TranslationEngineType.Nmt, + engineId, + buildId, + NmtBuildStages.Train, + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); + } + } + + private async Task WriteDataFilesAsync( + string buildId, + IReadOnlyList corpora, + CancellationToken cancellationToken + ) + { + await using var sourceTrainWriter = new StreamWriter( + await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken) + ); + await using var targetTrainWriter = new StreamWriter( + await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken) + ); + + int corpusSize = 0; + async IAsyncEnumerable ProcessRowsAsync() + { + foreach (Corpus corpus in corpora) + { + ITextCorpus sourceCorpus = _corpusService.CreateTextCorpus(corpus.SourceFiles); + ITextCorpus targetCorpus = _corpusService.CreateTextCorpus(corpus.TargetFiles); + + IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows( + targetCorpus, + allSourceRows: true, + allTargetRows: true + ); + + foreach (ParallelTextRow row in parallelCorpus) + { + await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); + await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); + if ( + (corpus.PretranslateAll || corpus.PretranslateTextIds.Contains(row.TextId)) + && row.SourceSegment.Count > 0 + && row.TargetSegment.Count == 0 + ) + { + IReadOnlyList refs; + if (row.TargetRefs.Count == 0) + { + if (targetCorpus is ScriptureTextCorpus tstc) + { + refs = row.SourceRefs + .Cast() + .Select(srcRef => + { + var trgRef = srcRef.Clone(); + trgRef.ChangeVersification(tstc.Versification); + return (object)trgRef; + }) + .ToList(); + } + else + { + refs = row.SourceRefs; + } + } + else + { + refs = row.TargetRefs; + } + yield return new Pretranslation + { + CorpusId = corpus.Id, + TextId = row.TextId, + Refs = refs.Select(r => r.ToString()!).ToList(), + Translation = row.SourceText + }; + } + if (!row.IsEmpty) + corpusSize++; + } + } + } + + await using var sourcePretranslateStream = await _sharedFileService.OpenWriteAsync( + $"builds/{buildId}/pretranslate.src.json", + cancellationToken + ); + + await JsonSerializer.SerializeAsync( + sourcePretranslateStream, + ProcessRowsAsync(), + new JsonSerializerOptions { WriteIndented = true, PropertyNamingPolicy = JsonNamingPolicy.CamelCase }, + cancellationToken: cancellationToken + ); + return corpusSize; + } + + protected override async Task CleanupAsync( + string engineId, + string buildId, + IReadOnlyList data, + IDistributedReaderWriterLock @lock, + JobCompletionStatus completionStatus + ) + { + if (completionStatus is JobCompletionStatus.Faulted or JobCompletionStatus.Canceled) + { + try + { + await _sharedFileService.DeleteAsync($"builds/{buildId}/"); + } + catch (Exception e) + { + Logger.LogWarning(e, "Unable to to delete job data for build {0}.", buildId); + } + } + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs new file mode 100644 index 000000000..f75003546 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs @@ -0,0 +1,139 @@ +namespace SIL.Machine.AspNetCore.Services; + +// TODO: The Hangfire implementation of the NMT train stage is not complete, DO NOT USE +// see https://github.com/sillsdev/machine/issues/103 +public class NmtTrainBuildJob : HangfireBuildJob +{ + private readonly ISharedFileService _sharedFileService; + private readonly IOptionsMonitor _options; + + public NmtTrainBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger logger, + ISharedFileService sharedFileService, + IOptionsMonitor options + ) + : base(platformService, engines, lockFactory, buildJobService, logger) + { + _sharedFileService = sharedFileService; + _options = options; + } + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + object? data, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + TranslationEngine? engine = await Engines.GetAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.BuildId == buildId, + cancellationToken + ); + if (engine is null) + throw new OperationCanceledException(); + + try + { + Installer.LogMessage += Log; + await Installer.SetupPython(); + await Installer.TryInstallPip(); + await PipInstallModuleAsync( + "sil-machine[jobs,huggingface,sentencepiece]", + cancellationToken: cancellationToken + ); + await PipInstallModuleAsync( + "torch", + indexUrl: "https://download.pytorch.org/whl/cu117", + cancellationToken: cancellationToken + ); + await PipInstallModuleAsync("accelerate", cancellationToken: cancellationToken); + + PythonEngine.Initialize(); + + using (Py.GIL()) + { + PythonEngine.Exec( + "from machine.jobs.build_nmt_engine import run\n" + + "args = {\n" + + $" 'model_type': '{_options.CurrentValue.ModelType}',\n" + + $" 'engine_id': '{engineId}',\n" + + $" 'build_id': '{buildId}',\n" + + $" 'src_lang': '{ConvertLanguageTag(engine.SourceLanguage)}',\n" + + $" 'trg_lang': '{ConvertLanguageTag(engine.TargetLanguage)}',\n" + + $" 'max_steps': {_options.CurrentValue.MaxSteps},\n" + + $" 'shared_file_uri': '{_sharedFileService.GetBaseUri()}',\n" + + $" 'clearml': False\n" + + "}\n" + + "run(args)\n" + ); + } + } + finally + { + Installer.LogMessage -= Log; + } + } + + private void Log(string message) + { + Logger.LogInformation(message); + } + + private static string ConvertLanguageTag(string languageTag) + { + if ( + !IetfLanguageTag.TryGetSubtags( + languageTag, + out LanguageSubtag languageSubtag, + out ScriptSubtag scriptSubtag, + out _, + out _ + ) + ) + return languageTag; + + // Convert to NLLB language codes + return $"{languageSubtag.Iso3Code}_{scriptSubtag.Code}"; + } + + public async Task PipInstallModuleAsync( + string module_name, + string version = "", + string indexUrl = "", + bool force = false, + CancellationToken cancellationToken = default + ) + { + try + { + Python.Deployment.Installer.LogMessage += Log; + if (!Installer.IsModuleInstalled(module_name) || force) + { + string text = Path.Combine(Python.Deployment.Installer.EmbeddedPythonHome, "Scripts", "pip"); + string text2 = (force ? " --force-reinstall" : ""); + if (version.Length > 0) + { + version = "==" + version; + } + if (indexUrl.Length > 0) + { + text2 += " --index-url " + indexUrl; + } + + await Python.Deployment.Installer.RunCommand( + text + " install " + module_name + version + " " + text2, + cancellationToken + ); + } + } + finally + { + Python.Deployment.Installer.LogMessage -= Log; + } + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs b/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs index 9bf3e2889..3df6c67c1 100644 --- a/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs +++ b/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs @@ -1,6 +1,8 @@ +using static SIL.Machine.AspNetCore.Utils.SharedFileUtils; + namespace SIL.Machine.AspNetCore.Services; -public class S3FileStorage : FileStorage +public class S3FileStorage : DisposableBase, IFileStorage { private readonly AmazonS3Client _client; private readonly string _bucketName; @@ -23,20 +25,18 @@ ILoggerFactory loggerFactory ); _bucketName = bucketName; - //Ultimately, object keys can neither begin nor end with slashes; this is what broke the earlier low-level implementation - _basePath = basePath.EndsWith("/") ? basePath.Remove(basePath.Length - 1, 1) : basePath; - _basePath = _basePath.StartsWith("/") ? _basePath.Remove(0, 1) : _basePath; + // Ultimately, object keys can neither begin nor end with slashes; this is what broke the earlier low-level + // implementation + _basePath = Normalize(basePath, includeTrailingSlash: true); _loggerFactory = loggerFactory; } - public override void Dispose() { } - - public override async Task Exists(string path, CancellationToken cancellationToken = default) + public async Task ExistsAsync(string path, CancellationToken cancellationToken = default) { var request = new ListObjectsV2Request { BucketName = _bucketName, - Prefix = _basePath + Normalize(path, includeTrailingSlash: path.EndsWith("/")), + Prefix = _basePath + Normalize(path), MaxKeys = 1 }; @@ -45,7 +45,7 @@ public override async Task Exists(string path, CancellationToken cancellat return response.S3Objects.Any(); } - public override async Task> Ls( + public async Task> ListFilesAsync( string? path = null, bool recurse = false, CancellationToken cancellationToken = default @@ -57,45 +57,43 @@ public override async Task> Ls( var request = new ListObjectsV2Request { BucketName = _bucketName, - Prefix = _basePath + Normalize(path, includeTrailingSlash: true), - MaxKeys = 1, + Prefix = _basePath + (string.IsNullOrEmpty(path) ? "" : Normalize(path, includeTrailingSlash: true)), Delimiter = recurse ? "" : "/" }; ListObjectsV2Response response = await _client.ListObjectsV2Async(request, cancellationToken); - return response.S3Objects.Select(s3Obj => s3Obj.Key).ToList(); + return response.S3Objects.Select(s3Obj => s3Obj.Key[_basePath.Length..]).ToList(); } - public override async Task OpenRead(string path, CancellationToken cancellationToken = default) + public async Task OpenReadAsync(string path, CancellationToken cancellationToken = default) { - string objectId = _basePath + Normalize(path); - GetObjectRequest request = new() { BucketName = _bucketName, Key = objectId }; + GetObjectRequest request = new() { BucketName = _bucketName, Key = _basePath + Normalize(path) }; GetObjectResponse response = await _client.GetObjectAsync(request, cancellationToken); if (response.HttpStatusCode != HttpStatusCode.OK) - throw new FileNotFoundException($"File {objectId} does not exist"); + throw new FileNotFoundException($"File {path} does not exist"); return response.ResponseStream; } - public override async Task OpenWrite(string path, CancellationToken cancellationToken = default) + public async Task OpenWriteAsync(string path, CancellationToken cancellationToken = default) { - string objectId = _basePath + Normalize(path); - InitiateMultipartUploadRequest request = new() { BucketName = _bucketName, Key = objectId }; - InitiateMultipartUploadResponse response = await _client.InitiateMultipartUploadAsync(request); + string fullPath = _basePath + Normalize(path); + InitiateMultipartUploadRequest request = new() { BucketName = _bucketName, Key = fullPath }; + InitiateMultipartUploadResponse response = await _client.InitiateMultipartUploadAsync( + request, + cancellationToken + ); return new BufferedStream( - new S3WriteStream(_client, objectId, _bucketName, response.UploadId, _loggerFactory), + new S3WriteStream(_client, fullPath, _bucketName, response.UploadId, _loggerFactory), S3WriteStream.MaxPartSize ); } - public override async Task Rm(string path, bool recurse = false, CancellationToken cancellationToken = default) + public async Task DeleteAsync(string path, bool recurse = false, CancellationToken cancellationToken = default) { - if (path is null) - throw new ArgumentNullException(nameof(path)); - string objectId = _basePath + Normalize(path); - DeleteObjectRequest request = new() { BucketName = _bucketName, Key = objectId }; + DeleteObjectRequest request = new() { BucketName = _bucketName, Key = _basePath + Normalize(path) }; DeleteObjectResponse response = await _client.DeleteObjectAsync(request, cancellationToken); - if (!response.HttpStatusCode.Equals(HttpStatusCode.OK)) - new HttpRequestException( + if (!response.HttpStatusCode.Equals(HttpStatusCode.NoContent)) + throw new HttpRequestException( $"Received status code {response.HttpStatusCode} when attempting to delete {path}" ); } diff --git a/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs b/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs index cef73bbeb..db349dbaa 100644 --- a/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs @@ -3,7 +3,7 @@ public class SharedFileService : ISharedFileService { private readonly Uri? _baseUri; - private readonly FileStorage _fileStorage; + private readonly IFileStorage _fileStorage; private readonly bool _supportFolderDelete = true; private readonly ILoggerFactory _loggerFactory; @@ -25,7 +25,6 @@ public SharedFileService(ILoggerFactory loggerFactory, IOptions OpenReadAsync(string path, CancellationToken cancellationToken = default) { - return _fileStorage.OpenRead(path, cancellationToken); + return _fileStorage.OpenReadAsync(path, cancellationToken); } public Task OpenWriteAsync(string path, CancellationToken cancellationToken = default) { - return _fileStorage.OpenWrite(path, cancellationToken); + return _fileStorage.OpenWriteAsync(path, cancellationToken); } public async Task DeleteAsync(string path, CancellationToken cancellationToken = default) { if (!_supportFolderDelete && path.EndsWith("/")) { - IReadOnlyCollection files = await _fileStorage.Ls(path, recurse: true, cancellationToken); + IReadOnlyCollection files = await _fileStorage.ListFilesAsync( + path, + recurse: true, + cancellationToken + ); foreach (string file in files) - await _fileStorage.Rm(file, cancellationToken: cancellationToken); + await _fileStorage.DeleteAsync(file, cancellationToken: cancellationToken); } else { - await _fileStorage.Rm(path, recurse: true, cancellationToken: cancellationToken); + await _fileStorage.DeleteAsync(path, recurse: true, cancellationToken: cancellationToken); } } public Task ExistsAsync(string path, CancellationToken cancellationToken = default) { - return _fileStorage.Exists(path, cancellationToken); - } - - public Task> Ls( - string path, - bool recurse = false, - CancellationToken cancellationToken = default - ) - { - return _fileStorage.Ls(path, recurse, cancellationToken); + return _fileStorage.ExistsAsync(path, cancellationToken); } } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs new file mode 100644 index 000000000..7244cc460 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferBuildJob.cs @@ -0,0 +1,126 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class SmtTransferBuildJob : HangfireBuildJob> +{ + private readonly IRepository _trainSegmentPairs; + private readonly ITruecaserFactory _truecaserFactory; + private readonly ISmtModelFactory _smtModelFactory; + private readonly ICorpusService _corpusService; + + public SmtTransferBuildJob( + IPlatformService platformService, + IRepository engines, + IDistributedReaderWriterLockFactory lockFactory, + IBuildJobService buildJobService, + ILogger logger, + IRepository trainSegmentPairs, + ITruecaserFactory truecaserFactory, + ISmtModelFactory smtModelFactory, + ICorpusService corpusService + ) + : base(platformService, engines, lockFactory, buildJobService, logger) + { + _trainSegmentPairs = trainSegmentPairs; + _truecaserFactory = truecaserFactory; + _smtModelFactory = smtModelFactory; + _corpusService = corpusService; + } + + protected override Task InitializeAsync( + string engineId, + string buildId, + IReadOnlyList data, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + return _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, cancellationToken); + } + + protected override async Task DoWorkAsync( + string engineId, + string buildId, + IReadOnlyList data, + IDistributedReaderWriterLock @lock, + CancellationToken cancellationToken + ) + { + await PlatformService.BuildStartedAsync(buildId, cancellationToken); + Logger.LogInformation("Build started ({0})", buildId); + var stopwatch = new Stopwatch(); + stopwatch.Start(); + + cancellationToken.ThrowIfCancellationRequested(); + + var targetCorpora = new List(); + var parallelCorpora = new List(); + foreach (Corpus corpus in data) + { + ITextCorpus sc = _corpusService.CreateTextCorpus(corpus.SourceFiles); + ITextCorpus tc = _corpusService.CreateTextCorpus(corpus.TargetFiles); + + targetCorpora.Add(tc); + parallelCorpora.Add(sc.AlignRows(tc)); + } + + IParallelTextCorpus parallelCorpus = parallelCorpora.Flatten(); + ITextCorpus targetCorpus = targetCorpora.Flatten(); + + var tokenizer = new LatinWordTokenizer(); + var detokenizer = new LatinWordDetokenizer(); + + using ITrainer smtModelTrainer = _smtModelFactory.CreateTrainer(engineId, tokenizer, parallelCorpus); + using ITrainer truecaseTrainer = _truecaserFactory.CreateTrainer(engineId, tokenizer, targetCorpus); + + cancellationToken.ThrowIfCancellationRequested(); + + var progress = new BuildProgress(PlatformService, buildId); + await smtModelTrainer.TrainAsync(progress, cancellationToken); + await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); + + TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new OperationCanceledException(); + + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + { + cancellationToken.ThrowIfCancellationRequested(); + await smtModelTrainer.SaveAsync(CancellationToken.None); + await truecaseTrainer.SaveAsync(CancellationToken.None); + ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId); + IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( + p => p.TranslationEngineRef == engine.Id, + CancellationToken.None + ); + using ( + IInteractiveTranslationModel smtModel = _smtModelFactory.Create( + engineId, + tokenizer, + detokenizer, + truecaser + ) + ) + { + foreach (TrainSegmentPair segmentPair in segmentPairs) + { + await smtModel.TrainSegmentAsync( + segmentPair.Source, + segmentPair.Target, + cancellationToken: CancellationToken.None + ); + } + } + + await PlatformService.BuildCompletedAsync( + buildId, + smtModelTrainer.Stats.TrainCorpusSize + segmentPairs.Count, + smtModelTrainer.Stats.Metrics["bleu"] * 100.0, + CancellationToken.None + ); + await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, CancellationToken.None); + } + + stopwatch.Stop(); + Logger.LogInformation("Build completed in {0}s ({1})", stopwatch.Elapsed.TotalSeconds, buildId); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineBuildJob.cs deleted file mode 100644 index 6ad7ab411..000000000 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineBuildJob.cs +++ /dev/null @@ -1,215 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public class SmtTransferEngineBuildJob -{ - private readonly IPlatformService _platformService; - private readonly IRepository _engines; - private readonly IRepository _trainSegmentPairs; - private readonly IDistributedReaderWriterLockFactory _lockFactory; - private readonly ITruecaserFactory _truecaserFactory; - private readonly ISmtModelFactory _smtModelFactory; - private readonly ICorpusService _corpusService; - - private readonly ILogger _logger; - - public SmtTransferEngineBuildJob( - IPlatformService platformService, - IRepository engines, - IRepository trainSegmentPairs, - IDistributedReaderWriterLockFactory lockFactory, - ITruecaserFactory truecaserFactory, - ISmtModelFactory smtModelFactory, - ICorpusService corpusService, - ILogger logger - ) - { - _platformService = platformService; - _engines = engines; - _trainSegmentPairs = trainSegmentPairs; - _lockFactory = lockFactory; - _truecaserFactory = truecaserFactory; - _smtModelFactory = smtModelFactory; - _corpusService = corpusService; - _logger = logger; - } - - [Queue("smt_transfer")] - [AutomaticRetry(Attempts = 0)] - public async Task RunAsync( - string engineId, - string buildId, - IReadOnlyList corpora, - CancellationToken cancellationToken - ) - { - IDistributedReaderWriterLock rwLock = await _lockFactory.CreateAsync(engineId, cancellationToken); - var tokenizer = new LatinWordTokenizer(); - var detokenizer = new LatinWordDetokenizer(); - ITrainer? smtModelTrainer = null; - ITrainer? truecaseTrainer = null; - try - { - var stopwatch = new Stopwatch(); - TranslationEngine? engine; - await using (await rwLock.WriterLockAsync(cancellationToken: cancellationToken)) - { - engine = await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId && !e.IsCanceled, - u => u.Set(e => e.BuildState, BuildState.Active), - cancellationToken: cancellationToken - ); - if (engine is null) - throw new OperationCanceledException(); - - await _platformService.BuildStartedAsync(buildId, cancellationToken); - _logger.LogInformation("Build started ({0})", buildId); - stopwatch.Start(); - - await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, cancellationToken); - - cancellationToken.ThrowIfCancellationRequested(); - - var targetCorpora = new List(); - var parallelCorpora = new List(); - foreach (Corpus corpus in corpora) - { - ITextCorpus sc = _corpusService.CreateTextCorpus(corpus.SourceFiles); - ITextCorpus tc = _corpusService.CreateTextCorpus(corpus.TargetFiles); - - targetCorpora.Add(tc); - parallelCorpora.Add(sc.AlignRows(tc)); - } - - IParallelTextCorpus parallelCorpus = parallelCorpora.Flatten(); - ITextCorpus targetCorpus = targetCorpora.Flatten(); - - smtModelTrainer = _smtModelFactory.CreateTrainer(engineId, tokenizer, parallelCorpus); - truecaseTrainer = _truecaserFactory.CreateTrainer(engineId, tokenizer, targetCorpus); - } - - cancellationToken.ThrowIfCancellationRequested(); - - var progress = new BuildProgress(_platformService, buildId); - await smtModelTrainer.TrainAsync(progress, cancellationToken); - await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); - int trainSegmentPairCount; - await using (await rwLock.WriterLockAsync(cancellationToken: cancellationToken)) - { - cancellationToken.ThrowIfCancellationRequested(); - await smtModelTrainer.SaveAsync(CancellationToken.None); - await truecaseTrainer.SaveAsync(CancellationToken.None); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId); - IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( - p => p.TranslationEngineRef == engine!.Id, - CancellationToken.None - ); - using ( - IInteractiveTranslationModel smtModel = _smtModelFactory.Create( - engineId, - tokenizer, - detokenizer, - truecaser - ) - ) - { - foreach (TrainSegmentPair segmentPair in segmentPairs) - { - await smtModel.TrainSegmentAsync( - segmentPair.Source, - segmentPair.Target, - cancellationToken: CancellationToken.None - ); - } - } - - trainSegmentPairCount = segmentPairs.Count; - - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - u => - u.Set(e => e.BuildState, BuildState.None) - .Set(e => e.IsCanceled, false) - .Inc(e => e.BuildRevision) - .Unset(e => e.JobId) - .Unset(e => e.BuildId), - cancellationToken: CancellationToken.None - ); - } - - await _platformService.BuildCompletedAsync( - buildId, - smtModelTrainer.Stats.TrainCorpusSize + trainSegmentPairCount, - smtModelTrainer.Stats.Metrics["bleu"] * 100.0, - CancellationToken.None - ); - - stopwatch.Stop(); - _logger.LogInformation("Build completed in {0}s ({1})", stopwatch.Elapsed.TotalSeconds, buildId); - } - catch (OperationCanceledException) - { - // Check if the cancellation was initiated by an API call or a shutdown. - TranslationEngine? engine = await _engines.GetAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - CancellationToken.None - ); - if (engine is null || engine.IsCanceled) - { - await using (await rwLock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - u => - u.Set(e => e.BuildState, BuildState.None) - .Set(e => e.IsCanceled, false) - .Unset(e => e.JobId) - .Unset(e => e.BuildId), - cancellationToken: CancellationToken.None - ); - } - await _platformService.BuildCanceledAsync(buildId, CancellationToken.None); - _logger.LogInformation("Build canceled ({0})", buildId); - } - else if (engine is not null) - { - // the build was canceled, because of a server shutdown - // switch state back to pending - await using (await rwLock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId && e.BuildState == BuildState.Active, - u => u.Set(e => e.BuildState, BuildState.Pending), - cancellationToken: CancellationToken.None - ); - } - await _platformService.BuildRestartingAsync(buildId, CancellationToken.None); - } - - throw; - } - catch (Exception e) - { - await using (await rwLock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await _engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - u => - u.Set(e => e.BuildState, BuildState.None) - .Set(e => e.IsCanceled, false) - .Unset(e => e.JobId) - .Unset(e => e.BuildId), - cancellationToken: CancellationToken.None - ); - } - - await _platformService.BuildFaultedAsync(buildId, e.Message, CancellationToken.None); - _logger.LogError(0, e, "Build faulted ({0})", buildId); - throw; - } - finally - { - smtModelTrainer?.Dispose(); - truecaseTrainer?.Dispose(); - } - } -} diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineCommitService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineCommitService.cs index 8b08cb004..de9704264 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineCommitService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineCommitService.cs @@ -1,45 +1,40 @@ namespace SIL.Machine.AspNetCore.Services; -public class SmtTransferEngineCommitService : DisposableBase, IHostedService +public class SmtTransferEngineCommitService : RecurrentTask { - private readonly IServiceProvider _services; private readonly IOptionsMonitor _engineOptions; private readonly SmtTransferEngineStateService _stateService; - private readonly AsyncTimer _commitTimer; + private readonly ILogger _logger; public SmtTransferEngineCommitService( IServiceProvider services, IOptionsMonitor engineOptions, - SmtTransferEngineStateService stateService + SmtTransferEngineStateService stateService, + ILogger logger ) + : base("SMT transfer engine commit service", services, engineOptions.CurrentValue.EngineCommitFrequency, logger) { - _services = services; _engineOptions = engineOptions; _stateService = stateService; - _commitTimer = new AsyncTimer(EngineCommitAsync); + _logger = logger; } - public Task StartAsync(CancellationToken cancellationToken) + protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken) { - _commitTimer.Start(_engineOptions.CurrentValue.EngineCommitFrequency); - return Task.CompletedTask; - } - - public async Task StopAsync(CancellationToken cancellationToken) - { - await _commitTimer.StopAsync(); - } - - private async Task EngineCommitAsync() - { - using IServiceScope scope = _services.CreateScope(); - var engines = scope.ServiceProvider.GetRequiredService>(); - var lockFactory = scope.ServiceProvider.GetRequiredService(); - await _stateService.CommitAsync(lockFactory, engines, _engineOptions.CurrentValue.InactiveEngineTimeout); - } - - protected override void DisposeManagedResources() - { - _commitTimer.Dispose(); + try + { + var engines = scope.ServiceProvider.GetRequiredService>(); + var lockFactory = scope.ServiceProvider.GetRequiredService(); + await _stateService.CommitAsync( + lockFactory, + engines, + _engineOptions.CurrentValue.InactiveEngineTimeout, + cancellationToken + ); + } + catch (Exception e) + { + _logger.LogError(e, "Error occurred while committing SMT transfer engines."); + } } } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs index 2cfccdc7e..14fc20a8c 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs @@ -1,28 +1,42 @@ namespace SIL.Machine.AspNetCore.Services; -public class SmtTransferEngineService : TranslationEngineServiceBase +public static class SmtTransferBuildStages { + public const string Train = "train"; +} + +public class SmtTransferEngineService : ITranslationEngineService +{ + private readonly IDistributedReaderWriterLockFactory _lockFactory; + private readonly IPlatformService _platformService; + private readonly IDataAccessContext _dataAccessContext; + private readonly IRepository _engines; private readonly IRepository _trainSegmentPairs; private readonly SmtTransferEngineStateService _stateService; + private readonly IBuildJobService _buildJobService; public SmtTransferEngineService( - IBackgroundJobClient jobClient, IDistributedReaderWriterLockFactory lockFactory, IPlatformService platformService, IDataAccessContext dataAccessContext, IRepository engines, IRepository trainSegmentPairs, - SmtTransferEngineStateService stateService + SmtTransferEngineStateService stateService, + IBuildJobService buildJobService ) - : base(jobClient, lockFactory, platformService, dataAccessContext, engines) { + _lockFactory = lockFactory; + _platformService = platformService; + _dataAccessContext = dataAccessContext; + _engines = engines; _trainSegmentPairs = trainSegmentPairs; _stateService = stateService; + _buildJobService = buildJobService; } - public override TranslationEngineType Type => TranslationEngineType.SmtTransfer; + public TranslationEngineType Type => TranslationEngineType.SmtTransfer; - public override async Task CreateAsync( + public async Task CreateAsync( string engineId, string? engineName, string sourceLanguage, @@ -30,47 +44,60 @@ public override async Task CreateAsync( CancellationToken cancellationToken = default ) { - await base.CreateAsync(engineId, engineName, sourceLanguage, targetLanguage, cancellationToken); + await _dataAccessContext.BeginTransactionAsync(cancellationToken); + await _engines.InsertAsync( + new TranslationEngine + { + EngineId = engineId, + SourceLanguage = sourceLanguage, + TargetLanguage = targetLanguage + }, + cancellationToken + ); + await _buildJobService.CreateEngineAsync(new[] { BuildJobType.Cpu }, engineId, engineName, cancellationToken); + await _dataAccessContext.CommitTransactionAsync(CancellationToken.None); - SmtTransferEngineState state = _stateService.Get(engineId); - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, CancellationToken.None); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, CancellationToken.None); await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) { + SmtTransferEngineState state = _stateService.Get(engineId); state.InitNew(); } } - public override async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) + public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) { - await base.DeleteAsync(engineId, cancellationToken); - if (_stateService.TryRemove(engineId, out SmtTransferEngineState? state)) + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, CancellationToken.None); - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - // ensure that there is no build running before unloading - string? buildId = await CancelBuildInternalAsync(engineId, CancellationToken.None); - if (buildId is not null) - await WaitForBuildToFinishAsync(engineId, buildId, CancellationToken.None); + await CancelBuildJobAsync(engineId, cancellationToken); + + await _dataAccessContext.BeginTransactionAsync(cancellationToken); + await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); + await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, cancellationToken); + await _dataAccessContext.CommitTransactionAsync(CancellationToken.None); + if (_stateService.TryRemove(engineId, out SmtTransferEngineState? state)) + { await state.DeleteDataAsync(); await state.DisposeAsync(); } } + await _lockFactory.DeleteAsync(engineId, CancellationToken.None); } - public override async Task> TranslateAsync( + public async Task> TranslateAsync( string engineId, int n, string segment, CancellationToken cancellationToken = default ) { - SmtTransferEngineState state = _stateService.Get(engineId); - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) { TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); IReadOnlyList results = await hybridEngine.TranslateAsync(n, segment, cancellationToken); state.LastUsedTime = DateTime.Now; @@ -78,17 +105,17 @@ public override async Task> TranslateAsync( } } - public override async Task GetWordGraphAsync( + public async Task GetWordGraphAsync( string engineId, string segment, CancellationToken cancellationToken = default ) { - SmtTransferEngineState state = _stateService.Get(engineId); - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) { TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); WordGraph result = await hybridEngine.GetWordGraphAsync(segment, cancellationToken); state.LastUsedTime = DateTime.Now; @@ -96,7 +123,7 @@ public override async Task GetWordGraphAsync( } } - public override async Task TrainSegmentPairAsync( + public async Task TrainSegmentPairAsync( string engineId, string sourceSegment, string targetSegment, @@ -104,18 +131,17 @@ public override async Task TrainSegmentPairAsync( CancellationToken cancellationToken = default ) { - SmtTransferEngineState state = _stateService.Get(engineId); - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); - if (engine.BuildState is BuildState.Active) + if (engine.CurrentBuild?.JobState is BuildJobState.Active) { - await DataAccessContext.BeginTransactionAsync(cancellationToken); + await _dataAccessContext.BeginTransactionAsync(cancellationToken); await _trainSegmentPairs.InsertAsync( new TrainSegmentPair { - TranslationEngineRef = engine.Id, + TranslationEngineRef = engineId, Source = sourceSegment, Target = targetSegment, SentenceStart = sentenceStart @@ -124,50 +150,79 @@ await _trainSegmentPairs.InsertAsync( ); } + SmtTransferEngineState state = _stateService.Get(engineId); HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); await hybridEngine.TrainSegmentAsync(sourceSegment, targetSegment, sentenceStart, cancellationToken); - await PlatformService.IncrementTrainSizeAsync(engineId, cancellationToken: CancellationToken.None); - if (engine.BuildState is BuildState.Active) - await DataAccessContext.CommitTransactionAsync(CancellationToken.None); + await _platformService.IncrementTrainSizeAsync(engineId, cancellationToken: CancellationToken.None); + if (engine.CurrentBuild?.JobState is BuildJobState.Active) + await _dataAccessContext.CommitTransactionAsync(CancellationToken.None); state.IsUpdated = true; state.LastUsedTime = DateTime.Now; } } - public override async Task StartBuildAsync( + public async Task StartBuildAsync( string engineId, string buildId, IReadOnlyList corpora, CancellationToken cancellationToken = default ) { - SmtTransferEngineState state = _stateService.Get(engineId); - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { - await StartBuildInternalAsync(engineId, buildId, corpora, cancellationToken); + // If there is a pending/running build, then no need to start a new one. + if (await _buildJobService.IsEngineBuilding(engineId, cancellationToken)) + throw new InvalidOperationException("The engine has already started a build."); + + await _buildJobService.StartBuildJobAsync( + BuildJobType.Cpu, + TranslationEngineType.SmtTransfer, + engineId, + buildId, + SmtTransferBuildStages.Train, + corpora, + cancellationToken + ); + SmtTransferEngineState state = _stateService.Get(engineId); state.LastUsedTime = DateTime.UtcNow; } } - public override async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) + public async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) { - SmtTransferEngineState state = _stateService.Get(engineId); - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { - await CancelBuildInternalAsync(engineId, cancellationToken); + await CancelBuildJobAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); state.LastUsedTime = DateTime.UtcNow; } } - protected override Expression> GetJobExpression( - string engineId, - string buildId, - IReadOnlyList corpora - ) + private async Task CancelBuildJobAsync(string engineId, CancellationToken cancellationToken) + { + (string? buildId, BuildJobState jobState) = await _buildJobService.CancelBuildJobAsync( + engineId, + cancellationToken + ); + if (buildId is not null && jobState is BuildJobState.None) + await _platformService.BuildCanceledAsync(buildId, CancellationToken.None); + } + + private async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) + { + TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new InvalidOperationException($"The engine {engineId} does not exist."); + return engine; + } + + private async Task GetBuiltEngineAsync(string engineId, CancellationToken cancellationToken) { - // Token "None" is used here because hangfire injects the proper cancellation token - return r => r.RunAsync(engineId, buildId, corpora, CancellationToken.None); + TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + if (engine.BuildRevision == 0) + throw new EngineNotBuiltException("The engine must be built first."); + return engine; } } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs index 28abc4b23..072dacf45 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineState.cs @@ -71,7 +71,11 @@ public async Task DeleteDataAsync() _truecaserFactory.Cleanup(EngineId); } - public async Task CommitAsync(int buildRevision, TimeSpan inactiveTimeout) + public async Task CommitAsync( + int buildRevision, + TimeSpan inactiveTimeout, + CancellationToken cancellationToken = default + ) { if (_hybridEngine is null) return; @@ -80,34 +84,34 @@ public async Task CommitAsync(int buildRevision, TimeSpan inactiveTimeout) CurrentBuildRevision = buildRevision; if (buildRevision != CurrentBuildRevision) { - await UnloadAsync(); + await UnloadAsync(cancellationToken); CurrentBuildRevision = buildRevision; } else if (DateTime.Now - LastUsedTime > inactiveTimeout) { - await UnloadAsync(); + await UnloadAsync(cancellationToken); } else { - await SaveModelAsync(); + await SaveModelAsync(cancellationToken); } } - private async Task SaveModelAsync() + private async Task SaveModelAsync(CancellationToken cancellationToken = default) { if (_smtModel is not null && IsUpdated) { - await _smtModel.SaveAsync(); + await _smtModel.SaveAsync(cancellationToken); IsUpdated = false; } } - private async Task UnloadAsync() + private async Task UnloadAsync(CancellationToken cancellationToken = default) { if (_hybridEngine is null) return; - await SaveModelAsync(); + await SaveModelAsync(cancellationToken); _hybridEngine.Dispose(); diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs index f00e3d472..cb0429d44 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineStateService.cs @@ -33,17 +33,26 @@ public bool TryRemove(string engineId, [MaybeNullWhen(false)] out SmtTransferEng public async Task CommitAsync( IDistributedReaderWriterLockFactory lockFactory, IRepository engines, - TimeSpan inactiveTimeout + TimeSpan inactiveTimeout, + CancellationToken cancellationToken = default ) { foreach (SmtTransferEngineState state in _engineStates.Values) { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(state.EngineId); - await using (await @lock.WriterLockAsync()) + IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(state.EngineId, cancellationToken); + await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) { - TranslationEngine? engine = await engines.GetAsync(e => e.EngineId == state.EngineId); - if (engine is not null && engine.BuildState is not BuildState.Active) - await state.CommitAsync(engine.BuildRevision, inactiveTimeout); + TranslationEngine? engine = await engines.GetAsync( + e => e.EngineId == state.EngineId, + cancellationToken + ); + if ( + engine is not null + && (engine.CurrentBuild is null || engine.CurrentBuild.JobState is BuildJobState.Pending) + ) + { + await state.CommitAsync(engine.BuildRevision, inactiveTimeout, cancellationToken); + } } } } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs new file mode 100644 index 000000000..533184140 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferHangfireBuildJobFactory.cs @@ -0,0 +1,18 @@ +using static SIL.Machine.AspNetCore.Services.HangfireBuildJobRunner; + +namespace SIL.Machine.AspNetCore.Services; + +public class SmtTransferHangfireBuildJobFactory : IHangfireBuildJobFactory +{ + public TranslationEngineType EngineType => TranslationEngineType.SmtTransfer; + + public Job CreateJob(string engineId, string buildId, string stage, object? data) + { + return stage switch + { + SmtTransferBuildStages.Train + => CreateJob>(engineId, buildId, "smt_transfer", data), + _ => throw new ArgumentException("Unknown build stage.", nameof(stage)), + }; + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs b/src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs deleted file mode 100644 index bb11a4395..000000000 --- a/src/SIL.Machine.AspNetCore/Services/TranslationEngineServiceBase.cs +++ /dev/null @@ -1,228 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -public abstract class TranslationEngineServiceBase : ITranslationEngineService -{ - private readonly IBackgroundJobClient _jobClient; - - protected TranslationEngineServiceBase( - IBackgroundJobClient jobClient, - IDistributedReaderWriterLockFactory lockFactory, - IPlatformService platformService, - IDataAccessContext dataAccessContext, - IRepository engines - ) - { - _jobClient = jobClient; - LockFactory = lockFactory; - PlatformService = platformService; - DataAccessContext = dataAccessContext; - Engines = engines; - } - - protected IRepository Engines { get; } - protected IDistributedReaderWriterLockFactory LockFactory { get; } - protected IPlatformService PlatformService { get; } - protected IDataAccessContext DataAccessContext { get; } - - public abstract TranslationEngineType Type { get; } - - public virtual async Task CreateAsync( - string engineId, - string? engineName, - string sourceLanguage, - string targetLanguage, - CancellationToken cancellationToken = default - ) - { - await Engines.InsertAsync( - new TranslationEngine - { - EngineId = engineId, - SourceLanguage = sourceLanguage, - TargetLanguage = targetLanguage - }, - cancellationToken - ); - } - - public virtual async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) - { - await DataAccessContext.BeginTransactionAsync(cancellationToken); - await Engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); - await LockFactory.DeleteAsync(engineId, cancellationToken); - await DataAccessContext.CommitTransactionAsync(CancellationToken.None); - } - - public virtual async Task StartBuildAsync( - string engineId, - string buildId, - IReadOnlyList corpora, - CancellationToken cancellationToken = default - ) - { - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await StartBuildInternalAsync(engineId, buildId, corpora, cancellationToken); - } - } - - public virtual async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) - { - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await CancelBuildInternalAsync(engineId, cancellationToken); - } - } - - public virtual Task> TranslateAsync( - string engineId, - int n, - string segment, - CancellationToken cancellationToken = default - ) - { - throw new NotSupportedException(); - } - - public virtual Task GetWordGraphAsync( - string engineId, - string segment, - CancellationToken cancellationToken = default - ) - { - throw new NotSupportedException(); - } - - public virtual Task TrainSegmentPairAsync( - string engineId, - string sourceSegment, - string targetSegment, - bool sentenceStart, - CancellationToken cancellationToken = default - ) - { - throw new NotSupportedException(); - } - - protected abstract Expression> GetJobExpression( - string engineId, - string buildId, - IReadOnlyList corpora - ); - - protected async Task StartBuildInternalAsync( - string engineId, - string buildId, - IReadOnlyList corpora, - CancellationToken cancellationToken - ) - { - // If there is a pending job, then no need to start a new one. - if ( - await Engines.ExistsAsync( - e => - e.EngineId == engineId && (e.BuildState == BuildState.Pending || e.BuildState == BuildState.Active), - cancellationToken - ) - ) - throw new InvalidOperationException("Engine is already building or pending."); - - // Schedule the job to occur way in the future, just so we can get the job id. - string jobId = _jobClient.Schedule(GetJobExpression(engineId, buildId, corpora), TimeSpan.FromDays(10000)); - try - { - await Engines.UpdateAsync( - e => e.EngineId == engineId, - u => - u.Set(e => e.BuildState, BuildState.Pending) - .Set(e => e.IsCanceled, false) - .Set(e => e.JobId, jobId) - .Set(e => e.BuildId, buildId), - cancellationToken: CancellationToken.None - ); - // Enqueue the job now that the build has been created. - _jobClient.Requeue(jobId); - } - catch - { - _jobClient.Delete(jobId); - throw; - } - } - - protected async Task CancelBuildInternalAsync(string engineId, CancellationToken cancellationToken) - { - await DataAccessContext.BeginTransactionAsync(cancellationToken); - // First, try to cancel a job that hasn't started yet - TranslationEngine? engine = await Engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildState == BuildState.Pending, - u => u.Set(b => b.BuildState, BuildState.None).Set(e => e.IsCanceled, true), - cancellationToken: cancellationToken - ); - bool notifyPlatform = false; - if (engine is not null) - { - notifyPlatform = true; - } - else - { - // Second, try to cancel a job that is already running - engine = await Engines.UpdateAsync( - e => e.EngineId == engineId && e.BuildState == BuildState.Active, - u => u.Set(b => b.IsCanceled, true), - cancellationToken: cancellationToken - ); - } - if (engine is not null) - { - // If pending, the job will be deleted from the queue, otherwise this will trigger the cancellation token - _jobClient.Delete(engine.JobId); - if (notifyPlatform) - await PlatformService.BuildCanceledAsync(engine.BuildId!, CancellationToken.None); - } - await DataAccessContext.CommitTransactionAsync(CancellationToken.None); - return engine?.BuildId; - } - - protected async Task WaitForBuildToFinishAsync( - string engineId, - string buildId, - CancellationToken cancellationToken - ) - { - using ISubscription sub = await Engines.SubscribeAsync( - e => e.EngineId == engineId && e.BuildId == buildId, - cancellationToken - ); - if (sub.Change.Entity is null) - return true; - - var timeout = DateTime.UtcNow + TimeSpan.FromSeconds(20); - while (DateTime.UtcNow < timeout) - { - await sub.WaitForChangeAsync(TimeSpan.FromSeconds(2), cancellationToken); - TranslationEngine? engine = sub.Change.Entity; - if (engine is null || engine.BuildState is BuildState.None) - return true; - } - return false; - } - - protected async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) - { - TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); - if (engine is null) - throw new InvalidOperationException($"Engine with id {engineId} does not exist"); - return engine; - } - - protected async Task GetBuiltEngineAsync(string engineId, CancellationToken cancellationToken) - { - TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); - if (engine.BuildState != BuildState.None || engine.BuildRevision == 0) - throw new EngineNotBuiltException("The engine must be built first"); - return engine; - } -} diff --git a/src/SIL.Machine.AspNetCore/Usings.cs b/src/SIL.Machine.AspNetCore/Usings.cs index 825a0b2f4..ea5c6e984 100644 --- a/src/SIL.Machine.AspNetCore/Usings.cs +++ b/src/SIL.Machine.AspNetCore/Usings.cs @@ -1,4 +1,3 @@ -global using System.Collections; global using System.Collections.Concurrent; global using System.Diagnostics; global using System.Diagnostics.CodeAnalysis; @@ -7,20 +6,23 @@ global using System.Net; global using System.Reflection; global using System.Runtime.CompilerServices; +global using System.Security.Cryptography; global using System.Text; global using System.Text.Json; global using System.Text.Json.Nodes; global using Amazon; +global using Amazon.Runtime; global using Amazon.S3; global using Amazon.S3.Model; -global using Amazon.Runtime; global using Grpc.Core; global using Grpc.Core.Interceptors; global using Grpc.Net.Client.Configuration; global using Hangfire; +global using Hangfire.Common; global using Hangfire.Mongo; global using Hangfire.Mongo.Migration.Strategies; global using Hangfire.Mongo.Migration.Strategies.Backup; +global using Hangfire.States; global using Microsoft.AspNetCore.Routing; global using Microsoft.Extensions.Configuration; global using Microsoft.Extensions.DependencyInjection; @@ -33,6 +35,8 @@ global using Nito.AsyncEx; global using Nito.AsyncEx.Synchronous; global using Polly; +global using Python.Included; +global using Python.Runtime; global using SIL.DataAccess; global using SIL.Machine.AspNetCore.Configuration; global using SIL.Machine.AspNetCore.Models; @@ -45,5 +49,5 @@ global using SIL.Machine.Translation.Thot; global using SIL.Machine.Utils; global using SIL.ObjectModel; -global using SIL.WritingSystems; global using SIL.Scripture; +global using SIL.WritingSystems; diff --git a/src/SIL.Machine.AspNetCore/Utils/RecurrentTask.cs b/src/SIL.Machine.AspNetCore/Utils/RecurrentTask.cs new file mode 100644 index 000000000..8c91d92f8 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Utils/RecurrentTask.cs @@ -0,0 +1,49 @@ +namespace SIL.Machine.AspNetCore.Utils; + +public abstract class RecurrentTask : BackgroundService +{ + private readonly bool _enable; + private readonly string _serviceName; + private readonly IServiceProvider _services; + private readonly TimeSpan _period; + private readonly ILogger _logger; + + protected RecurrentTask( + string serviceName, + IServiceProvider services, + TimeSpan period, + ILogger logger, + bool enable = true + ) + { + _enable = enable; + _serviceName = serviceName; + _services = services; + _period = period; + _logger = logger; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + if (!_enable) + return; + + using PeriodicTimer timer = new(_period); + + _logger.LogInformation($"{_serviceName} started."); + + try + { + while (await timer.WaitForNextTickAsync(stoppingToken)) + { + using IServiceScope scope = _services.CreateScope(); + await DoWorkAsync(scope, stoppingToken); + } + } + catch (OperationCanceledException) { } + + _logger.LogInformation($"{_serviceName} stopped."); + } + + protected abstract Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken); +} diff --git a/src/SIL.Machine.AspNetCore/Utils/SharedFileUtils.cs b/src/SIL.Machine.AspNetCore/Utils/SharedFileUtils.cs new file mode 100644 index 000000000..78521f458 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Utils/SharedFileUtils.cs @@ -0,0 +1,28 @@ +namespace SIL.Machine.AspNetCore.Utils; + +public static class SharedFileUtils +{ + public static string Normalize(string path, bool includeLeadingSlash = false, bool includeTrailingSlash = false) + { + string normalizedPath = path; + if (normalizedPath == "/") + return normalizedPath; + if (!includeLeadingSlash && normalizedPath.StartsWith("/")) + { + normalizedPath = normalizedPath.Remove(0, 1); + } + else if (includeLeadingSlash && !normalizedPath.StartsWith("/")) + { + normalizedPath = "/" + normalizedPath; + } + if (!includeTrailingSlash && normalizedPath.EndsWith("/")) + { + normalizedPath = normalizedPath.Remove(normalizedPath.Length - 1, 1); + } + else if (includeTrailingSlash && !normalizedPath.EndsWith("/")) + { + normalizedPath += "/"; + } + return normalizedPath; + } +} diff --git a/src/SIL.Machine.Serval.EngineServer/Program.cs b/src/SIL.Machine.Serval.EngineServer/Program.cs index fed98a249..0cb22049e 100644 --- a/src/SIL.Machine.Serval.EngineServer/Program.cs +++ b/src/SIL.Machine.Serval.EngineServer/Program.cs @@ -7,8 +7,9 @@ builder.Services .AddMachine(builder.Configuration) .AddMongoDataAccess() - .AddMongoBackgroundJobClient() - .AddServalTranslationEngineService(); + .AddMongoHangfireJobClient() + .AddServalTranslationEngineService() + .AddBuildJobService(); if (builder.Environment.IsDevelopment()) builder.Services .AddOpenTelemetry() diff --git a/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json b/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json index 52a414564..fa0d0a515 100644 --- a/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json +++ b/src/SIL.Machine.Serval.EngineServer/appsettings.Development.json @@ -4,13 +4,12 @@ "Mongo": "mongodb://localhost:27017/machine", "Serval": "https://localhost:8444" }, - "TranslationEngines": [ - "SmtTransfer", - "Nmt" - ], - "ClearMLNmtEngine": { - "ApiServer": "http://localhost:8008", - "Queue": "default" + "ClearML": { + "Queue": "jobs_backlog", + "MaxSteps": 1000 + }, + "SharedFile": { + "Uri": "s3://aqua-ml-data/dev/" }, "Logging": { "LogLevel": { diff --git a/src/SIL.Machine.Serval.EngineServer/appsettings.json b/src/SIL.Machine.Serval.EngineServer/appsettings.json index 4c066ba27..c9c085ed3 100644 --- a/src/SIL.Machine.Serval.EngineServer/appsettings.json +++ b/src/SIL.Machine.Serval.EngineServer/appsettings.json @@ -3,9 +3,23 @@ "Service": { "ServiceId": "machine_engine" }, + "TranslationEngines": [ + "SmtTransfer", + "Nmt" + ], + "BuildJob": { + "Runners": { + "Cpu": "Hangfire", + "Gpu": "ClearML" + } + }, "SmtTransferEngine": { "EnginesDir": "/var/lib/machine/engines" }, + "ClearML": { + "ApiServer": "https://api.sil.hosted.allegro.ai", + "BuildPollingEnabled": true + }, "Logging": { "LogLevel": { "System.Net.Http.HttpClient.Default": "Warning" diff --git a/src/SIL.Machine.Serval.JobServer/Program.cs b/src/SIL.Machine.Serval.JobServer/Program.cs index b93c36584..0e9ceb94d 100644 --- a/src/SIL.Machine.Serval.JobServer/Program.cs +++ b/src/SIL.Machine.Serval.JobServer/Program.cs @@ -5,9 +5,10 @@ builder.Services .AddMachine(builder.Configuration) .AddMongoDataAccess() - .AddMongoBackgroundJobClient() - .AddBackgroundJobServer() - .AddServalPlatformService(); + .AddMongoHangfireJobClient() + .AddHangfireJobServer() + .AddServalPlatformService() + .AddBuildJobService(); if (builder.Environment.IsDevelopment()) builder.Services .AddOpenTelemetry() diff --git a/src/SIL.Machine.Serval.JobServer/appsettings.Development.json b/src/SIL.Machine.Serval.JobServer/appsettings.Development.json index 03e0111dd..fa0d0a515 100644 --- a/src/SIL.Machine.Serval.JobServer/appsettings.Development.json +++ b/src/SIL.Machine.Serval.JobServer/appsettings.Development.json @@ -4,17 +4,12 @@ "Mongo": "mongodb://localhost:27017/machine", "Serval": "https://localhost:8444" }, - "TranslationEngines": [ - "SmtTransfer", - "Nmt" - ], - "ClearMLNmtEngine": { - "ApiServer": "http://localhost:8008", - "Queue": "default", - "DockerImage": "ghcr.io/sillsdev/machine.py:0.9.3.2" + "ClearML": { + "Queue": "jobs_backlog", + "MaxSteps": 1000 }, "SharedFile": { - "Uri": "s3://aqua-ml-data/" + "Uri": "s3://aqua-ml-data/dev/" }, "Logging": { "LogLevel": { diff --git a/src/SIL.Machine.Serval.JobServer/appsettings.json b/src/SIL.Machine.Serval.JobServer/appsettings.json index db8395d9e..007927ce5 100644 --- a/src/SIL.Machine.Serval.JobServer/appsettings.json +++ b/src/SIL.Machine.Serval.JobServer/appsettings.json @@ -3,9 +3,23 @@ "Service": { "ServiceId": "machine_job" }, + "TranslationEngines": [ + "SmtTransfer", + "Nmt" + ], + "BuildJob": { + "Runners": { + "Cpu": "Hangfire", + "Gpu": "ClearML" + } + }, "SmtTransferEngine": { "EnginesDir": "/var/lib/machine/engines" }, + "ClearML": { + "ApiServer": "https://api.sil.hosted.allegro.ai", + "BuildPollingEnabled": false + }, "Logging": { "LogLevel": { "System.Net.Http.HttpClient.Default": "Warning" diff --git a/src/SIL.Machine/Utils/TempDirectory.cs b/src/SIL.Machine/Utils/TempDirectory.cs index 031c66eac..d495ad5aa 100644 --- a/src/SIL.Machine/Utils/TempDirectory.cs +++ b/src/SIL.Machine/Utils/TempDirectory.cs @@ -1,7 +1,5 @@ -using SIL.ObjectModel; -using System; -using System.IO; -using System.Threading.Tasks; +using System.IO; +using SIL.ObjectModel; namespace SIL.Machine.Utils { diff --git a/tests/SIL.Machine.AspNetCore.Tests/SIL.Machine.AspNetCore.Tests.csproj b/tests/SIL.Machine.AspNetCore.Tests/SIL.Machine.AspNetCore.Tests.csproj index f3aa98ed9..f5db8989e 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/SIL.Machine.AspNetCore.Tests.csproj +++ b/tests/SIL.Machine.AspNetCore.Tests/SIL.Machine.AspNetCore.Tests.csproj @@ -9,11 +9,11 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive all - diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLNmtEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLNmtEngineServiceTests.cs deleted file mode 100644 index 614dabc52..000000000 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLNmtEngineServiceTests.cs +++ /dev/null @@ -1,192 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -[TestFixture] -public class ClearMLNmtEngineServiceTests -{ - [Test] - public async Task CancelBuildAsync() - { - using var env = new TestEnvironment(); - env.ClearMLService - .CreateTaskAsync( - Arg.Any(), - "project1", - "engine1", - "es", - "en", - "memory:///", - Arg.Any() - ) - .Returns(Task.FromResult("task1")); - var task = new ClearMLTask - { - Id = "task1", - Project = new ClearMLProject { Id = "project1" }, - Status = ClearMLTaskStatus.InProgress - }; - bool first = true; - env.ClearMLService - .GetTaskByNameAsync(Arg.Any(), Arg.Any()) - .Returns(x => - { - if (first) - { - first = false; - return Task.FromResult(null); - } - return Task.FromResult(task); - }); - env.ClearMLService - .GetTaskByIdAsync("task1", Arg.Any()) - .Returns(Task.FromResult(task)); - await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); - await env.WaitForBuildToStartAsync(); - TranslationEngine engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.Active)); - await env.Service.CancelBuildAsync("engine1"); - await env.WaitForBuildToFinishAsync(); - engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.None)); - await env.ClearMLService.Received().StopTaskAsync("task1", Arg.Any()); - } - - private class TestEnvironment : DisposableBase - { - private readonly MemoryStorage _memoryStorage; - private readonly BackgroundJobClient _jobClient; - private BackgroundJobServer _jobServer; - private readonly IDistributedReaderWriterLockFactory _lockFactory; - private readonly ISharedFileService _sharedFileService; - private readonly IOptionsMonitor _options; - - public TestEnvironment() - { - Engines = new MemoryRepository(); - Engines.Add( - new TranslationEngine - { - Id = "engine1", - EngineId = "engine1", - SourceLanguage = "es", - TargetLanguage = "en" - } - ); - EngineOptions = new SmtTransferEngineOptions(); - _memoryStorage = new MemoryStorage(); - _jobClient = new BackgroundJobClient(_memoryStorage); - PlatformService = Substitute.For(); - ClearMLService = Substitute.For(); - ClearMLService - .GetProjectIdAsync(Arg.Any(), Arg.Any()) - .Returns(Task.FromResult("project1")); - _lockFactory = new DistributedReaderWriterLockFactory( - new OptionsWrapper(new ServiceOptions { ServiceId = "host" }), - new MemoryRepository(), - new ObjectIdGenerator() - ); - _sharedFileService = new SharedFileService(Substitute.For()); - _options = Substitute.For>(); - _options.CurrentValue.Returns( - new ClearMLNmtEngineOptions { BuildPollingTimeout = TimeSpan.FromMilliseconds(50) } - ); - _jobServer = CreateJobServer(); - Service = CreateService(); - } - - public ClearMLNmtEngineService Service { get; private set; } - public MemoryRepository Engines { get; } - public SmtTransferEngineOptions EngineOptions { get; } - public IPlatformService PlatformService { get; } - public IClearMLService ClearMLService { get; } - - public void StopServer() - { - _jobServer.Dispose(); - } - - public void StartServer() - { - _jobServer = CreateJobServer(); - Service = CreateService(); - } - - private BackgroundJobServer CreateJobServer() - { - var jobServerOptions = new BackgroundJobServerOptions - { - Activator = new EnvActivator(this), - Queues = new[] { "nmt" }, - CancellationCheckInterval = TimeSpan.FromMilliseconds(100), - }; - return new BackgroundJobServer(jobServerOptions, _memoryStorage); - } - - private ClearMLNmtEngineService CreateService() - { - return new ClearMLNmtEngineService( - _jobClient, - PlatformService, - _lockFactory, - new MemoryDataAccessContext(), - Engines, - ClearMLService - ); - } - - public Task WaitForBuildToFinishAsync() - { - return WaitForBuildState(e => e.BuildState is BuildState.None); - } - - public Task WaitForBuildToStartAsync() - { - return WaitForBuildState(e => e.BuildState is BuildState.Active); - } - - private async Task WaitForBuildState(Func predicate) - { - using ISubscription subscription = await Engines.SubscribeAsync( - e => e.EngineId == "engine1" - ); - while (true) - { - TranslationEngine? build = subscription.Change.Entity; - if (build is not null && predicate(build)) - break; - await subscription.WaitForChangeAsync(); - } - } - - protected override void DisposeManagedResources() - { - _jobServer.Dispose(); - } - - private class EnvActivator : JobActivator - { - private readonly TestEnvironment _env; - - public EnvActivator(TestEnvironment env) - { - _env = env; - } - - public override object ActivateJob(Type jobType) - { - if (jobType == typeof(ClearMLNmtEngineBuildJob)) - { - return new ClearMLNmtEngineBuildJob( - _env.PlatformService, - _env.Engines, - Substitute.For>(), - _env.ClearMLService, - _env._sharedFileService, - _env._options, - Substitute.For() - ); - } - return base.ActivateJob(jobType); - } - } - } -} diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs index 0fcdcb317..84fc02bcc 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/ClearMLServiceTests.cs @@ -3,7 +3,7 @@ [TestFixture] public class ClearMLServiceTests { - private const string ApiServier = "https://clearml.com"; + private const string ApiServer = "https://clearml.com"; private const string AccessKey = "accessKey"; private const string SecretKey = "secretKey"; @@ -12,38 +12,40 @@ public async Task CreateTaskAsync() { var mockHttp = new MockHttpMessageHandler(); mockHttp - .Expect(HttpMethod.Post, $"{ApiServier}/tasks.create") + .Expect(HttpMethod.Post, $"{ApiServer}/tasks.create") .WithHeaders("Authorization", $"Bearer accessToken") .WithPartialContent("\\u0027src_lang\\u0027: \\u0027spa_Latn\\u0027") .WithPartialContent("\\u0027trg_lang\\u0027: \\u0027eng_Latn\\u0027") .Respond("application/json", "{ \"data\": { \"id\": \"projectId\" } }"); - var options = Substitute.For>(); + var options = Substitute.For>(); options.CurrentValue.Returns( - new ClearMLNmtEngineOptions + new ClearMLOptions { - ApiServer = ApiServier, + ApiServer = ApiServer, AccessKey = AccessKey, SecretKey = SecretKey } ); var authService = Substitute.For(); authService.GetAuthTokenAsync().Returns(Task.FromResult("accessToken")); - var service = new ClearMLService( - mockHttp.ToHttpClient(), - options, - Substitute.For>(), - authService - ); + var service = new ClearMLService(mockHttp.ToHttpClient(), options, authService); - string projectId = await service.CreateTaskAsync( - "build1", - "project1", - "engine1", - "es", - "en", - "s3://aqua-ml-data" - ); + string script = + "from machine.jobs.build_nmt_engine import run\n" + + "args = {\n" + + " 'model_type': 'huggingface',\n" + + " 'engine_id': 'engine1',\n" + + " 'build_id': 'build1',\n" + + " 'src_lang': 'spa_Latn',\n" + + " 'trg_lang': 'eng_Latn',\n" + + " 'max_steps': 20000,\n" + + " 'shared_file_uri': 's3://aqua-ml-data',\n" + + " 'clearml': True\n" + + "}\n" + + "run(args)\n"; + + string projectId = await service.CreateTaskAsync("build1", "project1", script); Assert.That(projectId, Is.EqualTo("projectId")); mockHttp.VerifyNoOutstandingExpectation(); } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/FileStorageTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/FileStorageTests.cs deleted file mode 100644 index b3e4991db..000000000 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/FileStorageTests.cs +++ /dev/null @@ -1,178 +0,0 @@ -namespace SIL.Machine.AspNetCore.Services; - -[TestFixture] -public class FileStorageTests -{ - [Test] - public async Task ExistsFileInMemoryAsync() - { - using InMemoryStorage fs = new InMemoryStorage(); - Stream ws = await fs.OpenWrite("file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - bool exists = await fs.Exists("file1"); - Assert.True(exists); - } - - [Test] - public async Task CreateFileReadFileInMemoryAsync() - { - using InMemoryStorage fs = new InMemoryStorage(); - Stream ws = await fs.OpenWrite("file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - Stream rs = await fs.OpenRead("file1"); - StreamReader sr = new(rs); - string? output = sr.ReadLine(); - sr.Dispose(); - Assert.That(input, Is.EqualTo(output), $"{input} | {output}"); - } - - [Test] - public async Task CreateFilesListFilesRecursiveInMemoryAsync() - { - using InMemoryStorage fs = new InMemoryStorage(); - Stream ws = await fs.OpenWrite("test/file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - ws = await fs.OpenWrite("test/test/file2"); - sw = new(ws); - string input2 = "Hola"; - sw.WriteLine(input2); - sw.Dispose(); - var files = await fs.Ls("test", recurse: true); - Assert.That(files.Count, Is.EqualTo(2)); - } - - [Test] - public async Task CreateFilesListFilesNotRecursiveInMemoryAsync() - { - using InMemoryStorage fs = new InMemoryStorage(); - Stream ws = await fs.OpenWrite("test/file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - ws = await fs.OpenWrite("test/test/file2"); - sw = new(ws); - string input2 = "Hola"; - sw.WriteLine(input2); - sw.Dispose(); - var files = await fs.Ls("test", recurse: false); - Assert.That(files.Count, Is.EqualTo(1)); - } - - [Test] - public async Task CreateFileRemoveFileInMemoryAsync() - { - using InMemoryStorage fs = new InMemoryStorage(); - Stream ws = await fs.OpenWrite("test/file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - ws = await fs.OpenWrite("test/test/file2"); - sw = new(ws); - string input2 = "Hola"; - sw.WriteLine(input2); - sw.Dispose(); - await fs.Rm("test", recurse: true); - var files = await fs.Ls("test", recurse: true); - Assert.That(files.Count, Is.EqualTo(0)); - } - - [Test] - public async Task ExistsFileLocalAsync() - { - var tmpDir = new TempDirectory("test"); - using FileStorage fs = new LocalStorage(tmpDir.Path); - Stream ws = await fs.OpenWrite("file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - bool exists = await fs.Exists("file1"); - Assert.True(exists); - } - - [Test] - public async Task CreateFileReadFileLocalAsync() - { - var tmpDir = new TempDirectory("test"); - using FileStorage fs = new LocalStorage(tmpDir.Path); - Stream ws = await fs.OpenWrite("file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - Stream rs = await fs.OpenRead("file1"); - StreamReader sr = new(rs); - string? output = sr.ReadLine(); - sr.Dispose(); - Assert.That(input, Is.EqualTo(output), $"{input} | {output}"); - } - - [Test] - public async Task CreateFilesListFilesRecursiveLocalAsync() - { - var tmpDir = new TempDirectory("test"); - using FileStorage fs = new LocalStorage(tmpDir.Path); - Stream ws = await fs.OpenWrite("test/file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - ws = await fs.OpenWrite("test/test/file2"); - sw = new(ws); - string input2 = "Hola"; - sw.WriteLine(input2); - sw.Dispose(); - var files = await fs.Ls("test", recurse: true); - Assert.That(files.Count, Is.EqualTo(2)); - } - - [Test] - public async Task CreateFilesListFilesNotRecursiveLocalAsync() - { - var tmpDir = new TempDirectory("test"); - using FileStorage fs = new LocalStorage(tmpDir.Path); - Stream ws = await fs.OpenWrite("test/file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - ws = await fs.OpenWrite("test/test/file2"); - sw = new(ws); - string input2 = "Hola"; - sw.WriteLine(input2); - sw.Dispose(); - var files = await fs.Ls("test", recurse: false); - Assert.That(files.Count, Is.EqualTo(1)); - } - - [Test] - public async Task CreateFileRemoveFileLocalAsync() - { - var tmpDir = new TempDirectory("test"); - using FileStorage fs = new LocalStorage(tmpDir.Path); - Stream ws = await fs.OpenWrite("test/file1"); - StreamWriter sw = new(ws); - string input = "Hello"; - sw.WriteLine(input); - sw.Dispose(); - ws = await fs.OpenWrite("test/test/file2"); - sw = new(ws); - string input2 = "Hola"; - sw.WriteLine(input2); - sw.Dispose(); - await fs.Rm("test", recurse: true); - var files = await fs.Ls("test", recurse: true); - Assert.That(files.Count, Is.EqualTo(0)); - } -} diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/InMemoryStorageTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/InMemoryStorageTests.cs new file mode 100644 index 000000000..3b5052865 --- /dev/null +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/InMemoryStorageTests.cs @@ -0,0 +1,91 @@ +namespace SIL.Machine.AspNetCore.Services; + +[TestFixture] +public class InMemoryStorageTests +{ + [Test] + public async Task ExistsAsync() + { + using InMemoryStorage fs = new(); + using (StreamWriter sw = new(await fs.OpenWriteAsync("file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + bool exists = await fs.ExistsAsync("file1"); + Assert.That(exists, Is.True); + } + + [Test] + public async Task OpenReadAsync() + { + using InMemoryStorage fs = new(); + string input; + using (StreamWriter sw = new(await fs.OpenWriteAsync("file1"))) + { + input = "Hello"; + sw.WriteLine(input); + } + string? output; + using (StreamReader sr = new(await fs.OpenReadAsync("file1"))) + { + output = sr.ReadLine(); + } + Assert.That(input, Is.EqualTo(output), $"{input} | {output}"); + } + + [Test] + public async Task ListFilesAsync_Recurse() + { + using InMemoryStorage fs = new(); + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/test/file2"))) + { + string input2 = "Hola"; + sw.WriteLine(input2); + } + IReadOnlyCollection files = await fs.ListFilesAsync("test", recurse: true); + Assert.That(files, Is.EquivalentTo(new[] { "test/file1", "test/test/file2" })); + } + + [Test] + public async Task ListFilesAsync_DoNotRecurse() + { + using InMemoryStorage fs = new(); + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/test/file2"))) + { + string input2 = "Hola"; + sw.WriteLine(input2); + } + IReadOnlyCollection files = await fs.ListFilesAsync("test", recurse: false); + Assert.That(files, Is.EquivalentTo(new[] { "test/file1" })); + } + + [Test] + public async Task DeleteAsync() + { + using InMemoryStorage fs = new(); + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/test/file2"))) + { + string input2 = "Hola"; + sw.WriteLine(input2); + } + await fs.DeleteAsync("test", recurse: true); + var files = await fs.ListFilesAsync("test", recurse: true); + Assert.That(files, Is.Empty); + } +} diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/LocalStorageTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/LocalStorageTests.cs new file mode 100644 index 000000000..280a54bb1 --- /dev/null +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/LocalStorageTests.cs @@ -0,0 +1,96 @@ +namespace SIL.Machine.AspNetCore.Services; + +[TestFixture] +public class LocalStorageTests +{ + [Test] + public async Task ExistsAsync() + { + using var tmpDir = new TempDirectory("test"); + using LocalStorage fs = new(tmpDir.Path); + using (StreamWriter sw = new(await fs.OpenWriteAsync("file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + bool exists = await fs.ExistsAsync("file1"); + Assert.That(exists, Is.True); + } + + [Test] + public async Task OpenReadAsync() + { + using var tmpDir = new TempDirectory("test"); + using LocalStorage fs = new(tmpDir.Path); + string input; + using (StreamWriter sw = new(await fs.OpenWriteAsync("file1"))) + { + input = "Hello"; + sw.WriteLine(input); + } + string? output; + using (StreamReader sr = new(await fs.OpenReadAsync("file1"))) + { + output = sr.ReadLine(); + } + Assert.That(input, Is.EqualTo(output), $"{input} | {output}"); + } + + [Test] + public async Task ListFilesAsync_Recurse() + { + using var tmpDir = new TempDirectory("test"); + using LocalStorage fs = new(tmpDir.Path); + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/test/file2"))) + { + string input2 = "Hola"; + sw.WriteLine(input2); + } + IReadOnlyCollection files = await fs.ListFilesAsync("test", recurse: true); + Assert.That(files, Is.EquivalentTo(new[] { "test/file1", "test/test/file2" })); + } + + [Test] + public async Task ListFilesAsync_DoNotRecurse() + { + using var tmpDir = new TempDirectory("test"); + using LocalStorage fs = new(tmpDir.Path); + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/test/file2"))) + { + string input2 = "Hola"; + sw.WriteLine(input2); + } + IReadOnlyCollection files = await fs.ListFilesAsync("test", recurse: false); + Assert.That(files, Is.EquivalentTo(new[] { "test/file1" })); + } + + [Test] + public async Task DeleteFileAsync() + { + using var tmpDir = new TempDirectory("test"); + using LocalStorage fs = new(tmpDir.Path); + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/file1"))) + { + string input = "Hello"; + sw.WriteLine(input); + } + using (StreamWriter sw = new(await fs.OpenWriteAsync("test/test/file2"))) + { + string input2 = "Hola"; + sw.WriteLine(input2); + } + await fs.DeleteAsync("test", recurse: true); + IReadOnlyCollection files = await fs.ListFilesAsync("test", recurse: true); + Assert.That(files, Is.Empty); + } +} diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs new file mode 100644 index 000000000..320d56ad5 --- /dev/null +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs @@ -0,0 +1,272 @@ +namespace SIL.Machine.AspNetCore.Services; + +[TestFixture] +public class NmtEngineServiceTests +{ + [Test] + public async Task StartBuildAsync() + { + using var env = new TestEnvironment(); + TranslationEngine engine = env.Engines.Get("engine1"); + Assert.That(engine.BuildRevision, Is.EqualTo(1)); + await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); + await env.WaitForBuildToFinishAsync(); + engine = env.Engines.Get("engine1"); + Assert.That(engine.CurrentBuild, Is.Null); + Assert.That(engine.BuildRevision, Is.EqualTo(2)); + } + + [Test] + public async Task CancelBuildAsync() + { + using var env = new TestEnvironment(); + + var cts = new CancellationTokenSource(); + env.ClearMLService.When(x => x.StopTaskAsync("job1", Arg.Any())).Do(_ => cts.Cancel()); + env.TrainJobFunc = async () => + { + await env.BuildJobService.BuildJobStartedAsync("engine1", "build1"); + + while (!cts.IsCancellationRequested) + await Task.Delay(50); + + await env.BuildJobService.BuildJobFinishedAsync("engine1", "build1", buildComplete: false); + }; + + TranslationEngine engine = env.Engines.Get("engine1"); + Assert.That(engine.BuildRevision, Is.EqualTo(1)); + await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); + await env.WaitForBuildToStartAsync(); + engine = env.Engines.Get("engine1"); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + await env.Service.CancelBuildAsync("engine1"); + await env.WaitForBuildToFinishAsync(); + engine = env.Engines.Get("engine1"); + Assert.That(engine.CurrentBuild, Is.Null); + Assert.That(engine.BuildRevision, Is.EqualTo(1)); + } + + [Test] + public async Task DeleteAsync_WhileBuilding() + { + using var env = new TestEnvironment(); + + var cts = new CancellationTokenSource(); + env.ClearMLService.When(x => x.StopTaskAsync("job1", Arg.Any())).Do(_ => cts.Cancel()); + env.TrainJobFunc = async () => + { + await env.BuildJobService.BuildJobStartedAsync("engine1", "build1"); + + while (!cts.IsCancellationRequested) + await Task.Delay(50); + + await env.BuildJobService.BuildJobFinishedAsync("engine1", "build1", buildComplete: false); + }; + + TranslationEngine engine = env.Engines.Get("engine1"); + Assert.That(engine.BuildRevision, Is.EqualTo(1)); + await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); + await env.WaitForBuildToStartAsync(); + engine = env.Engines.Get("engine1"); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + await env.Service.DeleteAsync("engine1"); + // ensure that the train job has completed + if (env.TrainJobTask is not null) + await env.TrainJobTask; + Assert.That(env.Engines.Contains("engine1"), Is.False); + } + + private class TestEnvironment : DisposableBase + { + private readonly Hangfire.InMemory.InMemoryStorage _memoryStorage; + private readonly BackgroundJobClient _jobClient; + private BackgroundJobServer _jobServer; + private readonly IDistributedReaderWriterLockFactory _lockFactory; + + public TestEnvironment() + { + if (!Sldr.IsInitialized) + Sldr.Initialize(offlineMode: true); + + TrainJobFunc = RunMockTrainJob; + Engines = new MemoryRepository(); + Engines.Add( + new TranslationEngine + { + Id = "engine1", + EngineId = "engine1", + SourceLanguage = "es", + TargetLanguage = "en", + BuildRevision = 1 + } + ); + _memoryStorage = new Hangfire.InMemory.InMemoryStorage(); + _jobClient = new BackgroundJobClient(_memoryStorage); + PlatformService = Substitute.For(); + _lockFactory = new DistributedReaderWriterLockFactory( + new OptionsWrapper(new ServiceOptions { ServiceId = "host" }), + new MemoryRepository(), + new ObjectIdGenerator() + ); + ClearMLService = Substitute.For(); + ClearMLService + .GetProjectIdAsync("engine1", Arg.Any()) + .Returns(Task.FromResult("project1")); + ClearMLService + .CreateTaskAsync("build1", "project1", Arg.Any(), Arg.Any()) + .Returns(Task.FromResult("job1")); + ClearMLService + .When(x => x.EnqueueTaskAsync("job1", Arg.Any())) + .Do(_ => TrainJobTask = Task.Run(TrainJobFunc)); + SharedFileService = new SharedFileService(Substitute.For()); + var clearMLOptions = Substitute.For>(); + clearMLOptions.CurrentValue.Returns(new ClearMLOptions()); + BuildJobService = new BuildJobService( + new IBuildJobRunner[] + { + new HangfireBuildJobRunner(_jobClient, new[] { new NmtHangfireBuildJobFactory() }), + new ClearMLBuildJobRunner( + ClearMLService, + new[] { new NmtClearMLBuildJobFactory(SharedFileService, Engines, clearMLOptions) } + ) + }, + Engines, + new OptionsWrapper(new BuildJobOptions()) + ); + _jobServer = CreateJobServer(); + Service = CreateService(); + } + + public NmtEngineService Service { get; private set; } + public MemoryRepository Engines { get; } + public IPlatformService PlatformService { get; } + public IClearMLService ClearMLService { get; } + public ISharedFileService SharedFileService { get; } + public IBuildJobService BuildJobService { get; } + public Func TrainJobFunc { get; set; } + public Task? TrainJobTask { get; private set; } + + public void StopServer() + { + _jobServer.Dispose(); + } + + public void StartServer() + { + _jobServer = CreateJobServer(); + Service = CreateService(); + } + + private BackgroundJobServer CreateJobServer() + { + var jobServerOptions = new BackgroundJobServerOptions + { + Activator = new EnvActivator(this), + Queues = new[] { "nmt" }, + CancellationCheckInterval = TimeSpan.FromMilliseconds(50), + }; + return new BackgroundJobServer(jobServerOptions, _memoryStorage); + } + + private NmtEngineService CreateService() + { + return new NmtEngineService( + PlatformService, + _lockFactory, + new MemoryDataAccessContext(), + Engines, + BuildJobService + ); + } + + public Task WaitForBuildToFinishAsync() + { + return WaitForBuildState(e => e.CurrentBuild is null); + } + + public Task WaitForBuildToStartAsync() + { + return WaitForBuildState( + e => e.CurrentBuild!.JobState is BuildJobState.Active && e.CurrentBuild!.Stage == NmtBuildStages.Train + ); + } + + private async Task WaitForBuildState(Func predicate) + { + using ISubscription subscription = await Engines.SubscribeAsync( + e => e.EngineId == "engine1" + ); + while (true) + { + TranslationEngine? engine = subscription.Change.Entity; + if (engine is not null && predicate(engine)) + break; + await subscription.WaitForChangeAsync(); + } + } + + private async Task RunMockTrainJob() + { + await BuildJobService.BuildJobStartedAsync("engine1", "build1"); + + await using (var stream = await SharedFileService.OpenWriteAsync("builds/build1/pretranslate.trg.json")) + { + await JsonSerializer.SerializeAsync(stream, Array.Empty()); + } + + await BuildJobService.StartBuildJobAsync( + BuildJobType.Cpu, + TranslationEngineType.Nmt, + "engine1", + "build1", + NmtBuildStages.Postprocess, + (0, 0.0) + ); + } + + protected override void DisposeManagedResources() + { + _jobServer.Dispose(); + } + + private class EnvActivator : JobActivator + { + private readonly TestEnvironment _env; + + public EnvActivator(TestEnvironment env) + { + _env = env; + } + + public override object ActivateJob(Type jobType) + { + if (jobType == typeof(NmtPreprocessBuildJob)) + { + return new NmtPreprocessBuildJob( + _env.PlatformService, + _env.Engines, + _env._lockFactory, + Substitute.For>(), + _env.BuildJobService, + _env.SharedFileService, + Substitute.For() + ); + } + if (jobType == typeof(NmtPostprocessBuildJob)) + { + return new NmtPostprocessBuildJob( + _env.PlatformService, + _env.Engines, + _env._lockFactory, + _env.BuildJobService, + Substitute.For>(), + _env.SharedFileService + ); + } + return base.ActivateJob(jobType); + } + } + } +} diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs index 7c69e79e4..1db5aa04f 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs @@ -1,16 +1,27 @@ -using SIL.Machine.Tokenization; - -namespace SIL.Machine.AspNetCore.Services; +namespace SIL.Machine.AspNetCore.Services; [TestFixture] public class SmtTransferEngineServiceTests { + [Test] + public async Task CreateAsync() + { + using var env = new TestEnvironment(); + await env.Service.CreateAsync("engine2", "Engine 2", "es", "en"); + TranslationEngine? engine = await env.Engines.GetAsync(e => e.EngineId == "engine2"); + Assert.That(engine, Is.Not.Null); + Assert.That(engine.EngineId, Is.EqualTo("engine2")); + Assert.That(engine.BuildRevision, Is.EqualTo(0)); + env.SmtModelFactory.Received().InitNew("engine2"); + env.TransferEngineFactory.Received().InitNew("engine2"); + } + [Test] public async Task StartBuildAsync() { using var env = new TestEnvironment(); TranslationEngine engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildRevision, Is.EqualTo(1)); //For testing purposes BuildRevision is set to 1 (i.e., an already built engine) + Assert.That(engine.BuildRevision, Is.EqualTo(1)); // ensure that the SMT model was loaded before training await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba."); await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); @@ -24,8 +35,8 @@ await env.TruecaserTrainer await env.SmtBatchTrainer.Received().SaveAsync(Arg.Any()); await env.TruecaserTrainer.Received().SaveAsync(Arg.Any()); engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.None)); - Assert.That(engine.BuildRevision, Is.EqualTo(2)); //For testing purposes BuildRevision was initially set to 1 (i.e., an already built engine), so now it ought to be 2 + Assert.That(engine.CurrentBuild, Is.Null); + Assert.That(engine.BuildRevision, Is.EqualTo(2)); // check if SMT model was reloaded upon first use after training env.SmtModel.ClearReceivedCalls(); await env.Service.TranslateAsync("engine1", n: 1, "esto es una prueba."); @@ -43,19 +54,23 @@ await env.SmtBatchTrainer.TrainAsync( Arg.Do(ct => { while (true) + { ct.ThrowIfCancellationRequested(); + Thread.Sleep(100); + } }) ); await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); await env.WaitForBuildToStartAsync(); TranslationEngine engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.Active)); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); await env.Service.CancelBuildAsync("engine1"); await env.WaitForBuildToFinishAsync(); await env.SmtBatchTrainer.DidNotReceive().SaveAsync(); await env.TruecaserTrainer.DidNotReceive().SaveAsync(); engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.None)); + Assert.That(engine.CurrentBuild, Is.Null); } [Test] @@ -63,30 +78,90 @@ public async Task StartBuildAsync_RestartUnfinishedBuild() { using var env = new TestEnvironment(); - env.SmtBatchTrainer - .WhenForAnyArgs(t => t.TrainAsync(null, default)) - .Do(ci => + await env.SmtBatchTrainer.TrainAsync( + Arg.Any>(), + Arg.Do(ct => { - CancellationToken ct = ci.ArgAt(1); while (true) { ct.ThrowIfCancellationRequested(); Thread.Sleep(100); } - }); + }) + ); await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); await env.WaitForBuildToStartAsync(); TranslationEngine engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.Active)); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); env.StopServer(); + await env.WaitForBuildToRestartAsync(); engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.Pending)); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Pending)); await env.PlatformService.Received().BuildRestartingAsync("build1"); env.SmtBatchTrainer.ClearSubstitute(ClearOptions.CallActions); env.StartServer(); await env.WaitForBuildToFinishAsync(); engine = env.Engines.Get("engine1"); - Assert.That(engine.BuildState, Is.EqualTo(BuildState.None)); + Assert.That(engine.CurrentBuild, Is.Null); + } + + [Test] + public async Task DeleteAsync_WhileBuilding() + { + using var env = new TestEnvironment(); + await env.SmtBatchTrainer.TrainAsync( + Arg.Any>(), + Arg.Do(ct => + { + while (true) + { + ct.ThrowIfCancellationRequested(); + Thread.Sleep(100); + } + }) + ); + await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); + await env.WaitForBuildToStartAsync(); + TranslationEngine engine = env.Engines.Get("engine1"); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + await env.Service.DeleteAsync("engine1"); + // ensure that the build job was canceled + await env.WaitForAllHangfireJobsToFinishAsync(); + await env.SmtBatchTrainer.DidNotReceive().SaveAsync(); + await env.TruecaserTrainer.DidNotReceive().SaveAsync(); + Assert.That(env.Engines.Contains("engine1"), Is.False); + } + + [Test] + public async Task TrainSegmentPairAsync() + { + using var env = new TestEnvironment(); + bool training = true; + await env.SmtBatchTrainer.TrainAsync( + Arg.Any>(), + Arg.Do(ct => + { + while (training) + { + ct.ThrowIfCancellationRequested(); + Thread.Sleep(100); + } + }) + ); + await env.Service.StartBuildAsync("engine1", "build1", Array.Empty()); + await env.WaitForBuildToStartAsync(); + TranslationEngine engine = env.Engines.Get("engine1"); + Assert.That(engine.CurrentBuild, Is.Not.Null); + Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); + await env.Service.TrainSegmentPairAsync("engine1", "esto es una prueba.", "this is a test.", true); + training = false; + await env.WaitForBuildToFinishAsync(); + engine = env.Engines.Get("engine1"); + Assert.That(engine.CurrentBuild, Is.Null); + await env.SmtModel.Received(2).TrainSegmentAsync("esto es una prueba.", "this is a test.", true); } [Test] @@ -118,15 +193,25 @@ public async Task TranslateAsync() Assert.That(result.Translation, Is.EqualTo("this is a TEST.")); } + [Test] + public async Task GetWordGraphAsync() + { + using var env = new TestEnvironment(); + WordGraph result = await env.Service.GetWordGraphAsync("engine1", "esto es una prueba."); + Assert.That( + result.Arcs.Select(a => string.Join(' ', a.TargetTokens)), + Is.EqualTo(new[] { "this is", "a test", "." }) + ); + } + private class TestEnvironment : DisposableBase { - private readonly MemoryStorage _memoryStorage; + private readonly Hangfire.InMemory.InMemoryStorage _memoryStorage; private readonly BackgroundJobClient _jobClient; private BackgroundJobServer _jobServer; - private readonly ISmtModelFactory _smtModelFactory; - private readonly ITransferEngineFactory _transferEngineFactory; private readonly ITruecaserFactory _truecaserFactory; private readonly IDistributedReaderWriterLockFactory _lockFactory; + private readonly IBuildJobService _buildJobService; public TestEnvironment() { @@ -138,12 +223,11 @@ public TestEnvironment() EngineId = "engine1", SourceLanguage = "es", TargetLanguage = "en", - BuildRevision = 1, - BuildState = BuildState.None, + BuildRevision = 1 } ); TrainSegmentPairs = new MemoryRepository(); - _memoryStorage = new MemoryStorage(); + _memoryStorage = new Hangfire.InMemory.InMemoryStorage(); _jobClient = new BackgroundJobClient(_memoryStorage); PlatformService = Substitute.For(); SmtModel = Substitute.For(); @@ -152,14 +236,27 @@ public TestEnvironment() Truecaser = Substitute.For(); TruecaserTrainer = Substitute.For(); TruecaserTrainer.SaveAsync().Returns(Task.CompletedTask); - _smtModelFactory = CreateSmtModelFactory(); - _transferEngineFactory = CreateTransferEngineFactory(); + SmtModelFactory = CreateSmtModelFactory(); + TransferEngineFactory = CreateTransferEngineFactory(); _truecaserFactory = CreateTruecaserFactory(); _lockFactory = new DistributedReaderWriterLockFactory( new OptionsWrapper(new ServiceOptions { ServiceId = "host" }), new MemoryRepository(), new ObjectIdGenerator() ); + _buildJobService = new BuildJobService( + new[] { new HangfireBuildJobRunner(_jobClient, new[] { new SmtTransferHangfireBuildJobFactory() }) }, + Engines, + new OptionsWrapper( + new BuildJobOptions + { + Runners = new Dictionary + { + { BuildJobType.Cpu, BuildJobRunner.Hangfire } + } + } + ) + ); _jobServer = CreateJobServer(); StateService = CreateStateService(); Service = CreateService(); @@ -169,6 +266,8 @@ public TestEnvironment() public SmtTransferEngineStateService StateService { get; private set; } public MemoryRepository Engines { get; } public MemoryRepository TrainSegmentPairs { get; } + public ISmtModelFactory SmtModelFactory { get; } + public ITransferEngineFactory TransferEngineFactory { get; } public ITrainer SmtBatchTrainer { get; } public IInteractiveTranslationModel SmtModel { get; } public ITruecaser Truecaser { get; } @@ -206,19 +305,19 @@ private BackgroundJobServer CreateJobServer() private SmtTransferEngineStateService CreateStateService() { - return new SmtTransferEngineStateService(_smtModelFactory, _transferEngineFactory, _truecaserFactory); + return new SmtTransferEngineStateService(SmtModelFactory, TransferEngineFactory, _truecaserFactory); } private SmtTransferEngineService CreateService() { return new SmtTransferEngineService( - _jobClient, _lockFactory, PlatformService, new MemoryDataAccessContext(), Engines, TrainSegmentPairs, - StateService + StateService, + _buildJobService ); } @@ -376,14 +475,26 @@ private static IEnumerable GetSources(int count, bool isUnkn return sources; } + public async Task WaitForAllHangfireJobsToFinishAsync() + { + IMonitoringApi monitoringApi = _memoryStorage.GetMonitoringApi(); + while (monitoringApi.EnqueuedCount("smt_transfer") > 0 || monitoringApi.ProcessingCount() > 0) + await Task.Delay(50); + } + public Task WaitForBuildToFinishAsync() { - return WaitForBuildState(e => e.BuildState is BuildState.None); + return WaitForBuildState(e => e.CurrentBuild is null); } public Task WaitForBuildToStartAsync() { - return WaitForBuildState(e => e.BuildState is BuildState.Active); + return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Active); + } + + public Task WaitForBuildToRestartAsync() + { + return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Pending); } private async Task WaitForBuildState(Func predicate) @@ -393,8 +504,8 @@ private async Task WaitForBuildState(Func predicate) ); while (true) { - TranslationEngine? build = subscription.Change.Entity; - if (build is not null && predicate(build)) + TranslationEngine? engine = subscription.Change.Entity; + if (engine is not null && predicate(engine)) break; await subscription.WaitForChangeAsync(); } @@ -417,17 +528,18 @@ public EnvActivator(TestEnvironment env) public override object ActivateJob(Type jobType) { - if (jobType == typeof(SmtTransferEngineBuildJob)) + if (jobType == typeof(SmtTransferBuildJob)) { - return new SmtTransferEngineBuildJob( + return new SmtTransferBuildJob( _env.PlatformService, _env.Engines, - _env.TrainSegmentPairs, _env._lockFactory, + _env._buildJobService, + Substitute.For>(), + _env.TrainSegmentPairs, _env._truecaserFactory, - _env._smtModelFactory, - Substitute.For(), - Substitute.For>() + _env.SmtModelFactory, + Substitute.For() ); } return base.ActivateJob(jobType); diff --git a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs index 6f2621852..462b51ad5 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs @@ -1,5 +1,6 @@ -global using Hangfire; -global using Hangfire.MemoryStorage; +global using System.Text.Json; +global using Hangfire; +global using Hangfire.Storage; global using Microsoft.Extensions.Logging; global using Microsoft.Extensions.Options; global using NSubstitute; @@ -12,6 +13,8 @@ global using SIL.Machine.AspNetCore.Configuration; global using SIL.Machine.AspNetCore.Models; global using SIL.Machine.Corpora; +global using SIL.Machine.Tokenization; global using SIL.Machine.Translation; global using SIL.Machine.Utils; global using SIL.ObjectModel; +global using SIL.WritingSystems;