hipBLASLtExt API reference#
hipBLASLt contains extension APIs with the namespace hipblaslt_ext. They are only C++ compatible. The extensions support the following:
hipBLASLtExt datatypes reference#
GemmType#
GemmProblemType#
-
class GemmProblemType#
hipblasLt extension ProblemType for gemm problems.
This structure sets the problem type of a gemm problem.
Public Functions
-
void setOpA(hipblasOperation_t op)#
Set the A martix transpose.
-
void setOpB(hipblasOperation_t op)#
Set the B matrix transpose.
-
void setTypeA(hipDataType type)#
Set the A matrix datatype.
-
void setTypeB(hipDataType type)#
Set the B matrix datatype.
-
void setTypeC(hipDataType type)#
Set the C matrix datatype.
-
void setTypeD(hipDataType type)#
Set the D matrix datatype.
-
void setTypeCompute(hipblasComputeType_t type)#
Set the compute datatype.
-
void setOrderA(hipblasLtOrder_t order)#
Set the A martix data order.
-
void setOrderB(hipblasLtOrder_t order)#
Set the B matrix data order.
-
hipblasOperation_t getOpA() const#
The A matrix transpose.
-
hipblasOperation_t getOpB() const#
The B matrix transpose.
-
hipDataType getTypeA() const#
The A matrix datatype.
-
hipDataType getTypeB() const#
The B matrix datatype.
-
hipDataType getTypeC() const#
The C matrix datatype.
-
hipDataType getTypeD() const#
The D matrix datatype.
-
hipblasComputeType_t getTypeCompute() const#
The compute datatype.
-
hipblasLtOrder_t getOrderA() const#
The A matrix data order.
-
hipblasLtOrder_t getOrderB() const#
The B matrix data order.
-
void setOpA(hipblasOperation_t op)#
GemmEpilogue#
-
class GemmEpilogue#
hipblasLt extension Epilogue for gemm problems.
This class sets the epilogue of a gemm problem.
Public Functions
-
void setMode(hipblasLtEpilogue_t mode)#
Set the mode of epilogue. Default is gemm.
-
void setBiasDataType(hipDataType biasDataType)#
Set the bias datatype. Only works if mode is set to bias related epilogues.
-
void setAuxDataType(hipDataType auxDataType)#
Set the aux datatype. Only works if mode is set to aux related epilogues.
-
void setAuxLeadingDimension(int auxLeadingDimension)#
Set the aux leading dimension. Only works if mode is set to aux related epilogues.
-
void setAuxBatchStride(int auxBatchStride)#
Set the aux batch stride. Only works if mode is set to aux related epilogues.
-
void setScalingAType(hipblasLtMatmulMatrixScale_t scalingAType)#
Only works if DataTypeA = DataTypeB = FP8.
-
void setScalingBType(hipblasLtMatmulMatrixScale_t scalingBType)#
Only works if DataTypeA = DataTypeB = FP8.
-
void setAct0(float act0)#
Set first extra argument for activation function.
-
void setAct1(float act1)#
Set second extra argument for activation function.
-
hipblasLtEpilogue_t getMode() const#
The mode of epilogue. Default is gemm.
-
hipDataType getBiasDataType() const#
The bias datatype. Only works if mode is set to bias related epilogues.
-
hipDataType getAuxDataType() const#
The aux datatype. Only works if mode is set to aux related epilogues.
-
int getAuxLeadingDimension() const#
The aux leading dimension. Only works if mode is set to aux related epilogues.
-
int getAuxBatchStride() const#
The aux batch stride. Only works if mode is set to aux related epilogues.
-
hipblasLtMatmulMatrixScale_t getScalingAType() const#
0 is scalar, 1 is vector. Only works if DataTypeA = DataTypeB = FP8.
-
hipblasLtMatmulMatrixScale_t getScalingBType() const#
0 is scalar, 1 is vector. Only works if DataTypeA = DataTypeB = FP8.
-
float getAct0()#
first extra argument for activation function.
-
float getAct1()#
second extra argument for activation function.
-
void setMode(hipblasLtEpilogue_t mode)#
GemmInputs#
-
class GemmInputs#
hipblasLt extension Inputs for gemm problems.
This class sets the input pointers of a gemm problem.
Public Functions
-
void setA(const void *a)#
Set the a matrix input pointer.
-
void setB(const void *b)#
Set the b matrix input pointer.
-
void setC(const void *c)#
Set the c matrix input pointer.
-
void setD(const void *d)#
Set the d matrix input pointer.
-
void setAlpha(const void *alpha)#
Set the alpha value.
-
void setBeta(const void *beta)#
Set the beta value.
-
void setBias(const void *bias)#
Set the bias input pointer.
-
void setScaleA(const void *scaleA)#
Set the Scale A input pointer.
-
void setScaleB(const void *scaleB)#
Set the Scale B input pointer.
-
void setScaleC(const void *scaleC)#
Set the Scale C input pointer.
-
void setScaleD(const void *scaleD)#
Set the Scale D input pointer.
-
void setScaleAux(const void *scaleAux)#
Set the Scale AUX input pointer.
-
void setScaleAlphaVec(const void *scaleAlphaVec)#
Set the scaleAlpha vector input pointer.
-
void setAux(const void *aux)#
Set the aux input pointer.
-
void setAmaxD(const void *amaxD)#
Set the AmaxD input pointer.
-
const void *getA() const#
The a matrix input pointer.
-
const void *getB() const#
The b matrix input pointer.
-
const void *getC() const#
The c matrix input pointer.
-
const void *getD() const#
The d matrix input pointer.
-
const void *getAlpha() const#
The alpha value.
-
const void *getBeta() const#
The beta value.
-
const void *getBias() const#
The bias input pointer.
-
const void *getScaleA() const#
The Scale A input pointer.
-
const void *getScaleB() const#
The Scale B input pointer.
-
const void *getScaleC() const#
The Scale C input pointer.
-
const void *getScaleD() const#
The Scale D input pointer.
-
const void *getScaleAux() const#
The Scale AUX input pointer.
-
const void *getScaleAlphaVec() const#
The scaleAlpha vector input pointer.
-
const void *getAux() const#
The aux input pointer.
-
const void *getAmaxD() const#
The AmaxD input pointer.
-
void setA(const void *a)#
hipBLASLtExt GEMM class reference#
GemmPreference#
-
class GemmPreference#
hipblasLt extension preference for gemm problems.
Currently only supports setting max workspace size.
GemmInstance#
-
class GemmInstance#
hipblasLt extension instance for gemm problems.
Subclassed by hipblaslt_ext::Gemm, hipblaslt_ext::GroupedGemm
Gemm#
-
class Gemm : public hipblaslt_ext::GemmInstance#
hipblasLt extension instance for gemm.
The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation:
D=alpha*(A*B) +beta*(C), whereA,B, andCare input matrices, andalphaandbetaare input scalars.
GroupedGemm#
-
class GroupedGemm : public hipblaslt_ext::GemmInstance#
hipblasLt extension instance for grouped gemm.
The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation:
D=alpha*(A*B) +beta*(C), whereA,B, andCare input matrices, andalphaandbetaare input scalars.Public Functions
-
hipblasStatus_t run(hipStream_t stream, hipEvent_t start = nullptr, hipEvent_t stop = nullptr)#
Execute the kernel arguments stored inside the hipblaslt_ext::GemmInstance.
- Parameters:
stream – [in] The HIP stream where all the GPU work will be
start – [in] The HIP event which will record the start of the kernel
stop – [in] The HIP event which will record the end of the kernel submitted.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
-
hipblasStatus_t run(hipStream_t stream, hipEvent_t start = nullptr, hipEvent_t stop = nullptr)#
hipBLASLtExt API reference#
getAllAlgos()#
Warning
doxygenfunction: Cannot find function “getAllAlgos” in doxygen xml output for project “hipBLASLt 1.1.0 Documentation” from directory: docs/doxygen/xml
getIndexFromAlgo()#
Warning
doxygenfunction: Cannot find function “getIndexFromAlgo” in doxygen xml output for project “hipBLASLt 1.1.0 Documentation” from directory: docs/doxygen/xml
getAlgosFromIndex()#
Warning
doxygenfunction: Cannot find function “getAlgosFromIndex” in doxygen xml output for project “hipBLASLt 1.1.0 Documentation” from directory: docs/doxygen/xml
matmulIsAlgoSupported()#
Warning
doxygenfunction: Cannot find function “matmulIsAlgoSupported” in doxygen xml output for project “hipBLASLt 1.1.0 Documentation” from directory: docs/doxygen/xml
hipblasLtExt usage#
Here are the three use cases supported by the hipBLASLtExt APIs.
GEMM#
hipblasLt has its own instance. You must assign the problem type when constructing or importing the problem from the hipBLAS API.
HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t handle,
hipblasOperation_t opA,
hipblasOperation_t opB,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
hipblasComputeType_t typeCompute);
HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t handle,
hipblasLtMatmulDesc_t matmul_descr,
const void* alpha,
const void* A,
hipblasLtMatrixLayout_t matA,
const void* B,
hipblasLtMatrixLayout_t matB,
const void* beta,
const void* C,
hipblasLtMatrixLayout_t matC,
void* D,
hipblasLtMatrixLayout_t matD);
After the instance is created, you can set the problem using the API. The API might require the following structures:
struct GemmEpilogue { hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT; hipDataType bias_data_type; int aux_ld; int aux_stride; };
setProblemAPIs:HIPBLASLT_EXPORT hipblasStatus_t setProblem( int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);
You can set the leading dimensions and strides and reassign the data type using the following API:
HIPBLASLT_EXPORT hipblasStatus_t setProblem(int64_t m,
int64_t n,
int64_t k,
int64_t batch_count,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t ldd,
int64_t strideA,
int64_t strideB,
int64_t strideC,
int64_t strideD,
GemmEpilogue& epilogue,
GemmInputs& inputs,
GemmProblemType& problemtype);
You can import problems from the hipblasLt APIs after the instance is created.
Note
This can overwrite the problem type of the instance.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(hipblasLtMatmulDesc_t matmul_descr,
const void* alpha,
const void* A,
hipblasLtMatrixLayout_t matA,
const void* B,
hipblasLtMatrixLayout_t matB,
const void* beta,
const void* C,
hipblasLtMatrixLayout_t matC,
void* D,
hipblasLtMatrixLayout_t matD);
You can retrieve heuristics and set kernel arguments with the instance. If the properties of the GEMM and the inputs don’t change, you can call the run API to launch the kernel directly.
// Pseudo code
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Default epilogue mode is HIPBLASLT_EPILOGUE_DEFAULT
hipblaslt_ext::GemmEpilogue epilogue;
hipblaslt_ext::GemmInputs inputs;
inputs.setA(d_a);
inputs.setB(d_b);
inputs.setC(d_c);
inputs.setD(d_d);
inputs.setAlpha(&alpha);
inputs.setBeta(&beta);
hipblaslt_ext::Gemm gemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic;
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
gemm.algoGetHeuristic(gemm, pref, heuristic);
gemm.initialize(heuristic[0].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
gemm.run(stream);
}
Grouped GEMM#
hipblasLtExt supports grouped GEMM. It shares the same class with normal GEMM.
After the problem is set, you can check the problem type using the function getGemmType().
enum class GemmType
{
HIPBLASLT_GEMM = 1,
HIPBLASLT_GROUPED_GEMM = 2
};
The grouped GEMM class also includes the setProblem APIs.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(
int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<int64_t>& lda,
std::vector<int64_t>& ldb,
std::vector<int64_t>& ldc,
std::vector<int64_t>& ldd,
std::vector<int64_t>& strideA,
std::vector<int64_t>& strideB,
std::vector<int64_t>& strideC,
std::vector<int64_t>& strideD,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs,
GemmProblemType& problemtype);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t>& matmul_descr,
std::vector<void*>& alpha,
std::vector<void*>& A,
std::vector<hipblasLtMatrixLayout_t>& matA,
std::vector<void*>& B,
std::vector<hipblasLtMatrixLayout_t>& matB,
std::vector<void*>& beta,
std::vector<void*>& C,
std::vector<hipblasLtMatrixLayout_t>& matC,
std::vector<void*>& D,
std::vector<hipblasLtMatrixLayout_t>& matD);
For the following API, the epilogue argument supports broadcasting to the length of the problem size
by duplicating the last element.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<int64_t>& lda,
std::vector<int64_t>& ldb,
std::vector<int64_t>& ldc,
std::vector<int64_t>& ldd,
std::vector<int64_t>& strideA,
std::vector<int64_t>& strideB,
std::vector<int64_t>& strideC,
std::vector<int64_t>& strideD,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs,
GemmProblemType& problemtype);
Note
Only a problemtype size equal to 1 is currently supported. (This means only one GemmProblemType for all problems.)
// Pseudo code
std::vector<int64_t> m, n, k;
// ...
for(size_t i = 0; i < problem_size, i++)
{
// ...
}
std::vector<GemmProblemType> problemtypes;
problemtypes.push_back(problemtype);
groupedgemm.setProblem(m, n, k, batch_count, lda, ldb, ldc, ldd, strideA, strideB, strideC, strideD, epilogue, inputs, problemtypes);
The UserArguments structure#
Grouped GEMM supports the use of external device memory to run the kernel.
This is helpful if some of the arguments are from the output of the pervious kernel.
To change the size-related arguments m, n, k, and batch, see Fixed MK.
struct UserArguments
{
uint32_t m; //!< size m
uint32_t n; //!< size n
uint32_t batch; //!< size batch
uint32_t k; //!< size k
void* d; //!< The d matrix input pointer.
void* c; //!< The c matrix input pointer.
void* a; //!< The a matrix input pointer.
void* b; //!< The b matrix input pointer.
uint32_t strideD1; //!< The d leading dimension.
uint32_t strideD2; //!< The d batch stride
uint32_t strideC1; //!< The c leading dimension.
uint32_t strideC2; //!< The c batch stride
uint32_t strideA1; //!< The a leading dimension.
uint32_t strideA2; //!< The a batch stride
uint32_t strideB1; //!< The b leading dimension.
uint32_t strideB2; //!< The b batch stride
int8_t alpha[16]; //!< The alpha value.
int8_t beta[16]; //!< The beta value.
// Epilogue inputs
void* bias; //!< The bias input pointer.
int biasType; //!< The bias datatype. Only works if mode is set to bias related epilogues.
uint32_t reserved;
void* e; //!< The aux input pointer. Only works if mode is set to aux related epilogues.
uint32_t strideE1; //!< The aux leading dimension. Only works if mode is set to aux related epilogues.
uint32_t strideE2; //!< The aux batch stride. Only works if mode is set to aux related epilogues.
float act0; //!< The activation value 1. Some activations might use it.
float act1; //!< The activation value 2.
int activationType; //!< The activation type. Only works if mode is set to activation related epilogues.
} __attribute__((packed));
hipBLASLt adds two functions to the UserArguments-related API. The first API is a helper function that helps you initialize
the UserArguments structure from the saved problems inside the grouped GEMM object.
The second API is an overload function with an additional UserArguments device pointer input.
HIPBLASLT_EXPORT hipblasStatus_t getDefaultValueForDeviceUserArguments(void* hostDeviceUserArgs);
HIPBLASLT_EXPORT hipblasStatus_t run(void* deviceUserArgs, hipStream_t stream);
Here is a simple example that shows how this API works.
// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
m[i] = 1;
n[i] = 1;
k[i] = 1;
batch_count[i] = 1;
epilogue[i].setMode(HIPBLASLT_EPILOGUE_GELU);
inputs[i].setA(d_a[i]);
inputs[i].setB(d_b[i]);
inputs[i].setC(d_c[i]);
inputs[i].setD(d_d[i]);
inputs[i].setAlpha(&alpha[i]);
inputs[i].setBeta(&beta[i]);
}
// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host
// Next copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
// Step 6: Initialize and run
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
groupedGemm.run(userArgs, stream);
}
}
The base class (GemmInstance)#
This is the base class for Gemm and GroupedGemm.
// Gets heuristic from the instance.
HIPBLASLT_EXPORT hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount,
const GemmPreference& pref,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
// Returns SUCCESS if the algo is supported, also returns the required workspace size in bytes.
HIPBLASLT_EXPORT hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t& algo, size_t& workspaceSizeInBytes);
// Initializes the instance before calling run. Requires every time the problem is set.
HIPBLASLT_EXPORT hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t& algo, void* workspace, bool useUserArgs = true, hipStream_t stream = 0);
// Run the problem.
HIPBLASLT_EXPORT hipblasStatus_t run(hipStream_t stream);
Get all algorithms#
Get all algorithms allows you to get all the algorithms for a specific problem type. It requires the transpose of A, B, the data type of the inputs, and the compute type.
HIPBLASLT_EXPORT
hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t handle,
hipblasLtExtGemmTypeEnum_t typeGemm,
hipblasOperation_t opA,
hipblasOperation_t opB,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
hipblasComputeType_t typeCompute,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
This API doesn’t require a problem size or epilogue as input. It uses another API named isAlgoSupported to check
if the algorithm supports a problem.
hipblaslt_ext::matmulIsAlgoSupported()
gemm.isAlgoSupported()
The API returns the required workspace size in bytes upon successful completion.
// Get all algorithms
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(hipblaslt_ext::matmulIsAlgoSupported(handle,
matmul,
&(alpha),
matA,
matB,
&(beta),
matC,
matD,
heuristicResult[j].algo,
workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
heuristicResult[j].workspaceSize = workspace_size;
}
else
{
heuristicResult[j].workspaceSize = 0;
}
}
Algorithm index#
This extension API lets you to get the algorithm index using hipblasLtMatmulAlgo_t.
HIPBLASLT_EXPORT int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo);
You can use an index vector to retrieve the heuristic results.
HIPBLASLT_EXPORT
hipblasStatus_t
getAlgosFromIndex(hipblasLtHandle_t handle,
std::vector<int>& algoIndex,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
Sample code#
This section contains some code samples that demonstrate the use cases of the extension APIs.
GEMM#
// Pseudo code for gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
hipblaslt_ext::GemmEpilogue epilogue;
epilogue.setMode(HIPBLASLT_EPILOGUE_GELU);
hipblaslt_ext::GemmInputs inputs;
inputs.setA(d_a);
inputs.setB(d_b);
inputs.setC(d_c);
inputs.setD(d_d);
inputs.setAlpha(&alpha);
inputs.setBeta(&beta);
hipblaslt_ext::Gemm gemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(gemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
heuristicResult[j].workspaceSize = workspace_size;
}
else
{
heuristicResult[j].workspaceSize = 0;
}
}
if(validIdx.size() > 1)
{
gemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
gemm.run(stream);
}
}
Grouped GEMM#
// Pseudo code for grouped gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
m[i] = 1;
n[i] = 1;
k[i] = 1;
batch_count[i] = 1;
epilogue[i].setMode(HIPBLASLT_EPILOGUE_GELU);
inputs[i].setA(d_a[i]);
inputs[i].setB(d_b[i]);
inputs[i].setC(d_c[i]);
inputs[i].setD(d_d[i]);
inputs[i].setAlpha(&alpha[i]);
inputs[i].setBeta(&beta[i]);
}
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
groupedGemm.run(stream);
}
}
Algorithm index#
int index = hipblaslt_ext::getIndexFromAlgo(testResults[i].algo);
// Save the index to disk or somewhere else for later use.
// Get the index from previous state.
std::vector<int> algoIndex{index};
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResults;
// If the index is out of the bound of solutions, getAlgosFromIndex will return HIPBLAS_STATUS_INVALID_VALUE
if(HIPBLAS_STATUS_INVALID_VALUE
== hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResults))
{
std::cout << "Indexes are all out of bound." << std::endl;
break;
}
[Grouped Gemm] Fixed MK#
The hipBLASLt extension supports changing the sizes (m, n, k, and batch) from the device memory UserArguments.
However, the setup is a bit different from the normal routing.
Sum of N#
A sum of N needs to be used as an input for the grouped GEMM instance.
{1000, 1, 1, 1}; // The array of N, the first element is the sum of N
// Below is the values stored in "UserArguments"
{256, 256, 1, 1}; // This is a valid configuration cause 256 + 256 + 1 + 1 < 1000
{512, 512, 1, 1}; // This is NOT a valid configuration cause 512 + 512 + 1 + 1 > 1000
For example, consider a grouped GEMM with gemm_count = 4. The sum of N must not exceed the “sum of N” set in the setProblem API.
In this mode, the first element is the “sum of N” in the array of Ns.
// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
// Step 2.1: Calculate sum of n
int64_t sum_of_n = 0;
for(int i = 0; i < gemm_count; i++)
{
sum_of_n += n_arr[i];
}
// {sum_of_n, 1, 1, 1, ...}; // The array of N, the first element is the sum of N
for(int i = 0; i < gemm_count; i++)
{
m[i] = m_arr[i];
if(i == 0)
n[i] = sum_of_n;
else
n[i] = 1;
k[i] = k_arr[i];
batch_count[i] = 1;
inputs[i].setA(d_a[i]);
inputs[i].setB(d_b[i]);
inputs[i].setC(d_c[i]);
inputs[i].setD(d_d[i]);
inputs[i].setAlpha(&alpha[i]);
inputs[i].setBeta(&beta[i]);
}
// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host
// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
int threads = 256;
int blocks = ceil((double)gemm_count / threads);
// Step 6: Initialize and run
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace);
for(int i = 0; i < 10; i++)
{
hipLaunchKernelGGL(kernelUpdateN,
dim3(blocks),
dim3(threads),
0,
stream,
gemm_count,
d_dUAFloat,
d_n_vec); // d_n_vec is a device pointer with Ns
groupedGemm.run(userArgs, stream);
}
}
// .....
__global__ void kernelUpdateN(uint32_t gemm_count, void* userArgs, int32_t* sizes_n)
{
uint64_t id = hipBlockIdx_x * 256 + hipThreadIdx_x;
if(id >= gemm_count)
return;
hipblaslt_ext::UserArguments* dUAFloat = static_cast<hipblaslt_ext::UserArguments*>(userArgs);
dUAFloat[id].n = sizes_n[id];
}