Skip to content

Commit

Permalink
Stop using unneeded stream_executor/platform/port.h in stream_executo…
Browse files Browse the repository at this point in the history
…r files.

PiperOrigin-RevId: 681728232
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 3, 2024
1 parent 0946107 commit 25a0df2
Show file tree
Hide file tree
Showing 23 changed files with 164 additions and 179 deletions.
126 changes: 63 additions & 63 deletions xla/stream_executor/blas.h

Large diffs are not rendered by default.

39 changes: 19 additions & 20 deletions xla/stream_executor/cuda/cuda_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ limitations under the License.
#include "xla/stream_executor/gpu/scoped_activate_context.h"
#include "xla/stream_executor/numeric_options.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform/port.h"
#include "xla/stream_executor/plugin_registry.h"
#include "xla/stream_executor/scratch_allocator.h"
#include "xla/stream_executor/stream_executor.h"
Expand Down Expand Up @@ -516,7 +515,7 @@ bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,

absl::Status CUDABlas::DoBlasGemm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64 n, uint64_t k, blas::DataType dtype, const void *alpha,
uint64_t n, uint64_t k, blas::DataType dtype, const void *alpha,
const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb,
const void *beta, DeviceMemoryBase *c, int ldc,
const NumericOptions &numeric_options, blas::CallContext context) {
Expand Down Expand Up @@ -709,7 +708,7 @@ static absl::Status PopulateProfileFromTimer(

absl::Status CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a,
blas::DataType type_a, int lda, const DeviceMemoryBase &b,
blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
blas::DataType type_c, int ldc, blas::ComputationType computation_type,
Expand Down Expand Up @@ -744,7 +743,7 @@ absl::Status CUDABlas::DoBlasGemmWithAlgorithm(

absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a,
blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
Expand Down Expand Up @@ -913,7 +912,7 @@ T inline CUDAComplexValue(T v) {
template <typename T, typename Scalar, typename FuncT>
absl::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha,
blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, Scalar alpha,
const DeviceMemorySlice<T> &a_ptrs_to_wrappers, int lda,
const DeviceMemorySlice<T> &b_ptrs_to_wrappers, int ldb, Scalar beta,
const DeviceMemorySlice<T> &c_ptrs_to_wrappers, int ldc, int batch_count,
Expand Down Expand Up @@ -1025,7 +1024,7 @@ absl::Status CUDABlas::DoBlasGemmBatchedInternal(

bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, float alpha, DeviceMemorySlice<Eigen::half> a_array,
uint64_t n, uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a_array,
int lda, DeviceMemorySlice<Eigen::half> b_array, int ldb, float beta,
DeviceMemorySlice<Eigen::half> c_array, int ldc, int batch_count,
const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
Expand All @@ -1044,7 +1043,7 @@ bool CUDABlas::DoBlasGemmBatched(

bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, float alpha,
uint64_t n, uint64_t k, float alpha,
DeviceMemorySlice<Eigen::bfloat16> a_array, int lda,
DeviceMemorySlice<Eigen::bfloat16> b_array, int ldb, float beta,
DeviceMemorySlice<Eigen::bfloat16> c_array, int ldc, int batch_count,
Expand All @@ -1064,7 +1063,7 @@ bool CUDABlas::DoBlasGemmBatched(

bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, float alpha, DeviceMemorySlice<float> a_array,
uint64_t n, uint64_t k, float alpha, DeviceMemorySlice<float> a_array,
int lda, DeviceMemorySlice<float> b_array, int ldb, float beta,
DeviceMemorySlice<float> c_array, int ldc, int batch_count,
const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
Expand All @@ -1081,7 +1080,7 @@ bool CUDABlas::DoBlasGemmBatched(

bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, double alpha, DeviceMemorySlice<double> a_array,
uint64_t n, uint64_t k, double alpha, DeviceMemorySlice<double> a_array,
int lda, DeviceMemorySlice<double> b_array, int ldb, double beta,
DeviceMemorySlice<double> c_array, int ldc, int batch_count,
const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
Expand All @@ -1099,7 +1098,7 @@ bool CUDABlas::DoBlasGemmBatched(

bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, std::complex<float> alpha,
uint64_t n, uint64_t k, std::complex<float> alpha,
DeviceMemorySlice<std::complex<float>> a_array, int lda,
DeviceMemorySlice<std::complex<float>> b_array, int ldb,
std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c_array,
Expand All @@ -1118,7 +1117,7 @@ bool CUDABlas::DoBlasGemmBatched(

bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, std::complex<double> alpha,
uint64_t n, uint64_t k, std::complex<double> alpha,
DeviceMemorySlice<std::complex<double>> a_array, int lda,
DeviceMemorySlice<std::complex<double>> b_array, int ldb,
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c_array,
Expand All @@ -1136,7 +1135,7 @@ bool CUDABlas::DoBlasGemmBatched(

absl::Status CUDABlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, blas::DataType dtype, const void *alpha,
uint64_t n, uint64_t k, blas::DataType dtype, const void *alpha,
const DeviceMemoryBase &a, int lda, int64_t stride_a,
const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
Expand Down Expand Up @@ -1274,7 +1273,7 @@ absl::Status CUDABlas::DoBlasGemmStridedBatched(

bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
float alpha, const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *b, int ldb) {
return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */,
Expand All @@ -1285,7 +1284,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,

bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
double alpha, const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *b, int ldb) {
return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */,
Expand All @@ -1296,7 +1295,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,

bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
DeviceMemory<std::complex<float>> *b, int ldb) {
Expand All @@ -1310,7 +1309,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,

bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb) {
Expand All @@ -1324,7 +1323,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,

bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
float alpha, const DeviceMemory<float *> &as,
int lda, DeviceMemory<float *> *bs, int ldb,
int batch_count) {
Expand All @@ -1337,7 +1336,7 @@ bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,

bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
double alpha, const DeviceMemory<double *> &as,
int lda, DeviceMemory<double *> *bs, int ldb,
int batch_count) {
Expand All @@ -1350,7 +1349,7 @@ bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,

bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float> *> &as,
int lda,
Expand All @@ -1367,7 +1366,7 @@ bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,

bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64_t m, uint64 n,
blas::Diagonal diag, uint64_t m, uint64_t n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double> *> &as,
int lda,
Expand Down
3 changes: 1 addition & 2 deletions xla/stream_executor/cuda/cuda_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ limitations under the License.
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/cuda/cuda_blas_lt.h"
#include "xla/stream_executor/numeric_options.h"
#include "xla/stream_executor/platform/port.h"

namespace stream_executor {

Expand Down Expand Up @@ -103,7 +102,7 @@ class CUDABlas : public blas::BlasSupport {
template <typename T, typename Scalar, typename FuncT>
absl::Status DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha,
blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, Scalar alpha,
const DeviceMemorySlice<T> &a_array, int lda,
const DeviceMemorySlice<T> &b_array, int ldb, Scalar beta,
const DeviceMemorySlice<T> &c_array, int ldc, int batch_count,
Expand Down
11 changes: 5 additions & 6 deletions xla/stream_executor/cuda/cuda_fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ limitations under the License.
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/scoped_activate_context.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform/port.h"
#include "xla/stream_executor/plugin_registry.h"
#include "xla/stream_executor/scratch_allocator.h"
#include "xla/stream_executor/stream.h"
Expand Down Expand Up @@ -103,8 +102,8 @@ absl::StatusOr<std::array<int32_t, 3>> Downsize64bArray(

absl::Status CUDAFftPlan::Initialize(
GpuExecutor *parent, Stream *stream, int rank, uint64_t *elem_count,
uint64_t *input_embed, uint64 input_stride, uint64 input_distance,
uint64_t *output_embed, uint64 output_stride, uint64 output_distance,
uint64_t *input_embed, uint64_t input_stride, uint64_t input_distance,
uint64_t *output_embed, uint64_t output_stride, uint64_t output_distance,
fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
if (IsInitialized()) {
return absl::InternalError("cuFFT is already initialized.");
Expand Down Expand Up @@ -309,9 +308,9 @@ int CUDAFftPlan::GetFftDirection() const {
}

std::unique_ptr<fft::Plan> CUDAFft::CreateBatchedPlanWithScratchAllocator(
Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed,
uint64_t input_stride, uint64 input_distance, uint64 *output_embed,
uint64_t output_stride, uint64 output_distance, fft::Type type,
Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed,
uint64_t input_stride, uint64_t input_distance, uint64_t *output_embed,
uint64_t output_stride, uint64_t output_distance, fft::Type type,
bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
absl::Status status = fft_plan_ptr->Initialize(
Expand Down
3 changes: 1 addition & 2 deletions xla/stream_executor/cuda/cuda_fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cufft.h"
#include "xla/stream_executor/fft.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/platform/port.h"
#include "xla/stream_executor/scratch_allocator.h"
#include "xla/stream_executor/stream.h"

Expand Down Expand Up @@ -66,7 +65,7 @@ class CUDAFftPlan : public fft::Plan {
// Initialize function for batched plan
absl::Status Initialize(GpuExecutor* parent, Stream* stream, int rank,
uint64_t* elem_count, uint64_t* input_embed,
uint64_t input_stride, uint64 input_distance,
uint64_t input_stride, uint64_t input_distance,
uint64_t* output_embed, uint64_t output_stride,
uint64_t output_distance, fft::Type type,
int batch_count, ScratchAllocator* scratch_allocator);
Expand Down
14 changes: 6 additions & 8 deletions xla/stream_executor/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ limitations under the License.
#include <cstdint>
#include <memory>

#include "xla/stream_executor/platform/port.h"

namespace stream_executor {

class Stream;
Expand Down Expand Up @@ -109,9 +107,9 @@ class FftSupport {
// output_distance: Indicates the distance between the first element of two
// consecutive signals in a batch of the output data.
virtual std::unique_ptr<Plan> CreateBatchedPlanWithScratchAllocator(
Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed,
uint64_t input_stride, uint64 input_distance, uint64 *output_embed,
uint64_t output_stride, uint64 output_distance, Type type,
Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed,
uint64_t input_stride, uint64_t input_distance, uint64_t *output_embed,
uint64_t output_stride, uint64_t output_distance, Type type,
bool in_place_fft, int batch_count,
ScratchAllocator *scratch_allocator) = 0;

Expand Down Expand Up @@ -162,9 +160,9 @@ class FftSupport {
// ::stream_executor namespace.
#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
std::unique_ptr<fft::Plan> CreateBatchedPlanWithScratchAllocator( \
Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, \
uint64_t input_stride, uint64 input_distance, uint64 *output_embed, \
uint64_t output_stride, uint64 output_distance, fft::Type type, \
Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed, \
uint64_t input_stride, uint64_t input_distance, uint64_t *output_embed, \
uint64_t output_stride, uint64_t output_distance, fft::Type type, \
bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \
override; \
void UpdatePlanWithScratchAllocator(Stream *stream, fft::Plan *plan, \
Expand Down
3 changes: 2 additions & 1 deletion xla/stream_executor/gpu/redzone_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class RedzoneAllocator : public ScratchAllocator {
return allocated_bytes_excluding_redzones_;
}

absl::StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override;
absl::StatusOr<DeviceMemory<uint8_t>> AllocateBytes(
int64_t byte_size) override;

// Non-empty redzone check status implies that there was a write into a
// redzone, with a string communicating the location of the write.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace stream_executor {
absl::StatusOr<const ComparisonKernel*> GetComparisonKernel(
StreamExecutor* executor, GpuAsmOpts /*gpu_asm_opts*/) {
static auto kernel = TypedKernelFactory<
DeviceMemory<uint8>, uint8, uint64_t,
DeviceMemory<uint8_t>, uint8_t, uint64_t,
DeviceMemory<uint64_t>>::Create(executor, "redzone_checker",
reinterpret_cast<void*>(
redzone_checker_kernel));
Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/rocm/hipblaslt_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ limitations under the License.
#else
#include "rocm/include/hipblaslt.h"
#endif
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"

Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/rocm/hipsolver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.

#include "rocm/include/hipsolver.h"
#endif
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"

Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/rocm/hipsparse_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ limitations under the License.
#include "rocm/include/hipsparse.h"
#endif
#include "xla/stream_executor/platform/platform.h"
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"

Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/rocm/rocblas_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License.

#include "rocm/include/rocblas/rocblas.h"
#include "rocm/rocm_config.h"
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"
#include "tsl/platform/platform.h"
Expand Down
Loading

0 comments on commit 25a0df2

Please sign in to comment.