DeviceNormalizationBwdDataImpl< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDxFastestDimReduced, DXDstVectorSize > Struct Template Reference#
Classes |
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDxFastestDimReduced, DXDstVectorSize > Struct Template Reference
#include <device_normalization_bwd_data_impl.hpp>
Inheritance diagram for ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDxFastestDimReduced, DXDstVectorSize >:
Classes | |
| struct | Argument |
| struct | Invoker |
Public Types | |
| using | GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1)) |
| using | GridwiseNormalizationBwdDataGeneric |
| using | GridwiseNormalizationBwdDataSweepOnce |
Public Member Functions | |
| template<index_t SrcVectorDim, index_t SrcVectorSize> | |
| bool | IsVectorDimSizeValid (const std::vector< index_t > &lengths, const std::vector< index_t > &strides) |
| bool | IsSupportedArgument (const BaseArgument *p_arg) override |
| std::unique_ptr< BaseArgument > | MakeArgumentPointer (const std::vector< index_t > lengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > dxStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_gamma, const void *p_mean, const void *p_invStd, void *p_dx) override |
| virtual std::unique_ptr< BaseInvoker > | MakeInvokerPointer () override |
| std::string | GetTypeString () const override |
| Public Member Functions inherited from ck::tensor_operation::device::BaseOperator | |
| BaseOperator ()=default | |
| BaseOperator (const BaseOperator &)=default | |
| BaseOperator & | operator= (const BaseOperator &)=default |
| virtual std::string | GetInstanceString () const |
| virtual std::string | GetTypeIdName () const |
| virtual std::optional< std::string > | GetObjectName () const |
| virtual std::optional< std::string > | GetTemplateInfo () const |
| virtual std::string | GetTypeIdHashCode () const |
| virtual size_t | GetWorkSpaceSize (const BaseArgument *) const |
| virtual void | SetWorkSpacePointer (BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const |
| virtual | ~BaseOperator () |
Static Public Member Functions | |
| static auto | Make2dDescriptor (const std::vector< index_t > &lengths, const std::vector< index_t > &strides, int numBlockTileIteration) |
Static Public Attributes | |
| static constexpr index_t | DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0 |
| static constexpr index_t | XSrcVectorDim = IsXFastestDimReduced ? 1 : 0 |
| static constexpr index_t | GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0 |
| static constexpr index_t | MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0 |
| static constexpr index_t | DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0 |
| static constexpr index_t | NumInvariantDim = Rank - NumReduceDim |
| static constexpr index_t | M_BlockTileSize = MThreadClusterSize * MThreadSliceSize |
| static constexpr index_t | K_BlockTileSize = KThreadClusterSize * KThreadSliceSize |
| static constexpr bool | reduceAllDim = (NumInvariantDim == 0) |
Member Typedef Documentation
◆ GridDesc_M_K
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
| using ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDxFastestDimReduced, DXDstVectorSize >::GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1)) |
◆ GridwiseNormalizationBwdDataGeneric
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
| using ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDxFastestDimReduced, DXDstVectorSize >::GridwiseNormalizationBwdDataGeneric |
Initial value:
GridwiseNormalizationBwdData_mk_to_mk<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
DYSrcVectorSize,
XSrcVectorSize,
GammaSrcVectorSize,
MeanInvStdSrcVectorSize,
DXDstVectorSize,
false>
Definition gridwise_normalization_bwd_data.hpp:49
static constexpr index_t MeanInvStdSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:92
static constexpr index_t XSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:90
static constexpr index_t DXDstVectorDim
Definition device_normalization_bwd_data_impl.hpp:93
static constexpr index_t DYSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:89
decltype(Make2dDescriptor({1}, {1}, 1)) GridDesc_M_K
Definition device_normalization_bwd_data_impl.hpp:169
static constexpr index_t GammaSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:91
◆ GridwiseNormalizationBwdDataSweepOnce
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
| using ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDxFastestDimReduced, DXDstVectorSize >::GridwiseNormalizationBwdDataSweepOnce |
Initial value:
GridwiseNormalizationBwdData_mk_to_mk<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
DYSrcVectorSize,
XSrcVectorSize,
GammaSrcVectorSize,
MeanInvStdSrcVectorSize,
DXDstVectorSize,
true>
Member Function Documentation
◆ GetTypeString()
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ IsSupportedArgument()
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ IsVectorDimSizeValid()
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
inline |
◆ Make2dDescriptor()
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
inlinestatic |
◆ MakeArgumentPointer()
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
inlineoverridevirtual |
◆ MakeInvokerPointer()
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
inlineoverridevirtual |
Member Data Documentation
◆ DXDstVectorDim
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ DYSrcVectorDim
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ GammaSrcVectorDim
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ K_BlockTileSize
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ M_BlockTileSize
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ MeanInvStdSrcVectorDim
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ NumInvariantDim
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ reduceAllDim
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
◆ XSrcVectorDim
template<typename DYDataType, typename XDataType, typename GammaDataType, typename MeanInvStdDataType, typename ComputeDataType, typename DXDataType, index_t Rank, index_t NumReduceDim, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, bool IsDYFastestDimReduced, index_t DYSrcVectorSize, bool IsXFastestDimReduced, index_t XSrcVectorSize, bool IsGammaFastestDimReduced, index_t GammaSrcVectorSize, bool IsMeanInvStdFastestDimReduced, index_t MeanInvStdSrcVectorSize, bool IsDxFastestDimReduced, index_t DXDstVectorSize>
|
staticconstexpr |
The documentation for this struct was generated from the following file: