Fixing dse function
This commit is contained in:
+12
-12
@@ -474,10 +474,10 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
||||
w = weights_per_layer[mid]
|
||||
|
||||
f_w = []
|
||||
for i in range(len(seq_counts)):
|
||||
t_w = w[i]
|
||||
c,l = seq_counts[i]
|
||||
for j in range(c+l):
|
||||
for j in range(len(seq_counts)):
|
||||
t_w = w[j]
|
||||
c,l = seq_counts[j]
|
||||
for _ in range(c+l):
|
||||
f_w.append(t_w)
|
||||
|
||||
if(len(seq_counts) > 0):
|
||||
@@ -489,9 +489,9 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
||||
quant_net = quant_net.to(device)
|
||||
print(f'==========================\nEvaluating Configuration: {mid} --> Weights: {w}')
|
||||
|
||||
for i in range(len(epochs)):
|
||||
for k in range(len(epochs)):
|
||||
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
|
||||
epochs = epochs[i], lr = lr[i])
|
||||
epochs = epochs[k], lr = lr[k])
|
||||
|
||||
# Evaluate the trained quantized network
|
||||
accuracy = quant_net_evaluation(quant_net, test_loader, device)
|
||||
@@ -518,10 +518,10 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
||||
test_accuracy = []
|
||||
for i, w in enumerate(weights_per_layer):
|
||||
f_w = []
|
||||
for i in range(len(seq_counts)):
|
||||
t_w = w[i]
|
||||
c,l = seq_counts[i]
|
||||
for j in range(c+l):
|
||||
for j in range(len(seq_counts)):
|
||||
t_w = w[j]
|
||||
c,l = seq_counts[j]
|
||||
for _ in range(c+l):
|
||||
f_w.append(t_w)
|
||||
|
||||
if(len(seq_counts) > 0):
|
||||
@@ -531,9 +531,9 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
||||
quant_net = Quant_Model(og_model, w, layer_mapping, sign)
|
||||
quant_net = quant_net.to(device)
|
||||
print(f'===================================\nModel No {i} --> {w}')
|
||||
for i in range(len(epochs)):
|
||||
for k in range(len(epochs)):
|
||||
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
|
||||
epochs = epochs[i], lr = lr[i])
|
||||
epochs = epochs[k], lr = lr[k])
|
||||
accuracy = quant_net_evaluation(quant_net, test_loader, device)
|
||||
test_accuracy.append(accuracy)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user