Skip to content

Commit be0ea7b

Browse files
committed
Optimize 4D Transpose
1 parent 2573346 commit be0ea7b

File tree

5 files changed

+75
-53
lines changed

5 files changed

+75
-53
lines changed

onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "core/providers/webgpu/webgpu_utils.h"
88
#include "core/providers/webgpu/nn/im2col_matmul.h"
9+
#include "core/providers/webgpu/nn/conv.h"
910
#include "core/providers/webgpu/nn/activation_util.h"
1011

1112
namespace onnxruntime {
@@ -52,15 +53,6 @@ bool IsDeviceSupported(const ComputeContextBase& context) {
5253

5354
} // namespace
5455

55-
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
56-
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
57-
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
58-
59-
return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
60-
WGSL_TEMPLATE_VARIABLE(output, output),
61-
WGSL_TEMPLATE_VARIABLE(src, src));
62-
}
63-
6456
Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
6557
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
6658
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
@@ -93,34 +85,16 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context,
9385
const bool has_bias = context.InputCount() > 2;
9486
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
9587

96-
// Transpose OIHW Weight to OHWI
97-
// TODO: Move to `Transpose`
98-
// TODO: Use prepack
9988
TensorShape weight_shape = weight->Shape();
10089
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(weight_shape[0]);
10190
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
10291
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]);
10392
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]);
10493

105-
TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input};
106-
Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape);
107-
OIHW2OHWIProgram transpose_program{};
108-
transpose_program.SetWorkgroupSize(64);
109-
110-
const uint32_t Ci_tiles = CeilDiv(channel_input, 64u);
111-
transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles);
112-
113-
transpose_program.AddInput({weight,
114-
ProgramTensorMetadataDependency::TypeAndRank});
115-
transpose_program.AddOutput({&ohwi_weight,
116-
ProgramTensorMetadataDependency::TypeAndRank});
117-
transpose_program.AddUniformVariables({{channel_output},
118-
{channel_input},
119-
{kernel_height},
120-
{kernel_width},
121-
{Ci_tiles},
122-
{CeilDiv(kernel_height * kernel_height, 4u)}});
123-
ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program));
94+
// Transpose OIHW Weight to OHWI
95+
// TODO: Use prepack
96+
Tensor ohwi_weight;
97+
ORT_RETURN_IF_ERROR(TransposeKernel(context, weight, weight->Shape(), &ohwi_weight, {0, 2, 3, 1}));
12498

12599
// im2col-matmul
126100
const TensorShape src_shape = src->Shape();

