diff --git a/kernel/api/src/main/java/org/sakaiproject/springframework/data/SpringCrudRepositoryImpl.java b/kernel/api/src/main/java/org/sakaiproject/springframework/data/SpringCrudRepositoryImpl.java index f0872116be69..4d79c4c4a7b0 100644 --- a/kernel/api/src/main/java/org/sakaiproject/springframework/data/SpringCrudRepositoryImpl.java +++ b/kernel/api/src/main/java/org/sakaiproject/springframework/data/SpringCrudRepositoryImpl.java @@ -17,6 +17,7 @@ import lombok.Getter; import lombok.Setter; +import lombok.extern.slf4j.Slf4j; import org.hibernate.Criteria; import org.hibernate.Session; import org.hibernate.SessionFactory; @@ -28,14 +29,12 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.util.Assert; -import org.sakaiproject.springframework.data.PersistableEntity; -import org.sakaiproject.springframework.data.Repository; - import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Optional; +@Slf4j @Transactional(readOnly = true) public abstract class SpringCrudRepositoryImpl, ID extends Serializable> implements SpringCrudRepository { @@ -89,14 +88,14 @@ public Iterable saveAll(Iterable entities) { public Optional findById(ID id) { Assert.notNull(id, "The id cannot be null"); - return Optional.ofNullable((T) sessionFactory.getCurrentSession().get(domainClass, id)); + return Optional.ofNullable(sessionFactory.getCurrentSession().get(domainClass, id)); } @Override public T getById(ID id) { Assert.notNull(id, "The id cannot be null"); - return (T) sessionFactory.getCurrentSession().load(domainClass, id); + return sessionFactory.getCurrentSession().load(domainClass, id); } @Override @@ -116,7 +115,7 @@ public Page findAll(Pageable pageable) { Criteria criteria = sessionFactory.getCurrentSession().createCriteria(domainClass); criteria.setFirstResult((int) pageable.getOffset()); - criteria.setMaxResults((int) pageable.getPageSize()); + criteria.setMaxResults(pageable.getPageSize()); return new PageImpl(criteria.list()); } @@ -125,7 +124,7 @@ public Iterable findAllById(Iterable ids) { List list = new ArrayList<>(); if (ids != null) { - ids.forEach(id -> findById(id).ifPresent(found -> list.add(found))); + ids.forEach(id -> findById(id).ifPresent(list::add)); } return list; } @@ -133,13 +132,10 @@ public Iterable findAllById(Iterable ids) { @Override @Transactional public void delete(T entity) { - - Session session = sessionFactory.getCurrentSession(); - - try { - session.delete(entity); - } catch (Exception he) { - session.delete(session.merge(entity)); + if (entity != null) { + deleteById(entity.getId()); + } else { + log.warn("Can not perform delete on a null entity"); } } @@ -161,7 +157,12 @@ public void deleteAll(Iterable entities) { @Override @Transactional public void deleteById(ID id) { - findById(id).ifPresent(found -> delete(found)); + if (id != null) { + Session session = sessionFactory.getCurrentSession(); + findById(id).ifPresent(session::delete); + } else { + log.warn("Can not perform delete with a null id"); + } } /**