File tree Expand file tree Collapse file tree 5 files changed +26
-25
lines changed
onnxruntime/core/providers/webgpu Expand file tree Collapse file tree 5 files changed +26
-25
lines changed Original file line number Diff line number Diff line change 55#include < vector>
66
77#include " core/providers/webgpu/webgpu_utils.h"
8+ #include " core/providers/webgpu/nn/im2col_matmul.h"
89#include " core/providers/webgpu/nn/conv.h"
910#include " core/providers/webgpu/nn/activation_util.h"
1011
@@ -52,15 +53,6 @@ bool IsDeviceSupported(const ComputeContextBase& context) {
5253
5354} // namespace
5455
55- Status OIHW2OHWIProgram::GenerateShaderCode (ShaderHelper& shader) const {
56- const auto & src = shader.AddInput (" src" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
57- const auto & output = shader.AddOutput (" output" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
58-
59- return WGSL_TEMPLATE_APPLY (shader, " nn/oihw_to_ohwi.wgsl.template" ,
60- WGSL_TEMPLATE_VARIABLE (output, output),
61- WGSL_TEMPLATE_VARIABLE (src, src));
62- }
63-
6456Status Im2ColMatMulProgram::GenerateShaderCode (ShaderHelper& shader) const {
6557 const auto & src = shader.AddInput (" src" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
6658 const auto & weight = shader.AddInput (" weight" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
Original file line number Diff line number Diff line change 1818namespace onnxruntime {
1919namespace webgpu {
2020
21- // Transpose OIHW Weight to OHWI
22- class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
23- public:
24- OIHW2OHWIProgram () : Program(" OIHW2OHWI" ) {}
25-
26- Status GenerateShaderCode (ShaderHelper& shader) const override ;
27-
28- WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
29- {" O" , ProgramUniformVariableDataType::Uint32},
30- {" I" , ProgramUniformVariableDataType::Uint32},
31- {" H" , ProgramUniformVariableDataType::Uint32},
32- {" W" , ProgramUniformVariableDataType::Uint32},
33- {" Ci_tiles" , ProgramUniformVariableDataType::Uint32},
34- {" H_W_tiles" , ProgramUniformVariableDataType::Uint32});
35- };
36-
3721class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
3822 public:
3923 Im2ColMatMulProgram (bool has_bias,
File renamed without changes.
Original file line number Diff line number Diff line change @@ -48,6 +48,15 @@ ONNX_OPERATOR_KERNEL_EX(
4848 .TypeConstraint(" T" , WebGpuSupportedNumberTypes()),
4949 Transpose);
5050
51+ Status OIHW2OHWIProgram::GenerateShaderCode (ShaderHelper& shader) const {
52+ const auto & src = shader.AddInput (" src" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
53+ const auto & output = shader.AddOutput (" output" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
54+
55+ return WGSL_TEMPLATE_APPLY (shader, " tensor/oihw_to_ohwi.wgsl.template" ,
56+ WGSL_TEMPLATE_VARIABLE (output, output),
57+ WGSL_TEMPLATE_VARIABLE (src, src));
58+ }
59+
5160auto SqueezeShape (const gsl::span<const int64_t >& shape,
5261 const gsl::span<const size_t >& adjusted_perm,
5362 TensorShapeVector& new_shape,
Original file line number Diff line number Diff line change 1111namespace onnxruntime {
1212namespace webgpu {
1313
14+ // Transpose OIHW Weight to OHWI
15+ class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
16+ public:
17+ OIHW2OHWIProgram () : Program(" OIHW2OHWI" ) {}
18+
19+ Status GenerateShaderCode (ShaderHelper& shader) const override ;
20+
21+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
22+ {" O" , ProgramUniformVariableDataType::Uint32},
23+ {" I" , ProgramUniformVariableDataType::Uint32},
24+ {" H" , ProgramUniformVariableDataType::Uint32},
25+ {" W" , ProgramUniformVariableDataType::Uint32},
26+ {" Ci_tiles" , ProgramUniformVariableDataType::Uint32},
27+ {" H_W_tiles" , ProgramUniformVariableDataType::Uint32});
28+ };
29+
1430class Transpose final : public WebGpuKernel, public TransposeBase {
1531 public:
1632 Transpose (const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
You can’t perform that action at this time.
0 commit comments