From 457507edc06dc985c8305a51bb75a9be19ed9158 Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 6 Oct 2021 21:12:13 -0700 Subject: [PATCH] Address PR comments (2) - declare a proper type for generation arguments that user can override when calling the genienlp parser - some refactoring for better reading --- lib/prediction/localparserclient.ts | 9 ++++---- lib/prediction/predictor.ts | 11 +++++----- lib/prediction/remoteparserclient.ts | 19 ++++++++++------- lib/prediction/types.ts | 20 ++++++++++++++++++ lib/utils/misc-utils.ts | 15 +++++++------- tool/server.ts | 31 +++++----------------------- 6 files changed, 55 insertions(+), 50 deletions(-) diff --git a/lib/prediction/localparserclient.ts b/lib/prediction/localparserclient.ts index 36b3df007..08e54d60f 100644 --- a/lib/prediction/localparserclient.ts +++ b/lib/prediction/localparserclient.ts @@ -35,7 +35,8 @@ import { PredictionCandidate, PredictionResult, GenerationResult, - ExactMatcher + ExactMatcher, + GENERATION_OPTIONS } from './types'; const SEMANTIC_PARSING_TASK = 'almond'; @@ -264,9 +265,9 @@ export default class LocalParserClient { }); } - async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record) : Promise { - input = Utils.qpisEntities(input, contextEntities); - const candidates = await this._predictor.predict('', input.join(' '), undefined, TRANSLATION_TASK, 'id-null', translationOptions); + async translateUtterance(input : string[], entities : string[]|undefined, generationOptions : GENERATION_OPTIONS) : Promise { + input = Utils.qpisEntities(input, entities); + const candidates = await this._predictor.predict('', input.join(' '), undefined, TRANSLATION_TASK, 'id-null', generationOptions); return candidates.map((cand) => { return { answer: cand.answer, diff --git a/lib/prediction/predictor.ts b/lib/prediction/predictor.ts index c46f6c638..2197502a7 100644 --- a/lib/prediction/predictor.ts +++ b/lib/prediction/predictor.ts @@ -24,6 +24,7 @@ import * as child_process from 'child_process'; import * as Tp from 'thingpedia'; import JsonDatagramSocket from '../utils/json_datagram_socket'; +import {GENERATION_OPTIONS} from "./types"; const DEFAULT_QUESTION = 'translate from english to thingtalk'; @@ -151,7 +152,7 @@ class LocalWorker extends events.EventEmitter { this._requests.clear(); } - request(task : string, minibatch : Example[], options : Record) : Promise { + request(task : string, minibatch : Example[], options : GENERATION_OPTIONS) : Promise { const id = this._nextId ++; return new Promise((resolve, reject) => { @@ -179,7 +180,7 @@ class RemoteWorker extends events.EventEmitter { start() {} stop() {} - async request(task : string, minibatch : Example[], options : Record) : Promise { + async request(task : string, minibatch : Example[], options : GENERATION_OPTIONS) : Promise { const response = await Tp.Helpers.Http.post(this._url, JSON.stringify({ task, instances: minibatch, @@ -244,7 +245,7 @@ export default class Predictor { }); } - private _startRequest(ex : Example, task : string, options : Record, now : number) { + private _startRequest(ex : Example, task : string, options : GENERATION_OPTIONS, now : number) { assert(this._minibatch.length === 0); this._minibatch.push(ex); this._minibatchTask = task; @@ -257,7 +258,7 @@ export default class Predictor { }, this._maxLatency); } - private _addRequest(ex : Example, task : string, options : Record) { + private _addRequest(ex : Example, task : string, options : GENERATION_OPTIONS) { const now = Date.now(); if (this._minibatch.length === 0) { this._startRequest(ex, task, options, now); @@ -271,7 +272,7 @@ export default class Predictor { } } - predict(context : string, question : string = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string, options : Record = {}) : Promise { + predict(context : string, question : string = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string, options : GENERATION_OPTIONS = {}) : Promise { // ensure we have a worker, in case it recently died if (!this._worker) diff --git a/lib/prediction/remoteparserclient.ts b/lib/prediction/remoteparserclient.ts index bab93ee4d..299b82178 100644 --- a/lib/prediction/remoteparserclient.ts +++ b/lib/prediction/remoteparserclient.ts @@ -30,7 +30,7 @@ import { PredictionCandidate, PredictionResult, GenerationResult, - ExactMatcher, + ExactMatcher, GENERATION_OPTIONS, } from './types'; import ExactMatcherBuilder from './exactbuilder'; @@ -170,16 +170,19 @@ export default class RemoteParserClient { return parsed.candidates; } - async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record) : Promise { - input = Utils.qpisEntities(input, contextEntities); + async translateUtterance(input : string[], entities : string[]|undefined, generationOptions : GENERATION_OPTIONS) : Promise { + input = Utils.qpisEntities(input, entities); const data = { input: input.join(' '), - tgt_locale: translationOptions.tgt_locale, - entities: contextEntities, - alignment: translationOptions.alignment, - src_locale: translationOptions.src_locale, - align_remove_output_quotation: translationOptions.align_remove_output_quotation + tgt_locale: generationOptions.tgt_locale, + entities: entities, + alignment: generationOptions.do_alignment, + src_locale: generationOptions.src_locale, + // always remove quotation marks in the output string used to mark entity boundaries + align_remove_output_quotation: true, + // always break input utterance into individual sentences before translation + translate_example_split: true }; const response = await Tp.Helpers.Http.post(`${this._baseUrl}/translate`, JSON.stringify(data), { diff --git a/lib/prediction/types.ts b/lib/prediction/types.ts index c59f675df..f4b7d1c79 100644 --- a/lib/prediction/types.ts +++ b/lib/prediction/types.ts @@ -24,6 +24,26 @@ export interface ExactMatcher { get(tokens : string[]) : string[][]|null; } +export interface GENERATION_OPTIONS { + num_beams ?: number + num_beam_groups ?: number + diversity_penalty ?: number + num_outputs ?: number + no_repeat_ngram_size ?: number + top_p ?: number + top_k ?: number + repetition_penalty ?: number + temperature ?: number + max_output_length ?: number + src_locale ?: string + tgt_locale ?: string + do_alignment ?: boolean + align_preserve_input_quotation ?: boolean + align_remove_output_quotation ?: boolean + translate_example_split ?: boolean +} + + export interface ParseOptions { thingtalk_version ?: string; store ?: string; diff --git a/lib/utils/misc-utils.ts b/lib/utils/misc-utils.ts index 26d38a58a..f5f74ca0e 100644 --- a/lib/utils/misc-utils.ts +++ b/lib/utils/misc-utils.ts @@ -31,7 +31,6 @@ import { makeDummyEntity, makeDummyEntities, renumberEntities, - EntityMap, } from './entity-utils'; class ValidationError extends Error { @@ -211,20 +210,22 @@ function substringSpan(sequence : string[], substring : string[]) : [number, num } } if (found) - return [i, i + substring.length + 1]; + return [i, i + substring.length]; } return null; } -function qpisEntities(input : string[], contextEntities : EntityMap|undefined) : string[] { - if (contextEntities) { - const allEntities = Object.keys(contextEntities).map((ent) => ent.split(' ')); - for (const entity of allEntities) { +function qpisEntities(input : string[], entities : string[]|undefined) : string[] { + if (entities) { + const entityTokens = entities.map((ent) => ent.split(' ')); + for (const entity of entityTokens) { const span = substringSpan(input, entity); if (span) { input.splice(span[0], 0, '"'); - input.splice(span[1], 0, '"'); + + // add 1 cause previous splice shift tokens to the right + input.splice(span[1] + 1, 0, '"'); } } } diff --git a/tool/server.ts b/tool/server.ts index 4e70056e6..4dd6c2867 100644 --- a/tool/server.ts +++ b/tool/server.ts @@ -32,6 +32,7 @@ import * as Utils from '../lib/utils/misc-utils'; import { EntityMap } from '../lib/utils/entity-utils'; import LocalParserClient from '../lib/prediction/localparserclient'; import * as I18n from '../lib/i18n'; +import {GENERATION_OPTIONS} from "../lib/prediction/types"; interface Backend { schemas : ThingTalk.SchemaRetriever; @@ -153,7 +154,7 @@ async function queryNLG(params : Record, interface TranslationData { input : string; tgt_locale : string - entities ?: EntityMap; + entities ?: string[]; limit ?: string; alignment ?: boolean; src_locale ?: string; @@ -162,35 +163,13 @@ interface TranslationData { const Translation_PARAMS = { input: 'string', tgt_locale: 'string', - entities: '?object', + entities: '?array', limit: '?number', alignment: '?boolean', src_locale: '?string', align_remove_output_quotation: '?boolean' }; - -// const VALID_PARSER_OPTIONS = new Set([ -// "num_beams", -// "num_beam_groups", -// "diversity_penalty", -// "num_outputs", -// "no_repeat_ngram_size", -// "top_p", -// "top_k", -// "repetition_penalty", -// "temperature", -// "max_output_length", -// "reduce_metrics", -// "database_dir", -// "do_alignment", -// "align_preserve_input_quotation", -// "align_remove_output_quotation", -// "src_locale", -// "tgt_locale", -// "translate_example_split" -// ]); - async function queryTranslate(params : Record, data : TranslationData, res : express.Response) { @@ -204,7 +183,7 @@ async function queryTranslate(params : Record, if (!data.src_locale) data.src_locale = 'en-US'; - const translationOptions : Record = { + const generationOptions : GENERATION_OPTIONS = { 'src_locale': data.src_locale, 'tgt_locale': data.tgt_locale, 'do_alignment': data.alignment, @@ -213,7 +192,7 @@ async function queryTranslate(params : Record, }; const result = await res.app.backend.translator!.translateUtterance( - data.input.split(' '), data.entities, translationOptions); + data.input.split(' '), data.entities, generationOptions); res.json({ candidates: result.slice(0, data.limit ? parseInt(data.limit) : undefined), });