Skip to content

Commit 1b8e605

Browse files
committed
Update
1 parent 6a3823f commit 1b8e605

File tree

5 files changed

+26
-25
lines changed

5 files changed

+26
-25
lines changed

onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66

77
#include "core/providers/webgpu/webgpu_utils.h"
8+
#include "core/providers/webgpu/nn/im2col_matmul.h"
89
#include "core/providers/webgpu/nn/conv.h"
910
#include "core/providers/webgpu/nn/activation_util.h"
1011

@@ -52,15 +53,6 @@ bool IsDeviceSupported(const ComputeContextBase& context) {
5253

5354
} // namespace
5455

55-
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
56-
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
57-
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
58-
59-
return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
60-
WGSL_TEMPLATE_VARIABLE(output, output),
61-
WGSL_TEMPLATE_VARIABLE(src, src));
62-
}
63-
6456
Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
6557
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
6658
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

onnxruntime/core/providers/webgpu/nn/im2col_matmul.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,6 @@
1818
namespace onnxruntime {
1919
namespace webgpu {
2020

21-
// Transpose OIHW Weight to OHWI
22-
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
23-
public:
24-
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
25-
26-
Status GenerateShaderCode(ShaderHelper& shader) const override;
27-
28-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
29-
{"O", ProgramUniformVariableDataType::Uint32},
30-
{"I", ProgramUniformVariableDataType::Uint32},
31-
{"H", ProgramUniformVariableDataType::Uint32},
32-
{"W", ProgramUniformVariableDataType::Uint32},
33-
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
34-
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
35-
};
36-
3721
class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
3822
public:
3923
Im2ColMatMulProgram(bool has_bias,

onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template renamed to onnxruntime/core/providers/webgpu/tensor/oihw_to_ohwi.wgsl.template

File renamed without changes.

onnxruntime/core/providers/webgpu/tensor/transpose.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ ONNX_OPERATOR_KERNEL_EX(
4848
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
4949
Transpose);
5050

51+
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
52+
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
53+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
54+
55+
return WGSL_TEMPLATE_APPLY(shader, "tensor/oihw_to_ohwi.wgsl.template",
56+
WGSL_TEMPLATE_VARIABLE(output, output),
57+
WGSL_TEMPLATE_VARIABLE(src, src));
58+
}
59+
5160
auto SqueezeShape(const gsl::span<const int64_t>& shape,
5261
const gsl::span<const size_t>& adjusted_perm,
5362
TensorShapeVector& new_shape,

onnxruntime/core/providers/webgpu/tensor/transpose.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
namespace onnxruntime {
1212
namespace webgpu {
1313

14+
// Transpose OIHW Weight to OHWI
15+
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
16+
public:
17+
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
18+
19+
Status GenerateShaderCode(ShaderHelper& shader) const override;
20+
21+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
22+
{"O", ProgramUniformVariableDataType::Uint32},
23+
{"I", ProgramUniformVariableDataType::Uint32},
24+
{"H", ProgramUniformVariableDataType::Uint32},
25+
{"W", ProgramUniformVariableDataType::Uint32},
26+
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
27+
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
28+
};
29+
1430
class Transpose final : public WebGpuKernel, public TransposeBase {
1531
public:
1632
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {

0 commit comments

Comments
 (0)