Fixing dse function

This commit is contained in:
Alexios Maras
2024-07-26 18:43:46 +03:00
committed by GitHub
parent 5224136830
commit 66c473fed5
+12 -12
View File
@@ -474,10 +474,10 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
w = weights_per_layer[mid] w = weights_per_layer[mid]
f_w = [] f_w = []
for i in range(len(seq_counts)): for j in range(len(seq_counts)):
t_w = w[i] t_w = w[j]
c,l = seq_counts[i] c,l = seq_counts[j]
for j in range(c+l): for _ in range(c+l):
f_w.append(t_w) f_w.append(t_w)
if(len(seq_counts) > 0): if(len(seq_counts) > 0):
@@ -489,9 +489,9 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
quant_net = quant_net.to(device) quant_net = quant_net.to(device)
print(f'==========================\nEvaluating Configuration: {mid} --> Weights: {w}') print(f'==========================\nEvaluating Configuration: {mid} --> Weights: {w}')
for i in range(len(epochs)): for k in range(len(epochs)):
quant_net = train_quant_model(quant_net, train_loader, val_loader, device, quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
epochs = epochs[i], lr = lr[i]) epochs = epochs[k], lr = lr[k])
# Evaluate the trained quantized network # Evaluate the trained quantized network
accuracy = quant_net_evaluation(quant_net, test_loader, device) accuracy = quant_net_evaluation(quant_net, test_loader, device)
@@ -518,10 +518,10 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
test_accuracy = [] test_accuracy = []
for i, w in enumerate(weights_per_layer): for i, w in enumerate(weights_per_layer):
f_w = [] f_w = []
for i in range(len(seq_counts)): for j in range(len(seq_counts)):
t_w = w[i] t_w = w[j]
c,l = seq_counts[i] c,l = seq_counts[j]
for j in range(c+l): for _ in range(c+l):
f_w.append(t_w) f_w.append(t_w)
if(len(seq_counts) > 0): if(len(seq_counts) > 0):
@@ -531,9 +531,9 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
quant_net = Quant_Model(og_model, w, layer_mapping, sign) quant_net = Quant_Model(og_model, w, layer_mapping, sign)
quant_net = quant_net.to(device) quant_net = quant_net.to(device)
print(f'===================================\nModel No {i} --> {w}') print(f'===================================\nModel No {i} --> {w}')
for i in range(len(epochs)): for k in range(len(epochs)):
quant_net = train_quant_model(quant_net, train_loader, val_loader, device, quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
epochs = epochs[i], lr = lr[i]) epochs = epochs[k], lr = lr[k])
accuracy = quant_net_evaluation(quant_net, test_loader, device) accuracy = quant_net_evaluation(quant_net, test_loader, device)
test_accuracy.append(accuracy) test_accuracy.append(accuracy)