Skip to content

Commit

Permalink
added boost field setter and getter for Knn Search
Browse files Browse the repository at this point in the history
  • Loading branch information
hkulekci committed Nov 18, 2023
1 parent 96d4f5a commit 190de85
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
22 changes: 21 additions & 1 deletion src/Knn/Knn.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Knn implements BuilderInterface
private $numCandidates;

/**
* @var int
* @var float|null
*/
private $boost;

Expand Down Expand Up @@ -153,6 +153,22 @@ public function setSimilarity(float $similarity): void
$this->similarity = $similarity;
}

/**
* @return float|null
*/
public function getBoost(): ?float
{
return $this->boost;
}

/**
* @param float $boost
*/
public function setBoost(float $boost): void
{
$this->boost = $boost;
}

/**
* @return BuilderInterface|null
*/
Expand Down Expand Up @@ -193,6 +209,10 @@ public function toArray()
$output['similarity'] = $this->getSimilarity();
}

if ($this->getBoost()) {
$output['boost'] = $this->getBoost();
}

if ($this->getFilter()) {
$output['filter'] = $this->getFilter()->toArray();
}
Expand Down
23 changes: 22 additions & 1 deletion tests/Functional/Knn/KnnTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

namespace ONGR\ElasticsearchDSL\Tests\Functional\Knn;

use ONGR\ElasticsearchDSL\Aggregation\Bucketing\DateHistogramAggregation;
use Composer\InstalledVersions;
use Elastic\Elasticsearch\Client;
use ONGR\ElasticsearchDSL\Knn\Knn;
use ONGR\ElasticsearchDSL\Query\TermLevel\TermQuery;
use ONGR\ElasticsearchDSL\Search;
Expand Down Expand Up @@ -82,4 +83,24 @@ public function testKnnSearchWithFilter(): void
$this->assertCount(1, $results['hits']['hits']);
$this->assertEquals('doc_3', $results['hits']['hits'][0]['_id']);
}

/**
* Match all test
*/
public function testMultipleKnnSearchWithBoost(): void
{
$knn1 = new Knn('vector_field', [1, 2, 3], 1, 1);
$knn1->setFilter(new TermQuery('label', 2));
$knn1->setBoost(0.5);
$knn2 = new Knn('vector_field', [1, 2, 4], 1, 1);
$knn2->setFilter(new TermQuery('label', 2));
$knn2->setBoost(0.1);

$search = new Search();
$search->addKnn($knn1);
$search->addKnn($knn2);
$results = $this->executeSearch($search, true);
$this->assertCount(1, $results['hits']['hits']);
$this->assertEquals('doc_3', $results['hits']['hits'][0]['_id']);
}
}

0 comments on commit 190de85

Please sign in to comment.