From 87b587862853b981f72434b4d1a4369385a08cf7 Mon Sep 17 00:00:00 2001 From: Patrick Henning Date: Fri, 27 Jan 2023 18:32:50 -0800 Subject: [PATCH] Utilities for functions and definitions --- src/elements.ts | 3 +- src/functions.ts | 104 +++++++++++++++++++++++++++++++++++++++++++++-- src/index.ts | 2 +- 3 files changed, 103 insertions(+), 6 deletions(-) diff --git a/src/elements.ts b/src/elements.ts index 65dd2009..5860fc6b 100644 --- a/src/elements.ts +++ b/src/elements.ts @@ -50,7 +50,8 @@ export abstract class ExprElement { */ recursiveSubstitute(vars: ExprMap): ExprElement { const varList = Object.keys(vars); - if (!this.unknowns.filter(v => varList.includes(v)).length) return this; + const unknown = [...this.unknowns, ...this.functions]; + if (!unknown.filter(v => varList.includes(v)).length) return this; return this.substitute(vars).recursiveSubstitute(vars); } diff --git a/src/functions.ts b/src/functions.ts index 3d0651b9..d84afa5a 100644 --- a/src/functions.ts +++ b/src/functions.ts @@ -50,6 +50,13 @@ export class ExprFunction extends ExprElement { if (this.fn in vars) { const fn = vars[this.fn]; + if (fn instanceof ExprFunctionDefinition) { + const nextMap = {...vars}; + for (const [position, paramName] of fn.params.entries()) { + nextMap[paramName] = args[position]; + } + return fn.evaluate(nextMap); + } if (typeof fn === 'function') return fn(...args); if (typeof fn === 'number' && args.length === 1) return evaluate.mul(fn, args[0]); throw ExprError.uncallableExpression(this.fn); @@ -86,13 +93,22 @@ export class ExprFunction extends ExprElement { throw ExprError.undefinedFunction(this.fn); } - substitute(vars: ExprMap = {}) { - return new ExprFunction(this.fn, this.args.map(a => a.substitute(vars))); + substitute(vars: ExprMap = {}): ExprElement { + const args = this.args.map(a => a.substitute(vars)); + if (this.fn in vars && vars[this.fn] instanceof ExprFunctionDefinition) { + return (vars[this.fn] as ExprFunctionDefinition).applyExpressions(args); + } + return new ExprFunction(this.fn, args); } - collapse() { + collapse(): ExprElement { if (this.fn === '(') return this.args[0].collapse(); - return new ExprFunction(this.fn, this.args.map(a => a.collapse())); + const collapsedArgs = this.args.map(a => a.collapse()); + const arg0 = collapsedArgs[0]; + if (this.fn === '=' && isFunctionHead(arg0)) { + return new ExprFunctionDefinition(arg0.fn, arg0.unknowns, collapsedArgs[1]); + } + return new ExprFunction(this.fn, collapsedArgs); } get simplified() { @@ -133,6 +149,18 @@ export class ExprFunction extends ExprElement { return `${this.fn}(${args.join(', ')})`; } + partialEvaluate(vars: ExprMap = {}) { + const base = this.substitute(vars); + const fn: CustomFunction = (...args) => { + const vm: Record = {}; + for (const [index, arg] of base.unknowns.entries()) { + vm[arg] = args[index]; + } + return base.evaluate(vm); + }; + return fn; + } + toMathML(custom: MathMLMap = {}) { const args = this.args.map(a => a.toMathML(custom)); const argsF = this.args.map((a, i) => addMFence(a, this.fn, args[i])); @@ -246,6 +274,74 @@ export class ExprFunction extends ExprElement { } } +const LEGAL_CUSTOM_FUNCTION_NAME = new RegExp('^\\p{L}+$', 'u'); + +function isFunctionHead(expr: E | ExprFunction): expr is ExprFunction { + return expr instanceof ExprFunction && LEGAL_CUSTOM_FUNCTION_NAME.test(expr.fn); +} + +export class ExprFunctionDefinition extends ExprElement { + private _body: ExprElement; + constructor(readonly name: string, readonly params: string[], body: ExprElement) { + super(); + this._body = body; + } + + get body() { + return this._body; + } + + private set body(b: ExprElement) { + this._body = b; + } + + get unknowns() { + return this.body.unknowns.filter(u => !this.params.includes(u)); + } + + get variables() { + return this.body.variables.filter(v => !this.params.includes(v)); + } + + get functions() { + return this.body.functions; + } + + withSubstitutedBody(vars: ExprMap) { + const n = new ExprFunctionDefinition(this.name, this.params, this.body); + n.bodySubstitute(vars); + return n; + } + + bodySubstitute(vars: ExprMap) { // TODO: ensure that param vars dont get substituted + this.body = this.body.substitute(vars); + } + + applyExpressions(args: ExprElement[]) { + const exprMap: ExprMap = {}; + for (const [index, param] of this.params.entries()) { + exprMap[param] = args[index]; + } + return this.body.substitute(exprMap); + } + + applyVals = (...args: number[]) => { + const varMap: VarMap = {}; + for (const [index, param] of this.params.entries()) { + varMap[param] = args[index]; + } + return this.body.evaluate(varMap); + }; + + evaluate(vars: VarMap) { + return this.body.evaluate(vars); + } + + toString() { + return `${this.name}(${this.params.join(',')})=${this.body}`; + } +} + // ----------------------------------------------------------------------------- export class ExprTerm extends ExprElement { diff --git a/src/index.ts b/src/index.ts index 21a2cced..72dbde9f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,6 +7,6 @@ export {ExprError} from './errors'; export {Expression} from './expression'; export {ExprElement, ExprIdentifier, ExprNumber, ExprOperator} from './elements'; -export {ExprFunction} from './functions'; +export {ExprFunction, ExprFunctionDefinition} from './functions'; export {CONSTANTS as HILBERT_CONSTANTS, SPECIAL_IDENTIFIERS, isSpecialFunction} from './symbols'; export {hasZero, Interval, isWhole, width} from './eval';