Fixing dse function
This commit is contained in:
+12
-12
@@ -474,10 +474,10 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
|||||||
w = weights_per_layer[mid]
|
w = weights_per_layer[mid]
|
||||||
|
|
||||||
f_w = []
|
f_w = []
|
||||||
for i in range(len(seq_counts)):
|
for j in range(len(seq_counts)):
|
||||||
t_w = w[i]
|
t_w = w[j]
|
||||||
c,l = seq_counts[i]
|
c,l = seq_counts[j]
|
||||||
for j in range(c+l):
|
for _ in range(c+l):
|
||||||
f_w.append(t_w)
|
f_w.append(t_w)
|
||||||
|
|
||||||
if(len(seq_counts) > 0):
|
if(len(seq_counts) > 0):
|
||||||
@@ -489,9 +489,9 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
|||||||
quant_net = quant_net.to(device)
|
quant_net = quant_net.to(device)
|
||||||
print(f'==========================\nEvaluating Configuration: {mid} --> Weights: {w}')
|
print(f'==========================\nEvaluating Configuration: {mid} --> Weights: {w}')
|
||||||
|
|
||||||
for i in range(len(epochs)):
|
for k in range(len(epochs)):
|
||||||
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
|
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
|
||||||
epochs = epochs[i], lr = lr[i])
|
epochs = epochs[k], lr = lr[k])
|
||||||
|
|
||||||
# Evaluate the trained quantized network
|
# Evaluate the trained quantized network
|
||||||
accuracy = quant_net_evaluation(quant_net, test_loader, device)
|
accuracy = quant_net_evaluation(quant_net, test_loader, device)
|
||||||
@@ -518,10 +518,10 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
|||||||
test_accuracy = []
|
test_accuracy = []
|
||||||
for i, w in enumerate(weights_per_layer):
|
for i, w in enumerate(weights_per_layer):
|
||||||
f_w = []
|
f_w = []
|
||||||
for i in range(len(seq_counts)):
|
for j in range(len(seq_counts)):
|
||||||
t_w = w[i]
|
t_w = w[j]
|
||||||
c,l = seq_counts[i]
|
c,l = seq_counts[j]
|
||||||
for j in range(c+l):
|
for _ in range(c+l):
|
||||||
f_w.append(t_w)
|
f_w.append(t_w)
|
||||||
|
|
||||||
if(len(seq_counts) > 0):
|
if(len(seq_counts) > 0):
|
||||||
@@ -531,9 +531,9 @@ def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, te
|
|||||||
quant_net = Quant_Model(og_model, w, layer_mapping, sign)
|
quant_net = Quant_Model(og_model, w, layer_mapping, sign)
|
||||||
quant_net = quant_net.to(device)
|
quant_net = quant_net.to(device)
|
||||||
print(f'===================================\nModel No {i} --> {w}')
|
print(f'===================================\nModel No {i} --> {w}')
|
||||||
for i in range(len(epochs)):
|
for k in range(len(epochs)):
|
||||||
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
|
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
|
||||||
epochs = epochs[i], lr = lr[i])
|
epochs = epochs[k], lr = lr[k])
|
||||||
accuracy = quant_net_evaluation(quant_net, test_loader, device)
|
accuracy = quant_net_evaluation(quant_net, test_loader, device)
|
||||||
test_accuracy.append(accuracy)
|
test_accuracy.append(accuracy)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user