diff --git a/mpq/configure_ibex.py b/mpq/configure_ibex.py index 876c821..4e77d06 100644 --- a/mpq/configure_ibex.py +++ b/mpq/configure_ibex.py @@ -866,7 +866,7 @@ def generate_opt_c_code_mlp(path, name, int_weights, optimal_config, type_of_lay f.write('#include "ibex_inputs.h"\n\n') f.write('#define IN_DIM ' + str((8//optimal_config[0]) * int_weights[0].shape[1])) for i in range(1, len(int_weights)): - f.write('\n#define HIDDEN_DIM' + str(i) + ' ' + str(4 * int_weights[i].shape[1])) + f.write('\n#define HIDDEN_DIM' + str(i) + ' ' + str(4 * int_weights[i-1].shape[0])) f.write('\n#define OUT_DIM ' + str(4 * int_weights[-1].shape[0])) f.write('\n#define SAMPLES 1\n\n')