Skip to content

Commit

Permalink
Add token validation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
emyl3 committed Dec 3, 2024
1 parent d587d5c commit 870dfba
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import com.fasterxml.jackson.annotation.JsonView;
import gov.cdc.usds.simplereport.api.model.errors.DryRunException;
import gov.cdc.usds.simplereport.db.model.DeviceType;
import gov.cdc.usds.simplereport.service.DeviceTypeProdSyncService;
import gov.cdc.usds.simplereport.service.DeviceTypeService;
import gov.cdc.usds.simplereport.service.DeviceTypeSyncService;
import jakarta.servlet.http.HttpServletRequest;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
Expand All @@ -17,6 +22,7 @@
@RequiredArgsConstructor
public class DeviceTypeController {
private final DeviceTypeSyncService deviceTypeSyncService;
private final DeviceTypeProdSyncService deviceTypeProdSyncService;
private final DeviceTypeService deviceTypeService;

@GetMapping("/devices/sync")
Expand All @@ -30,11 +36,14 @@ public void syncDevices(@RequestParam boolean dryRun) {

@GetMapping("/devices")
@JsonView(PublicDeviceType.class)
public List<DeviceType> getDevices() {
public ResponseEntity<Object> getDevices(HttpServletRequest request) {
try {
return deviceTypeService.fetchDeviceTypes();
} catch (Exception e) {
return null;
String headerToken = request.getHeader("Sr-Prod-Devices-Token");
deviceTypeProdSyncService.validateToken(headerToken);
List<DeviceType> devices = deviceTypeService.fetchDeviceTypes();
return ResponseEntity.status(HttpStatus.OK).body(devices);
} catch (AccessDeniedException e) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body(null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
.permitAll()
.requestMatchers(HttpMethod.GET, WebConfiguration.USER_ACCOUNT_REQUEST + "/**")
.permitAll()
// Devices endpoint authorization is handled at the service or controller level
.requestMatchers(HttpMethod.GET, WebConfiguration.DEVICES + "/**")
.permitAll()
// Anything else goes through Okta
.anyRequest()
.authenticated())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class WebConfiguration implements WebMvcConfigurer {
public static final String PATIENT_UPLOAD = "/upload/patients";
public static final String RESULT_UPLOAD = "/upload/results";
public static final String CONDITION_AGNOSTIC_RESULT_UPLOAD = "/upload/condition-agnostic";

public static final String DEVICES = "/devices";
public static final String GRAPH_QL = "/graphql";

@Autowired private RestLoggingInterceptor _loggingInterceptor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public class DeviceType extends EternalAuditedEntity {
@JsonView(PublicDeviceType.class)
private int testLength;

// @JsonIgnore
@OneToMany(mappedBy = "deviceTypeId", cascade = CascadeType.ALL, orphanRemoval = true)
@JsonView(PublicDeviceType.class)
List<DeviceTypeDisease> supportedDiseaseTestPerformed = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
public class DeviceTypeDisease extends IdentifiedEntity {

@Column(name = "device_type_id")
@JsonView(PublicDeviceType.class)
private UUID deviceTypeId;

@ManyToOne
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package gov.cdc.usds.simplereport.service;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

/** Service to fetch and save DeviceTypes from our prod env */
@Service
@Slf4j
@Transactional(readOnly = true)
public class DeviceTypeProdSyncService {
@Value("${simple-report.production.devices-token}")
private String token;

public boolean validateToken(String headerToken) throws AccessDeniedException {
if (token.equals(headerToken)) {
return true;
}
throw new AccessDeniedException("Access denied");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package gov.cdc.usds.simplereport.api;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;

import gov.cdc.usds.simplereport.service.DeviceTypeProdSyncService;
import gov.cdc.usds.simplereport.test_util.TestUserIdentities;
import org.json.JSONArray;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;

class DeviceTypeControllerTest extends BaseFullStackTest {

@Autowired private MockMvc _mockMvc;
@MockBean private DeviceTypeProdSyncService _mockDeviceTypeProdSyncService;

@BeforeEach
void init() {
TestUserIdentities.withStandardUser(
() -> {
_dataFactory.initGenericDeviceTypeAndSpecimenType();
});
}

@Test
void getDevices_withValidateToken_success() throws Exception {
when(_mockDeviceTypeProdSyncService.validateToken(any())).thenReturn(true);
MockHttpServletRequestBuilder builder =
get(ResourceLinks.DEVICES)
.contentType(MediaType.valueOf(MediaType.APPLICATION_JSON_VALUE))
.accept(MediaType.APPLICATION_JSON)
.characterEncoding("UTF-8");

MvcResult result = this._mockMvc.perform(builder).andReturn();
MockHttpServletResponse res = result.getResponse();

assertThat(res.getStatus()).isEqualTo(200);

JSONArray jsonRes = new JSONArray(res.getContentAsString());
assertThat(jsonRes.length()).isEqualTo(1);

JSONObject deviceType = jsonRes.getJSONObject(0);
assertThat(deviceType.getString("manufacturer")).isEqualTo("Acme");
assertThat(deviceType.getString("model")).isEqualTo("SFN");
assertThat(deviceType.getString("name")).isEqualTo("Acme SuperFine");
assertThat(deviceType.getInt("testLength")).isEqualTo(15);
assertThat(deviceType.getJSONArray("supportedDiseaseTestPerformed")).isEmpty();
// ensure deviceType internalId is not returned
assertTrue(deviceType.isNull("internalId"));

JSONArray swabTypes = deviceType.getJSONArray("swabTypes");
assertThat(swabTypes.length()).isEqualTo(1);
JSONObject swabType = swabTypes.getJSONObject(0);
assertThat(swabType.getString("collectionLocationCode")).isEqualTo("986543321");
assertThat(swabType.getString("collectionLocationName")).isEqualTo("Da Nose");
assertThat(swabType.getString("name")).isEqualTo("Nasal swab");
assertThat(swabType.getString("typeCode")).isEqualTo("000111222");
// ensure swabType internalId is not returned
assertTrue(swabType.isNull("internalId"));
}

@Test
void getDevices_withValidateToken_failure() throws Exception {
when(_mockDeviceTypeProdSyncService.validateToken(any()))
.thenThrow(new AccessDeniedException("Bad token"));
MockHttpServletRequestBuilder builder =
get(ResourceLinks.DEVICES)
.contentType(MediaType.valueOf(MediaType.APPLICATION_JSON_VALUE))
.accept(MediaType.APPLICATION_JSON)
.characterEncoding("UTF-8");

MvcResult result = this._mockMvc.perform(builder).andReturn();
MockHttpServletResponse res = result.getResponse();

assertThat(res.getStatus()).isEqualTo(401);
assertThat(res.getContentAsString()).isEmpty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

/** Container class for test constants related to REST handler testing */
public final class ResourceLinks {
public static final String DEVICES = "/devices";
public static final String VERIFY_LINK_V2 = "/pxp/link/verify/v2";

public static final String SELF_REGISTER = "/pxp/register";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package gov.cdc.usds.simplereport.service;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.access.AccessDeniedException;

class DeviceTypeProdSyncServiceTest extends BaseServiceTest<DeviceTypeProdSyncService> {
@Value("${simple-report.production.devices-token}")
private String token;

@Test
void validateToken_success() {
assertThat(_service.validateToken(token)).isTrue();
}

@Test
void validateToken_throwsException() {
assertThrows(AccessDeniedException.class, () -> _service.validateToken("foo"));
}
}

0 comments on commit 870dfba

Please sign in to comment.