onnxruntime/core/providers/webgpu/nn/im2col_matmul.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,6 @@
1818
namespace onnxruntime {
1919
namespace webgpu {
2020

21-
// Transpose OIHW Weight to OHWI
22-
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
23-
public:
24-
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
25-
26-
Status GenerateShaderCode(ShaderHelper& shader) const override;
27-
28-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
29-
{"O", ProgramUniformVariableDataType::Uint32},
30-
{"I", ProgramUniformVariableDataType::Uint32},
31-
{"H", ProgramUniformVariableDataType::Uint32},
32-
{"W", ProgramUniformVariableDataType::Uint32},
33-
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
34-
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
35-
};
36-
3721
class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
3822
public:
3923
Im2ColMatMulProgram(bool has_bias,

onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template renamed to onnxruntime/core/providers/webgpu/tensor/oihw_to_ohwi.wgsl.template

File renamed without changes.

onnxruntime/core/providers/webgpu/tensor/transpose.cc

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ ONNX_OPERATOR_KERNEL_EX(
4747
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
4848
Transpose);
4949

50+
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
51+
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
52+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
53+
54+
return WGSL_TEMPLATE_APPLY(shader, "tensor/oihw_to_ohwi.wgsl.template",
55+
WGSL_TEMPLATE_VARIABLE(output, output),
56+
WGSL_TEMPLATE_VARIABLE(src, src));
57+
}
58+
5059
auto SqueezeShape(const gsl::span<const int64_t>& shape,
5160
const gsl::span<const size_t>& adjusted_perm,
5261
TensorShapeVector& new_shape,
@@ -106,12 +115,52 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
106115
const auto& input_shape = input.Shape();
107116
const auto& input_dims = input_shape.GetDims();
108117
int32_t rank = static_cast<int32_t>(input_shape.NumDimensions());
109-
110118
TensorShapeVector output_dims(rank);
111119

112120
for (int32_t i = 0; i < rank; i++) {
113121
output_dims[i] = input_dims[permutations[i]];
114122
}
123+
TensorShape output_shape(output_dims);
124+
125+
// Check if `OIHW2OHWIProgram` can be applied.
126+
//
127+
// `OIHW2OHWIProgram` was originally designed to transpose 4D weights from OIHW
128+
// to OHWI format, utilizing workgroup tiling to maximize bandwidth through
129+
// coalesced reads and writes. While variable names reflect this origin for
130+
// simplicity, the shader is now generalized for broader use, supporting any
131+
// permutation equivalent to {0, 2, 3, 1}.
132+
//
133+
// TODO: Extend support to 2D and 3D transpositions.
134+
if (permutations == gsl::span<const size_t>{{0, 2, 3, 1}}) {
135+
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(input_shape[0]);
136+
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(input_shape[1]);
137+
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(input_shape[2]);
138+
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(input_shape[3]);
139+
140+
// Calculate tiling for the input channel dimension (tiled by 64)
141+
const uint32_t input_channel_tiles = CeilDiv(channel_input, 64u);
142+
const uint32_t dispatch_size = channel_output * input_channel_tiles;
143+
144+
// Threshold check: Only apply if the workload is large enough to saturate
145+
// GPU compute units. For small tensors, the overhead of the transpose
146+
// outweighs the gain.
147+
if (dispatch_size >= 128u) {
148+
OIHW2OHWIProgram transpose_program{};
149+
transpose_program.SetWorkgroupSize(64);
150+
transpose_program.SetDispatchGroupSize(dispatch_size);
151+
transpose_program.AddInput({&input,
152+
ProgramTensorMetadataDependency::TypeAndRank});
153+
transpose_program.AddOutput({&output,
154+
ProgramTensorMetadataDependency::TypeAndRank});
155+
transpose_program.AddUniformVariables({{channel_output},
156+
{channel_input},
157+
{kernel_height},
158+
{kernel_width},
159+
{input_channel_tiles},
160+
{CeilDiv(kernel_height * kernel_width, 4u)}});
161+
return context.RunProgram(transpose_program);
162+
}
163+
}
115164

116165
TensorShapeVector new_shape{};
117166
TensorShapeVector new_perm{};
@@ -120,15 +169,14 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
120169
const bool channels_first = new_perm == TensorShapeVector({3, 1, 2});
121170
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
122171
auto new_input_shape = input_shape;
123-
TensorShape new_output_shape(output_dims);
124172

125173
if (use_shared) {
126174
new_input_shape = channels_last
127175
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
128176
: channels_first
129177
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
130178
: new_shape;
131-
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
179+
output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
132180
}
133181

134182
uint32_t output_size = onnxruntime::narrow<uint32_t>(input_shape.Size());
@@ -137,13 +185,13 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
137185
program
138186
.CacheHint(absl::StrJoin(permutations, "-"))
139187
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
140-
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
188+
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, output_shape, 1}})
141189
.AddUniformVariables({{output_size}});
142190

143191
if (use_shared) {
144192
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
145-
program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
146-
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)));
193+
program.SetDispatchGroupSize(static_cast<uint32_t>((output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
194+
static_cast<uint32_t>(((output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)));
147195
} else {
148196
program.SetWorkgroupSize(64u);
149197

onnxruntime/core/providers/webgpu/tensor/transpose.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
namespace onnxruntime {
1212
namespace webgpu {
1313

14+
// Transpose OIHW Weight to OHWI
15+
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
16+
public:
17+
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
18+
19+
Status GenerateShaderCode(ShaderHelper& shader) const override;
20+
21+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
22+
{"O", ProgramUniformVariableDataType::Uint32},
23+
{"I", ProgramUniformVariableDataType::Uint32},
24+
{"H", ProgramUniformVariableDataType::Uint32},
25+
{"W", ProgramUniformVariableDataType::Uint32},
26+
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
27+
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
28+
};
29+
1430
class Transpose final : public WebGpuKernel, public TransposeBase {
1531
public:
1632
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {

0 commit comments

Comments
 (0)