@@ -47,6 +47,15 @@ ONNX_OPERATOR_KERNEL_EX(
4747 .TypeConstraint(" T" , WebGpuSupportedNumberTypes()),
4848 Transpose);
4949
50+ Status OIHW2OHWIProgram::GenerateShaderCode (ShaderHelper& shader) const {
51+ const auto & src = shader.AddInput (" src" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
52+ const auto & output = shader.AddOutput (" output" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
53+
54+ return WGSL_TEMPLATE_APPLY (shader, " tensor/oihw_to_ohwi.wgsl.template" ,
55+ WGSL_TEMPLATE_VARIABLE (output, output),
56+ WGSL_TEMPLATE_VARIABLE (src, src));
57+ }
58+
5059auto SqueezeShape (const gsl::span<const int64_t >& shape,
5160 const gsl::span<const size_t >& adjusted_perm,
5261 TensorShapeVector& new_shape,
@@ -106,12 +115,52 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
106115 const auto & input_shape = input.Shape ();
107116 const auto & input_dims = input_shape.GetDims ();
108117 int32_t rank = static_cast <int32_t >(input_shape.NumDimensions ());
109-
110118 TensorShapeVector output_dims (rank);
111119
112120 for (int32_t i = 0 ; i < rank; i++) {
113121 output_dims[i] = input_dims[permutations[i]];
114122 }
123+ TensorShape output_shape (output_dims);
124+
125+ // Check if `OIHW2OHWIProgram` can be applied.
126+ //
127+ // `OIHW2OHWIProgram` was originally designed to transpose 4D weights from OIHW
128+ // to OHWI format, utilizing workgroup tiling to maximize bandwidth through
129+ // coalesced reads and writes. While variable names reflect this origin for
130+ // simplicity, the shader is now generalized for broader use, supporting any
131+ // permutation equivalent to {0, 2, 3, 1}.
132+ //
133+ // TODO: Extend support to 2D and 3D transpositions.
134+ if (permutations == gsl::span<const size_t >{{0 , 2 , 3 , 1 }}) {
135+ const uint32_t channel_output = onnxruntime::narrow<uint32_t >(input_shape[0 ]);
136+ const uint32_t channel_input = onnxruntime::narrow<uint32_t >(input_shape[1 ]);
137+ const uint32_t kernel_height = onnxruntime::narrow<uint32_t >(input_shape[2 ]);
138+ const uint32_t kernel_width = onnxruntime::narrow<uint32_t >(input_shape[3 ]);
139+
140+ // Calculate tiling for the input channel dimension (tiled by 64)
141+ const uint32_t input_channel_tiles = CeilDiv (channel_input, 64u );
142+ const uint32_t dispatch_size = channel_output * input_channel_tiles;
143+
144+ // Threshold check: Only apply if the workload is large enough to saturate
145+ // GPU compute units. For small tensors, the overhead of the transpose
146+ // outweighs the gain.
147+ if (dispatch_size >= 128u ) {
148+ OIHW2OHWIProgram transpose_program{};
149+ transpose_program.SetWorkgroupSize (64 );
150+ transpose_program.SetDispatchGroupSize (dispatch_size);
151+ transpose_program.AddInput ({&input,
152+ ProgramTensorMetadataDependency::TypeAndRank});
153+ transpose_program.AddOutput ({&output,
154+ ProgramTensorMetadataDependency::TypeAndRank});
155+ transpose_program.AddUniformVariables ({{channel_output},
156+ {channel_input},
157+ {kernel_height},
158+ {kernel_width},
159+ {input_channel_tiles},
160+ {CeilDiv (kernel_height * kernel_width, 4u )}});
161+ return context.RunProgram (transpose_program);
162+ }
163+ }
115164
116165 TensorShapeVector new_shape{};
117166 TensorShapeVector new_perm{};
@@ -120,15 +169,14 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
120169 const bool channels_first = new_perm == TensorShapeVector ({3 , 1 , 2 });
121170 const bool use_shared = (new_shape.size () == 2 && new_perm[0 ] > new_perm[1 ]) || channels_last || channels_first;
122171 auto new_input_shape = input_shape;
123- TensorShape new_output_shape (output_dims);
124172
125173 if (use_shared) {
126174 new_input_shape = channels_last
127175 ? TensorShape ({new_shape[0 ], new_shape[1 ] * new_shape[2 ]})
128176 : channels_first
129177 ? TensorShape ({new_shape[0 ] * new_shape[1 ], new_shape[2 ]})
130178 : new_shape;
131- new_output_shape = TensorShape ({new_input_shape[1 ], new_input_shape[0 ]});
179+ output_shape = TensorShape ({new_input_shape[1 ], new_input_shape[0 ]});
132180 }
133181
134182 uint32_t output_size = onnxruntime::narrow<uint32_t >(input_shape.Size ());
@@ -137,13 +185,13 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
137185 program
138186 .CacheHint (absl::StrJoin (permutations, " -" ))
139187 .AddInputs ({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1 }})
140- .AddOutputs ({{&output, ProgramTensorMetadataDependency::None, new_output_shape , 1 }})
188+ .AddOutputs ({{&output, ProgramTensorMetadataDependency::None, output_shape , 1 }})
141189 .AddUniformVariables ({{output_size}});
142190
143191 if (use_shared) {
144192 program.SetWorkgroupSize (TILE_SIZE, TILE_SIZE, 1 );
145- program.SetDispatchGroupSize (static_cast <uint32_t >((new_output_shape [1 ] + TILE_SIZE - 1 ) / TILE_SIZE),
146- static_cast <uint32_t >(((new_output_shape [0 ] + TILE_SIZE - 1 ) / TILE_SIZE)));
193+ program.SetDispatchGroupSize (static_cast <uint32_t >((output_shape [1 ] + TILE_SIZE - 1 ) / TILE_SIZE),
194+ static_cast <uint32_t >(((output_shape [0 ] + TILE_SIZE - 1 ) / TILE_SIZE)));
147195 } else {
148196 program.SetWorkgroupSize (64u );
149197
0 commit comments