Files
mixed-precision-neural-netw…/mpq/mpq_quantize.py
T
2024-07-26 18:43:46 +03:00

541 lines
22 KiB
Python

import numpy as np
import sys
# Import Torch
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn, optim
import brevitas.nn as qnn
from brevitas.quant import *
from brevitas.core.restrict_val import RestrictValueType
from collections import defaultdict
from torchinfo import summary
def net_input_size(X_train):
example = X_train[0]
if(len(np.shape(example)) == 1 or (len(np.shape(example)) == 3)):
example = np.expand_dims(example, axis = 0)
else:
example = np.expand_dims(example, axis = (0,1))
return np.shape(example)
def display_model_info(model, input_size):
a = summary(model, input_size, col_names = ("output_size", "num_params", "mult_adds"))
model_params_str = str(a)
lines = model_params_str.split('\n')
lines_with_macc = []
for line in lines:
if(line.startswith('├─') or line.startswith('')):
lines_with_macc.append(line)
type_of_layer = []
macc_per_layer = []
for l in lines_with_macc:
k = l.split()[-1]
if(k != '--'):
type_of_layer.append(l.split()[0].replace(':','').replace('├─',''))
macc_per_layer.append(int(l.split()[-1].replace(',','')))
else:
type_of_layer.append(l.split()[0].replace(':','').replace('├─',''))
return macc_per_layer, type_of_layer
def create_dataloaders(BATCH_SIZE, X_train, y_train, X_test, y_test,
X_val = None, y_val = None):
if(X_val is None):
torch_X_train = torch.from_numpy(X_train).type(torch.FloatTensor)
torch_y_train = torch.from_numpy(y_train).type(torch.LongTensor)
# Create feature and targets tensor for test set.
torch_X_test = torch.from_numpy(X_test).type(torch.FloatTensor)
torch_y_test = torch.from_numpy(y_test).type(torch.LongTensor)
train = torch.utils.data.TensorDataset(torch_X_train, torch_y_train)
test = torch.utils.data.TensorDataset(torch_X_test, torch_y_test)
# Data Loaders
train_loader = torch.utils.data.DataLoader(train, batch_size = BATCH_SIZE, shuffle = True)
test_loader = torch.utils.data.DataLoader(test, batch_size = BATCH_SIZE, shuffle = False)
val_loader = None
else:
torch_X_train = torch.from_numpy(X_train).type(torch.FloatTensor)
torch_y_train = torch.from_numpy(y_train).type(torch.LongTensor)
torch_X_test = torch.from_numpy(X_test).type(torch.FloatTensor)
torch_y_test = torch.from_numpy(y_test).type(torch.LongTensor)
torch_X_val = torch.from_numpy(X_val).type(torch.FloatTensor)
torch_y_val = torch.from_numpy(y_val).type(torch.LongTensor)
train = torch.utils.data.TensorDataset(torch_X_train, torch_y_train)
test = torch.utils.data.TensorDataset(torch_X_test, torch_y_test)
val = torch.utils.data.TensorDataset(torch_X_val, torch_y_val)
# Data Loaders
train_loader = torch.utils.data.DataLoader(train, batch_size = BATCH_SIZE, shuffle = True)
test_loader = torch.utils.data.DataLoader(test, batch_size = BATCH_SIZE, shuffle = False)
val_loader = torch.utils.data.DataLoader(val, batch_size = BATCH_SIZE, shuffle = True)
return train_loader, val_loader, test_loader
# Function to calculate the minimum value of a DataLoader
def calculate_minimum(dataloader):
global_min = float('inf')
for batch in dataloader:
inputs, _ = batch
batch_min = inputs.min().item()
if batch_min < global_min:
global_min = batch_min
return global_min
def fp_train(net, train_loader, val_loader = None, device = 'cpu', epochs = 20, lr = 0.0001):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr = lr)
patience = 10
best_val_loss = float('inf')
train_losses, val_losses = [], []
net = net.to(device)
for e in range(epochs):
running_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Prevent accumulation of gradients
optimizer.zero_grad()
# Make predictions
log_ps = net(images.float())
loss = criterion(log_ps, labels)
# Backpropagation
loss.backward()
optimizer.step()
running_loss += loss.item()
val_loss = 0
accuracy = 0
# Turn off gradients for validation, to save memory and computations
with torch.no_grad():
net.eval()
if(val_loader != None):
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
log_ps = net(images.float())
val_loss += criterion(log_ps, labels)
ps = torch.exp(log_ps)
# Get top predictions
_, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
net.train()
train_losses.append(running_loss/len(train_loader))
if(val_loader != None):
val_losses.append(val_loss/len(val_loader))
print(f"Epoch {e+1}/{epochs}.. "
f"Train loss: {train_losses[-1]:.3f}.. "
f"Validation loss: {val_losses[-1]:.3f}.. "
f"Validation accuracy: {accuracy/len(val_loader):.3f}")
# Check for early stopping
avg_val_loss = val_loss/len(val_loader)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
counter = 0
else:
counter += 1
if counter >= patience:
break
else:
print(f"Epoch {e+1}/{epochs}.. "
f"Train loss: {train_losses[-1]:.3f}.. ")
return net
def fp_evaluate(net, test_loader, device):
print('\nFULL PRECISION MODEL EVALUATION ...')
# Turn off gradients for validation
with torch.no_grad():
net.eval()
correct = 0
y_size = 0
for test_imgs, test_labels in test_loader:
test_imgs, test_labels = test_imgs.to(device), test_labels.to(device)
test_imgs = Variable(test_imgs).float()
output = net(test_imgs)
predicted = torch.max(output,1)[1]
correct += (predicted == test_labels).sum()
y_size += len(test_labels)
print("Test accuracy: {:.3f}% ".format(100*float(correct)/(y_size)))
floating_acc = 100*float(correct)/y_size
return floating_acc
def generate_sequences(length, values = [2, 4, 8]):
sequences = []
def generate_sequence_helper(seq):
if len(seq) == length:
sequences.append(seq)
return
for value in values:
generate_sequence_helper(seq + [value])
generate_sequence_helper([])
return sequences
def create_weight_confs(macc_per_layer):
total_macc_opt = []
cc = 0
idx = []
if(len(macc_per_layer) >= 6):
for i, mpl in enumerate(macc_per_layer):
if(mpl/max(macc_per_layer) < 0.05):
cc += 1
idx.append(i)
weights_per_layer = generate_sequences(len(macc_per_layer) - cc)
for w in weights_per_layer:
for i in idx:
w.insert(i, 8)
for w_conf in weights_per_layer:
macc = 0
for i, w in enumerate(w_conf):
if(w == 2):
macc += macc_per_layer[i]/16
elif(w == 4):
macc += macc_per_layer[i]/8
else:
macc += macc_per_layer[i]/4
total_macc_opt.append(np.round(macc))
# Get the indexes in descending order based on the values
sorted_indexes = sorted(enumerate(total_macc_opt), key=lambda x: x[1])
# Extract the sorted indexes
ascending_indexes = [index for index, _ in sorted_indexes]
weights_per_layer = [weights_per_layer[i] for i in ascending_indexes]
total_macc_opt_sorted = [total_macc_opt[i] for i in ascending_indexes]
return weights_per_layer, total_macc_opt_sorted
# Define a mapping from PyTorch layers to Brevitas layers
def create_layer_mapping(bit_width):
mapping = {
nn.Conv2d: lambda layer, bw: (qnn.QuantConv2d(in_channels=layer.in_channels,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride[0],
padding=layer.padding,
groups=layer.groups,
bias=True,
cache_inference_bias=True,
bias_quant=Int32Bias,
weight_bit_width=bw,
weight_quant=Int8WeightPerTensorFloat,
weight_scaling_min_val=2e-16,
restrict_scaling_type=RestrictValueType.LOG_FP,
return_quant_tensor=True
) if layer.groups != layer.in_channels else (
# Special case for depthwise convolutions
qnn.QuantConv2d(in_channels=layer.in_channels,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride[0],
padding=layer.padding,
groups=layer.groups,
bias=True,
cache_inference_bias=True,
bias_quant=Int32Bias,
weight_bit_width=8, # Fixed bit width for depthwise convolutions
weight_quant=Int8WeightPerTensorFloat,
weight_scaling_min_val=2e-16,
restrict_scaling_type=RestrictValueType.LOG_FP,
return_quant_tensor=True))),
nn.Linear: lambda layer, bw: qnn.QuantLinear(in_features = layer.in_features,
out_features = layer.out_features,
cache_inference_bias = True,
bias_quant = Int32Bias,
bias = True,
weight_quant = Int8WeightPerTensorFloat,
weight_bit_width = bw,
return_quant_tensor=True),
nn.ReLU: lambda _, bw: qnn.QuantReLU(bit_width = bw,
return_quant_tensor = True),
nn.MaxPool2d: lambda layer, _: qnn.QuantMaxPool2d(kernel_size = layer.kernel_size,
stride = layer.stride,
padding = layer.padding,
return_quant_tensor = True),
nn.AvgPool2d: lambda layer, _: qnn.TruncAvgPool2d(kernel_size = layer.kernel_size,
stride = layer.stride,
padding = layer.padding,
return_quant_tensor = True),
}
return mapping
# Function to convert a PyTorch layer to a Brevitas layer with a specified bit width
def convert_layer(layer, bit_width, layer_mapping):
layer_type = type(layer)
if layer_type in layer_mapping:
return layer_mapping[layer_type](layer, bit_width)
else:
return layer
# Function to convert a PyTorch model to a Brevitas model
def convert_model(module, bit_widths, layer_mapping, layer_idx = [0]):
brevitas_module = nn.Sequential()
for name, layer in module.named_children():
if list(layer.children()): # If the layer has children, recurse
brevitas_module.add_module(name, convert_model(layer, bit_widths, layer_mapping, layer_idx))
else:
layer_type = type(layer)
if layer_type in [nn.Conv2d, nn.Linear]:
bit_width = bit_widths[layer_idx[0]]
layer_idx[0] += 1
else:
bit_width = 8
brevitas_module.add_module(name, convert_layer(layer, bit_width, layer_mapping))
return brevitas_module
class Quant_Model(nn.Module):
def __init__(self, og_model, w, layer_mapping, input_sign = True):
super(Quant_Model, self).__init__()
if(input_sign):
self.quant_inp = qnn.QuantIdentity(bit_width = 8, return_quant_tensor = True,
act_quant = Uint8ActPerTensorFloat, scaling_min_val = 2e-16,
restrict_scaling_type = RestrictValueType.LOG_FP)
else:
self.quant_inp = qnn.QuantIdentity(bit_width = 8, return_quant_tensor = True,
act_quant = Int8ActPerTensorFloat, scaling_min_val = 2e-16,
restrict_scaling_type = RestrictValueType.LOG_FP)
self.sequential = convert_model(og_model, w, layer_mapping, [0])
self.o_quant = qnn.QuantIdentity(bit_width = 8, return_quant_tensor = True)
def forward(self, X):
X = self.quant_inp(X)
X = self.sequential(X)
X = self.o_quant(X)
return F.log_softmax(X, dim = 1)
def count_layers_in_sequential(module):
# List to store the counts of Conv2d and Linear layers for each nn.Sequential module
sequential_counts = []
def _count_layers(submodule, prefix = ''):
if isinstance(submodule, nn.Sequential):
conv_count = 0
linear_count = 0
# Count Conv2d and Linear layers in the current nn.Sequential module
for child in submodule.children():
if isinstance(child, nn.Conv2d):
conv_count += 1
elif isinstance(child, nn.Linear):
linear_count += 1
# Append the counts to the list
sequential_counts.append((conv_count, linear_count))
# Recursively process children of the current nn.Sequential module
for name, child in submodule.named_children():
child_prefix = f"{prefix}.{name}" if prefix else name
_count_layers(child, child_prefix)
else:
# Process children of non-nn.Sequential modules
for name, child in submodule.named_children():
_count_layers(child, prefix)
_count_layers(module)
return sequential_counts[1:]
def train_quant_model(quant_net, train_loader, val_loader = None, device = 'cpu',
epochs = 20, lr = 0.0001):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(quant_net.parameters(), lr = lr)
patience = 10
best_val_loss = float('inf')
for e in range(epochs):
running_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Prevent accumulation of gradients
optimizer.zero_grad()
# Make predictions
log_ps = quant_net(images.float())
loss = criterion(log_ps, labels)
#backprop
loss.backward()
optimizer.step()
running_loss += loss.item()
val_loss = 0
accuracy = 0
# Turn off gradients for validation
with torch.no_grad():
quant_net.eval()
if(val_loader != None):
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
log_ps = quant_net(images.float())
val_loss += criterion(log_ps, labels)
ps = torch.exp(log_ps)
# Get our top predictions
top_p, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
if(val_loader != None):
# Check for early stopping
avg_val_loss = val_loss/len(val_loader)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
counter = 0
else:
counter += 1
if counter >= patience:
break
quant_net.train()
return quant_net
def quant_net_evaluation(quant_net, test_loader, device = 'cpu'):
with torch.no_grad():
quant_net.eval()
correct = 0
y_size = 0
for test_imgs, test_labels in test_loader:
test_imgs, test_labels = test_imgs.to(device), test_labels.to(device)
test_imgs = Variable(test_imgs).float()
output = quant_net(test_imgs)
predicted = torch.max(output, 1)[1]
correct += (predicted == test_labels).sum()
y_size += len(test_labels)
print("Test accuracy: {:.3f}% ".format(100*float(correct)/y_size))
return 100 * float(correct)/y_size
def dse(og_model, max_acc_drop, weights_per_layer, fp_accuracy, train_loader, test_loader, val_loader = None,
device = 'cpu', epochs = 5, lr = 0.0001):
sign = calculate_minimum(train_loader) >= 0
seq_counts = count_layers_in_sequential(og_model)
if max_acc_drop is not None:
print('\nDSE STARTING ... BINARY SEARCH')
opt_found = 0
low = 0
high = len(weights_per_layer) - 1
while low <= high:
mid = (low + high) // 2
w = weights_per_layer[mid]
f_w = []
for j in range(len(seq_counts)):
t_w = w[j]
c,l = seq_counts[j]
for _ in range(c+l):
f_w.append(t_w)
if(len(seq_counts) > 0):
w = f_w
# Create and train the quantized network
layer_mapping = create_layer_mapping(w)
quant_net = Quant_Model(og_model, w, layer_mapping, sign)
quant_net = quant_net.to(device)
print(f'==========================\nEvaluating Configuration: {mid} --> Weights: {w}')
for k in range(len(epochs)):
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
epochs = epochs[k], lr = lr[k])
# Evaluate the trained quantized network
accuracy = quant_net_evaluation(quant_net, test_loader, device)
# Check if the accuracy drop is within the acceptable range
if fp_accuracy - accuracy <= max_acc_drop:
opt_found = 1
optimal_quant_net = quant_net
optimal_config = w
high = mid - 1 # Try to find a less complex configuration that meets the criteria
else:
low = mid + 1 # Too much accuracy loss, look for a more complex configuration
quant_net = optimal_quant_net
if(opt_found == 0):
print("No solution that meets user's criteria was found !!")
optimal_config = w
return quant_net, optimal_config
else: # Exhaustive Search for optimal solutions & to create Pareto Space for the specific Model
print('\nDSE STARTING ... EXHAUSTIVE SEARCH')
test_accuracy = []
for i, w in enumerate(weights_per_layer):
f_w = []
for j in range(len(seq_counts)):
t_w = w[j]
c,l = seq_counts[j]
for _ in range(c+l):
f_w.append(t_w)
if(len(seq_counts) > 0):
w = f_w
layer_mapping = create_layer_mapping(w)
quant_net = Quant_Model(og_model, w, layer_mapping, sign)
quant_net = quant_net.to(device)
print(f'===================================\nModel No {i} --> {w}')
for k in range(len(epochs)):
quant_net = train_quant_model(quant_net, train_loader, val_loader, device,
epochs = epochs[k], lr = lr[k])
accuracy = quant_net_evaluation(quant_net, test_loader, device)
test_accuracy.append(accuracy)
return quant_net, test_accuracy