Skip to content

Commit

Permalink
Improve CosineDistance for scalar values
Browse files Browse the repository at this point in the history
  • Loading branch information
axkr committed Dec 2, 2024
1 parent d673637 commit 76f3e70
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.matheclipse.core.interfaces.IBuiltInSymbol;
import org.matheclipse.core.interfaces.IEvaluator;
import org.matheclipse.core.interfaces.IExpr;
import org.matheclipse.core.interfaces.INumber;

public class ClusteringFunctions {

Expand All @@ -34,6 +35,14 @@ private abstract static class AbstractDistance extends AbstractEvaluator

public abstract IExpr distance(IExpr a, IExpr b, EvalEngine engine);

public IExpr scalarDistance(INumber a, INumber b, EvalEngine engine) {
return F.NIL;
}

public IExpr numericFunctionDistance(IExpr a, IExpr b, EvalEngine engine) {
return F.NIL;
}

@Override
public IExpr evaluate(final IAST ast, EvalEngine engine) {
IExpr arg1 = ast.arg1();
Expand All @@ -42,12 +51,24 @@ public IExpr evaluate(final IAST ast, EvalEngine engine) {
int dim1 = arg1.isVector();
if (dim1 > (-1)) {
int dim2 = arg2.isVector();
if (dim1 == dim2) {
if (dim1 == 0) {
return F.NIL;
if (dim2 > (-1)) {
if (dim1 != dim2) {
// The arguments `1` and `2` do not have compatible dimensions.
return Errors.printMessage(ast.topHead(), "bldim", F.List(arg1, arg2), engine);
}
if (dim1 != 0) {
return vectorDistance(arg1, arg2, engine);
}
return vectorDistance(arg1, arg2, engine);
}
return F.NIL;
}
if (arg1.isNumber() && arg2.isNumber()) {
INumber n1 = (INumber) arg1;
INumber n2 = (INumber) arg2;
return scalarDistance(n1, n2, engine);
}
if (arg1.isNumericFunction() && arg2.isNumericFunction()) {
return numericFunctionDistance(arg1, arg2, engine);
}
return F.NIL;
}
Expand Down Expand Up @@ -373,6 +394,35 @@ public double compute(double[] a, double[] b) throws MathIllegalArgumentExceptio
return 1.0 - cosAngle;
}

@Override
public IExpr scalarDistance(INumber arg1, INumber arg2, EvalEngine engine) {
IExpr norm1 = arg1.abs();
if (norm1.isZero()) {
return F.C0;
}
IExpr norm2 = arg2.abs();
if (norm2.isZero()) {
return F.C0;
}
final INumber c = arg1.times(arg2.conjugate());
return engine.evaluate(F.Subtract(F.C1, F.Divide(c, F.Times(norm1, norm2))));
}


@Override
public IExpr numericFunctionDistance(IExpr arg1, IExpr arg2, EvalEngine engine) {
IExpr norm1 = arg1.abs();
if (norm1.isPossibleZero(true)) {
return F.C0;
}
IExpr norm2 = arg2.abs();
if (norm2.isPossibleZero(true)) {
return F.C0;
}
return engine.evaluate(
F.Subtract(F.C1, F.Divide(F.Times(arg1, F.Conjugate(arg2)), F.Times(norm1, norm2))));
}

@Override
public IExpr distance(IExpr arg1, IExpr arg2, EvalEngine engine) {
IExpr norm1 = F.Norm.of(engine, arg1);
Expand All @@ -383,7 +433,8 @@ public IExpr distance(IExpr arg1, IExpr arg2, EvalEngine engine) {
if (norm2.isZero()) {
return F.C0;
}
return F.Subtract(F.C1, F.Divide(F.Dot(arg1, F.Conjugate(arg2)), F.Times(norm1, norm2)));
return engine.evaluate(
F.Subtract(F.C1, F.Divide(F.Dot(arg1, F.Conjugate(arg2)), F.Times(norm1, norm2))));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public static void initGeneralMessages() {
"bset",
"The second argument `1` of Element should be one of: Primes, Integers, Rationals, Algebraics, Reals, Complexes or Booleans.", //
"bfun", "`1` is not a boolean-valued pure function.", //
"bldim", "The arguments `1` and `2` do not have compatible dimensions.", //
"boxfmt", "`1` is not a box formatting type.", //
"bdomv", "Warning: `1` is not a valid domain specification.", //
"cfn", "Numerical error encountered, proceeding with uncompiled evaluation.", //
Expand Down Expand Up @@ -351,8 +352,7 @@ public static void initGeneralMessages() {
"The number of subdivisions given in position `1` of `2` should be a positive machine-sized integer.", //
"seqs",
"Sequence specification (+n,-n,{+n},{-n},{m,n}) or {m,n,s} expected at position `2` in `1`.", //
"seqso",
"Sequence specification (+n,-n,{+n},{-n} or {m,n}) expected at position `2` in `1`.", //
"seqso", "Sequence specification (+n,-n,{+n},{-n} or {m,n}) expected at position `2` in `1`.", //
"setp", "Part assignment to `1` could not be made", //
"setraw", "Cannot assign to raw object `1`.", //
"setps", "`1` in the part assignment is not a symbol.", //
Expand Down Expand Up @@ -542,7 +542,7 @@ public static IAST printMessage(ISymbol symbol, String messageShortcut,

/**
* Format a message according to the shortcut from the {@link MESSAGES} array and print it to the
* error stream with the <code>engine.printMessage()</code>method.
* error stream with the help of the {@link EvalEngine#getErrorPrintStream()} method.
*
* <p>
* Usage pattern:
Expand All @@ -557,7 +557,7 @@ public static IAST printMessage(ISymbol symbol, String messageShortcut,
* @param listOfParameters a list of arguments which should be inserted into the message shortcuts
* placeholder
* @param engine
* @return always <code>F.NIL</code>
* @return always {@link F#NIL}
*/
public static IAST printMessage(ISymbol symbol, String messageShortcut,
final IAST listOfParameters, EvalEngine engine) {
Expand Down Expand Up @@ -855,11 +855,12 @@ public static String templateRender(String templateStr, String[] args) {


public static void rethrowsInterruptException(Exception e) {
if (e instanceof ApfloatInterruptedException || e instanceof PreemptingException || e instanceof TimeoutException) {
if (e instanceof ApfloatInterruptedException || e instanceof PreemptingException
|| e instanceof TimeoutException) {
throw (RuntimeException) e;
}
if (e instanceof RuntimeException && e.getCause() instanceof InterruptedException) {
throw (RuntimeException) e;
throw (RuntimeException) e;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,11 @@ public ContextPath getContextPath() {
return fContextPath;
}

/**
* Get the current error print stream. If no user error print stream is defined, the standard
* error print stream {@link System#err} is returned.
*
*/
public PrintStream getErrorPrintStream() {
return fErrorPrintStream != null ? fErrorPrintStream : System.err;
}
Expand Down Expand Up @@ -3162,6 +3167,11 @@ public Iterator<IdentityHashMap<ISymbol, IASTAppendable>> optionsStackIterator()
return fOptionsStack.iterator();
}

/**
* Get the current output print stream. If no user output print stream is defined, the standard
* output print stream {@link System#out} is returned.
*
*/
public PrintStream getOutPrintStream() {
return fOutPrintStream != null ? fOutPrintStream : System.out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3969,15 +3969,36 @@ public void testCorrelationDistance() {

@Test
public void testCosineDistance() {
check("CosineDistance(Sin(Pi/111) ,Cos(E))", //
"2");
check("CosineDistance(17,47)", //
"0");
check("CosineDistance(2/3,-3/7)", //
"2");
check("CosineDistance(2/3*I,-3/11)", //
"1+I");
check("CosineDistance(-Sin(Pi^2)+I , Cos(E))", //
"1+(I-Sin(Pi^2))/Sqrt(1+Sin(Pi^2)^2)");


// CosineDistance: The arguments {a,b} and {x,y,z} do not have compatible dimensions.
check("CosineDistance({a, b}, {x, y, z})", //
"CosineDistance({a,b},{x,y,z})");


check("CosineDistance({1, 0}, {x, y})", //
"1-Conjugate(x)/Sqrt(Abs(x)^2+Abs(y)^2)");
check("CosineDistance({a, b, c}, {x, y, z})", //
"1-(a*Conjugate(x)+b*Conjugate(y)+c*Conjugate(z))/Sqrt((Abs(a)^2+Abs(b)^2+Abs(c)^\n" //
+ "2)*(Abs(x)^2+Abs(y)^2+Abs(z)^2))");

check("CosineDistance({7.0, 9}, {71, 89})", //
"0.0000759646");
check("N(CosineDistance({7, 9}, {71, 89}))", //
"0.0000759646");
check("CosineDistance({a, b}, {c, d})", //
"1-(a*Conjugate(c)+b*Conjugate(d))/Sqrt((Abs(a)^2+Abs(b)^2)*(Abs(c)^2+Abs(d)^2))");//
check("CosineDistance({a, b, c}, {x, y, z})", //
"1-(a*Conjugate(x)+b*Conjugate(y)+c*Conjugate(z))/Sqrt((Abs(a)^2+Abs(b)^2+Abs(c)^\n" //
+ "2)*(Abs(x)^2+Abs(y)^2+Abs(z)^2))");

}

@Test
Expand Down Expand Up @@ -10991,6 +11012,16 @@ public void testIncrement() {

@Test
public void testIndeterminate() {
check("Infinity-Infinity", //
"Indeterminate");
check("ComplexInfinity+ComplexInfinity", //
"Indeterminate");
check("Indeterminate+1", //
"Indeterminate");
check("0*Indeterminate", //
"Indeterminate");
check("0*ComplexInfinity", //
"Indeterminate");
check("Tan(Indeterminate)", //
"Indeterminate");
check("{And(True, Indeterminate), And(False, Indeterminate)}", //
Expand All @@ -10999,6 +11030,8 @@ public void testIndeterminate() {
"Indeterminate==Indeterminate");
check("Indeterminate===Indeterminate", //
"True");
check("Indeterminate!=Indeterminate", //
"Indeterminate!=Indeterminate");
check("{Re(Indeterminate), Im(Indeterminate)}", //
"{Indeterminate,Indeterminate}");
check("NumberQ(Indeterminate)", //
Expand Down Expand Up @@ -24704,6 +24737,16 @@ public void testTogetherIssue856() {

@Test
public void testToExpression() {
// TODO print syntax error to error stream
check("ToExpression(\"1+2}\")", //
"$Failed");
check("ToExpression(\"1+2\")", //
"3");
check("ToExpression(\"{2, 3, 1}\", InputForm, Max)", //
"3");
check("ToExpression(\"2 3\", InputForm)", //
"6");

check("ToExpression(\"\\\\begin{matrix}\n" //
+ "1 & 2 \\\\\\\\\n" //
+ " 7 & 8\n" //
Expand Down Expand Up @@ -25957,6 +26000,8 @@ public void testWhittakerW() {

@Test
public void testWith() {
check("Block({i = 0}, With({}, Module({j = i}, i=i+1; j)))", //
"0");
EvalEngine.resetModuleCounter4JUnit();

// print message: Set: Cannot unset object 2.0.
Expand Down

0 comments on commit 76f3e70

Please sign in to comment.