WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions src/subgraph.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "include/experimental.h"
#include "include/xnnpack.h"
#include "src/subgraph/subgraph-utils.h"
#include "src/xnnpack/allocation-type.h"
#include "src/xnnpack/allocator.h"
#include "src/xnnpack/common.h"
Expand Down Expand Up @@ -2336,6 +2337,19 @@ static float get_value_as_float(const void* data, enum xnn_datatype datatype) {
}
}

static uint32_t xnn_subgraph_new_workspace_value_like(
struct xnn_subgraph* subgraph, uint32_t value_id) {
struct xnn_value* new_value = xnn_subgraph_new_internal_value(subgraph);
const struct xnn_value* value = &subgraph->values[value_id];
const uint32_t new_value_id = new_value->id;
*new_value = *value;
new_value->id = new_value_id;
new_value->data = NULL;
new_value->allocation_type = xnn_allocation_type_workspace;
new_value->flags = 0;
return new_value_id;
}

enum xnn_status xnn_subgraph_optimize_common_subgraphs(
xnn_subgraph_t subgraph, uint32_t optimization_flags) {
// If we shouldn't change the numerics, then don't do anything.
Expand Down Expand Up @@ -2553,6 +2567,176 @@ enum xnn_status xnn_subgraph_optimize_common_subgraphs(
}
break;

case xnn_node_type_batch_matrix_multiply:
do {
// Convert `batch-matrix-multiply` nodes with 2d right-hand operands
// to `fully-connected`.
const uint32_t input_b_id = node->inputs[1];
struct xnn_value* input_b_value = &subgraph->values[input_b_id];
if (input_b_value->shape.num_dims != 2 ||
(input_b_value->datatype != xnn_datatype_fp32 &&
input_b_value->datatype != xnn_datatype_fp16)) {
break;
}
const uint32_t input_a_id = node->inputs[0];
const uint32_t output_id = node->outputs[0];
enum xnn_status status = xnn_define_fully_connected(
subgraph,
/*output_min=*/-INFINITY, /*output_max=*/INFINITY, input_a_id,
input_b_id, /*bias_id=*/XNN_INVALID_VALUE_ID, output_id,
node->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS);
if (status != xnn_status_success) {
xnn_log_error("Failed to create new `fully_connected` node.");
return status;
}
node = &subgraph->nodes[node_id];
*node = subgraph->nodes[--subgraph->num_nodes];
node->id = node_id;

xnn_log_info(
"Converted batch_matrix_multiply[#%u](v%03u, v%03u) to "
"fully_connected[#%u](v%03u, v%03u).",
node_id, input_a_id, input_b_id, node_id, input_a_id, input_b_id);
changes++;
} while (false);

XNN_FALLTHROUGH

case xnn_node_type_fully_connected:
do {
// Avoid (slow) dynamic `gio` packing by pre-transposing the
// right-hand operand.
uint32_t input_b_id = node->inputs[1];
struct xnn_value* input_b_value = &subgraph->values[input_b_id];
// Are the weights dynamic?
if (xnn_value_is_static(input_b_value->allocation_type)) {
break;
}

// Is this a `gio` packing?
if ((node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) !=
(node->type == xnn_node_type_fully_connected
? XNN_FLAG_TRANSPOSE_WEIGHTS
: 0)) {
break;
}

// If the weights are produced by a `transpose` node, swap the last
// two permutation values.
const uint32_t input_a_id = node->inputs[0];
struct xnn_node* transpose_node = NULL;
if (input_b_value->producer != XNN_INVALID_NODE_ID &&
subgraph->nodes[input_b_value->producer].type ==
xnn_node_type_static_transpose) {
transpose_node = &subgraph->nodes[input_b_value->producer];

// Swap the two last dimensions of the transpose.
const uint32_t num_dims = transpose_node->params.transpose.num_dims;
size_t* perm = transpose_node->params.transpose.perm;
size_t temp = perm[num_dims - 1];
perm[num_dims - 1] = perm[num_dims - 2];
perm[num_dims - 2] = temp;

// Check if the transpose is now a no-op.
bool is_noop = true;
for (int k = 0; k < num_dims && is_noop; k++) {
is_noop &= (perm[k] == k);
}

// Skip the transpose if it is now a no-op.
if (is_noop) {
input_b_id = transpose_node->inputs[0];
xnn_log_info(
"Skipping static_transpose[#%u](v%03u) for second input to "
"%s[#%u](v%03u, v%03u).",
transpose_node->id, transpose_node->inputs[0],
node->type == xnn_node_type_fully_connected
? "fully_connected"
: "batch_matrix_multiply",
node_id, input_a_id, input_b_id);
} else {
xnn_log_info(
"Reusing static_transpose[#%u](v%03u) to transpose second "
"input to %s[#%u](v%03u, v%03u).",
transpose_node->id, transpose_node->inputs[0],
node->type == xnn_node_type_fully_connected
? "fully_connected"
: "batch_matrix_multiply",
node_id, input_a_id, input_b_id);
}
}

// Otherwise, add a `transpose` node.
else {
// Create a workspace value for the transposed input_b.
const uint32_t transposed_input_b_id =
xnn_subgraph_new_workspace_value_like(subgraph, input_b_id);

// Create the permutation of the last two dimensions.
const uint32_t num_dims = input_b_value->shape.num_dims;
size_t perm[XNN_MAX_TENSOR_DIMS];
for (int k = 0; k + 2 < num_dims; k++) {
perm[k] = k;
}
perm[num_dims - 2] = num_dims - 1;
perm[num_dims - 1] = num_dims - 2;

// Create the `static_transpose` node.
enum xnn_status status = xnn_define_static_transpose(
subgraph, num_dims, perm, input_b_id, transposed_input_b_id,
/*flags=*/0);
if (status != xnn_status_success) {
xnn_log_error("Failed to create new `static_transpose` node.");
return status;
}
node = &subgraph->nodes[node_id];
transpose_node = &subgraph->nodes[subgraph->num_nodes - 1];

// Set the new input_b value id.
input_b_id = transposed_input_b_id;
}

// Flip the "transpose" flag.
const uint32_t bias_id =
node->num_inputs > 2 ? node->inputs[2] : XNN_INVALID_VALUE_ID;
const uint32_t output_id = node->outputs[0];
enum xnn_status status =
node->type == xnn_node_type_fully_connected
? xnn_define_fully_connected(
subgraph,
/*output_min=*/node->activation.output_min,
/*output_max=*/node->activation.output_max, input_a_id,
input_b_id, bias_id, output_id,
node->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS)
: xnn_define_batch_matrix_multiply(
subgraph, input_a_id, input_b_id, output_id,
node->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS);
if (status != xnn_status_success) {
xnn_log_error("Failed to create new `%s` node.",
node->type == xnn_node_type_fully_connected
? "fully_connected"
: "batch_matrix_multiply");
return status;
}
node = &subgraph->nodes[node_id];
*node = subgraph->nodes[--subgraph->num_nodes];
node->id = node_id;

xnn_log_info(
"Converted %s[#%u](v%03u, v%03u) to "
"%s[#%u](v%03u, static_transpose[#%i](v%03u)).",
node->type == xnn_node_type_fully_connected
? "fully_connected"
: "batch_matrix_multiply",
node_id, input_a_id, input_b_id,
node->type == xnn_node_type_fully_connected
? "fully_connected"
: "batch_matrix_multiply",
node_id, input_a_id, transpose_node->id, input_b_id);
changes++;
} while (false);
break;

default:
break;
}
Expand All @@ -2561,6 +2745,7 @@ enum xnn_status xnn_subgraph_optimize_common_subgraphs(
// Clean up after ourselves.
if (changes) {
xnn_subgraph_clean_up(subgraph);
xnn_subgraph_log_info(subgraph);
}

return xnn_status_success;
Expand Down
7 changes: 7 additions & 0 deletions src/subgraph/subgraph-utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ void xnn_subgraph_log_impl(const char* filename, size_t line_number,
xnn_datatype_to_string(subgraph->values[node->outputs[0]].datatype),
xnn_datatype_to_string(subgraph->values[node->inputs[1]].datatype));
break;
case xnn_node_type_static_transpose:
fprintf(out, " (perm=[%zu", node->params.transpose.perm[0]);
for (int i = 1; i < node->params.transpose.num_dims; i++) {
fprintf(out, ", %zu", node->params.transpose.perm[i]);
}
fprintf(out, "])");
break;
default:
break;
}
Expand Down
Loading