Skip to content

Commit

Permalink
Merge pull request #213 from skni-kod/issue-185_2
Browse files Browse the repository at this point in the history
#185 kodemy-search and commons tests
  • Loading branch information
marcinbator authored Dec 13, 2024
2 parents bec20f7 + 4e579f6 commit 87584f0
Show file tree
Hide file tree
Showing 14 changed files with 964 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package pl.sknikod.kodemycommons.data

import org.springframework.security.core.Authentication
import org.springframework.security.core.context.SecurityContextHolder
import pl.sknikod.kodemycommons.security.UserPrincipal
import spock.lang.Specification

import java.time.LocalDateTime


class AuditableSpec extends Specification {


def "should update createdDate and createdBy on prePersist"() {
given: "fake authentication"
withAuthentication()

and: "an instance of Auditable and initial state"
def auditable = new Auditable() {}
def beforeDate = auditable.getCreatedDate()
def beforeCreatedBy = auditable.getCreatedBy()
def dateBeforePrePersist = LocalDateTime.now()

when: "onPrePersist is called"
auditable.onPrePersist()

then: "before call createdDate and createdBy are null"
beforeDate == null
beforeCreatedBy == null

and: "after call createdDate and createdBy are set"
auditable.getCreatedDate() != beforeDate
auditable.getCreatedDate().isAfter(dateBeforePrePersist) || auditable.getCreatedDate().isEqual(dateBeforePrePersist)
auditable.getCreatedBy() != null

cleanup:
clearAuthentication()
}

def "should update modifiedDate and modifiedBy on preUpdate"() {
given: "fake authentication"
withAuthentication()

and: "an instance of Auditable and initial state"
def auditable = new Auditable() {}
def beforeDate = auditable.getModifiedDate()
def beforeModifiedBy = auditable.getModifiedBy()
def dateBeforePreUpdate = LocalDateTime.now()

when: "onPreUpdate is called"
auditable.onPreUpdate()

then: "before call modifiedDate and modifiedBy are null"
beforeDate == null
beforeModifiedBy == null

and: "after call modifiedDate and modifiedBy are set"
auditable.getModifiedDate() != beforeDate
auditable.getModifiedDate().isAfter(dateBeforePreUpdate) || auditable.getModifiedDate().isEqual(dateBeforePreUpdate)
auditable.getModifiedBy() != null

cleanup:
clearAuthentication()
}


private void withAuthentication() {
def userPrincipal = new UserPrincipal(1L, "username", Collections.emptyList())
def authentication = Mock(Authentication) {
isAuthenticated() >> true
getPrincipal() >> userPrincipal
}
SecurityContextHolder.getContext().setAuthentication(authentication)
}

private static void clearAuthentication() {
SecurityContextHolder.clearContext()
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package pl.sknikod.kodemycommons.security

import org.springframework.security.core.Authentication
import org.springframework.security.core.GrantedAuthority
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.context.SecurityContextHolder
import spock.lang.Specification

class AuthFacadeSpec extends Specification {

def setup() {
clearAuthentication()
}

def cleanup() {
clearAuthentication()
}

def "should get authentication"() {
given:
withAuthentication()
when:
def authentication = AuthFacade.getAuthentication()
then:
authentication.isPresent()
}

def "should check is authenticated"() {
given:
auth.call()
when:
def isAuthenticated = AuthFacade.isAuthenticated()
then:
isAuthenticated == expected

where:
auth || expected ;
{withAuthentication(authenticated: true)} || true ;
{withAuthentication(authenticated: false)} || false ;
}

def "should get current username"() {
given:
withAuthentication(username: "abc")
when:
def username = AuthFacade.getCurrentUsername()
then:
username == "abc"
}

def "should get current user principal"() {
given:
withAuthentication()
when:
def userPrincipal = AuthFacade.getCurrentUserPrincipal()
then:
userPrincipal.isPresent()
}

def "should find any authority"() {
given:
withAuthentication(authorities: List.of(
new SimpleGrantedAuthority("ROLE_USER"),
new SimpleGrantedAuthority("ROLE_ADMIN")))
when:
def anyAuthority = AuthFacade.hasAnyAuthority("ROLE_SUPERADMIN", "ROLE_ADMIN")
then:
anyAuthority
}

def "should find authority"() {
given:
withAuthentication(authorities: List.of(
new SimpleGrantedAuthority("ROLE_USER"),
new SimpleGrantedAuthority("ROLE_ADMIN")))
when:
def authority = AuthFacade.hasAuthority(role)
then:
authority == expected

where:
role || expected
"ROLE_ADMIN" || true
"ROLE_SUPERADMIN" || false
}


def withAuthentication(Map args = [:]) {
Long id = args.containsKey('id') ? args.id : 1L
String username = args.containsKey('username') ? args.username : "username"
Collection<? extends GrantedAuthority> authorities = args.containsKey('authorities') ? args.authorities : Collections.emptyList()
boolean authenticated = args.containsKey('authenticated') ? args.authenticated : true

def userPrincipal = new UserPrincipal(id, username, authorities as Collection<SimpleGrantedAuthority>)
def authentication = Mock(Authentication) {
isAuthenticated() >> authenticated
getPrincipal() >> userPrincipal
getName() >> username
getAuthorities() >> authorities
}
SecurityContextHolder.getContext().setAuthentication(authentication)
}

private static void clearAuthentication() {
SecurityContextHolder.clearContext()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package pl.sknikod.kodemycommons.security

import org.springframework.security.core.authority.SimpleGrantedAuthority
import spock.lang.Specification

class JwtProviderSpec extends Specification {

def "should generate delegation token"() {
given:
def jwtProvider = new JwtProvider(new FakeProperties())
String subject = "user123"
String authority = "ROLE_ADMIN"
when:
def token = jwtProvider.generateDelegationToken(subject, authority)
then:
token != null
token.id() != null
token.value() != null
token.expiration() > new Date()
and:
def parsed = jwtProvider.parseToken(token.value())
parsed.isSuccess()
def result = parsed.get()
result.bearerId == token.id()
result.username == "user123"
result.authorities.contains(new SimpleGrantedAuthority("ROLE_ADMIN"))
}

def "should generate user token"() {
given:
def jwtProvider = new JwtProvider(new FakeProperties())
def input = new JwtProvider.Input(5L, "user123", false, false,
false, true, [new SimpleGrantedAuthority("ROLE_USER")] as Set)
when:
def token = jwtProvider.generateUserToken(input)
then:
token != null
token.id() != null
token.value() != null
token.expiration() > new Date()
and:
def parsed = jwtProvider.parseToken(token.value())
parsed.isSuccess()
def result = parsed.get()
result.id == 5
result.bearerId == token.id()
result.username == "user123"
result.authorities.contains(new SimpleGrantedAuthority("ROLE_USER"))
}



static class FakeProperties extends JwtProvider.Properties {
FakeProperties() {
secretKey = 'YWJjZGVmZ2hjvbwjrW5vcHFyc3R1dnd4eXoxMjM0NTY3OnDMTIzNDU2Nzg5MDEyMzQ1Njc4OTAf='
bearerExpirationMin = 15
delegationExpirationMin = 60
}
}
}
3 changes: 2 additions & 1 deletion kodemy-search/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ dependencies {
testImplementation 'org.springframework.cloud:spring-cloud-stream'
testImplementation 'org.springframework.cloud:spring-cloud-stream-test-binder'

testImplementation 'org.spockframework:spock-spring:2.4-M4-groovy-4.0'
testImplementation "org.apache.groovy:groovy"
testImplementation "org.apache.groovy:groovy-json"
testImplementation 'org.spockframework:spock-spring:2.4-M4-groovy-4.0'
testImplementation "org.testcontainers:spock:1.20.0"

testCompileOnly 'org.projectlombok:lombok'
testAnnotationProcessor 'org.projectlombok:lombok'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ private SearchCriteria createSearchCriteria(@NonNull MaterialControllerDefinitio
));
var categoryIds = filterSearchParams.getCategoryIds();
if (Objects.nonNull(categoryIds) && !categoryIds.isEmpty()) {
categoryIds.forEach(categoryId -> {
criteria.addPhraseField(new SearchCriteria.PhraseField(
"categoryId", String.valueOf(categoryId), false, false
));
});
criteria.addArrayField(new SearchCriteria.ArrayField(
"categoryId", categoryIds.stream().map(String::valueOf).toList()));
}
if (Objects.nonNull(filterSearchParams.getMinAvgGrade()) || Objects.nonNull(filterSearchParams.getMaxAvgGrade()))
criteria.addRangeField(new SearchCriteria.RangeField<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class SearchCriteria {
String anyPhrase;
List<PhraseField> phraseFields = new ArrayList<>();
List<RangeField<?>> rangeFields = new ArrayList<>();
List<ArrayField> arrayFields = new ArrayList<>();
Pageable pageable;

public SearchCriteria(@NonNull String anyPhrase, @NonNull Pageable pageable) {
Expand All @@ -30,6 +31,10 @@ public void addRangeField(RangeField<?> field) {
rangeFields.add(field);
}

public void addArrayField(ArrayField field) {
arrayFields.add(field);
}

@Getter
private abstract static class Field {
private final String name;
Expand Down Expand Up @@ -65,6 +70,16 @@ public PhraseField(String name, String value, boolean wildcard, boolean mustNot)
}
}

@Getter
public static class ArrayField extends Field {
private final List<String> values;

public ArrayField(String name, List<String> values) {
super(name);
this.values = values;
}
}

@Getter
public static class ContentField {
private final String value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Strings;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.SortOptions;
import org.opensearch.client.opensearch._types.SortOrder;
import org.opensearch.client.opensearch._types.query_dsl.MatchPhraseQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch._types.query_dsl.RangeQuery;
import org.opensearch.client.opensearch._types.query_dsl.TermsQuery;
import org.opensearch.client.opensearch._types.query_dsl.WildcardQuery;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.springframework.data.domain.Pageable;
Expand Down Expand Up @@ -42,6 +44,7 @@ public SearchRequestBuilder(String indexName, SearchCriteria criteria) {
any(criteria.getAnyPhrase());
criteria.getPhraseFields().forEach(this::append);
criteria.getRangeFields().forEach(this::append);
criteria.getArrayFields().forEach(this::append);
}

private void with(Pageable pageable) {
Expand Down Expand Up @@ -74,7 +77,7 @@ private void append(SearchCriteria.PhraseField field) {
var query = field.isWildcard()
? WildcardQuery.of(w -> w.field(field.getName()).value(field.getValue())).toQuery()
: MatchPhraseQuery.of(m -> m.field(field.getName()).query(field.getValue())).toQuery();
(field.isMustNot() ? mustNotQueries : shouldQueries).add(query);
(field.isMustNot() ? mustNotQueries : mustQueries).add(query);
}

private void append(SearchCriteria.RangeField<?> field) {
Expand All @@ -87,6 +90,17 @@ private void append(SearchCriteria.RangeField<?> field) {
mustQueries.add(rangeQueryBuilder.build().toQuery());
}

private void append(SearchCriteria.ArrayField field) {
if (field == null || field.getValues().isEmpty()) {
return;
}
var query = TermsQuery.of(t -> t.field(field.getName())
.terms(terms -> terms.value(field.getValues().stream().map(FieldValue::of).toList()))
).toQuery();
mustQueries.add(query);
}


public SearchRequest build() {
return new SearchRequest.Builder()
.index(indexName)
Expand Down
Loading

0 comments on commit 87584f0

Please sign in to comment.