[add]上传训练benchmark by z00560161
This commit is contained in:
@@ -0,0 +1,541 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
import models as models
|
||||
import numpy as np
|
||||
|
||||
from apex import amp
|
||||
from benchmark_log import hwlog
|
||||
from benchmark_log.basic_utils import get_environment_info
|
||||
from benchmark_log.basic_utils import get_model_parameter
|
||||
|
||||
BATCH_SIZE = 512
|
||||
OPTIMIZER_BATCH_SIZE = 2048
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name]))
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
parser.add_argument('--data', metavar='DIR', default='/opt/npu/dataset/imagenet',
|
||||
help='path to dataset')
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet18)')
|
||||
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=BATCH_SIZE, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('--workspace', type=str, default='./', metavar='DIR',
|
||||
help='path to directory where checkpoints will be stored')
|
||||
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('-ef', '--eval-freq', default=5, type=int,
|
||||
metavar='N', help='evaluate frequency (default: 5)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--world-size', default=-1, type=int,
|
||||
help='number of nodes for distributed training')
|
||||
parser.add_argument('--rank', default=-1, type=int,
|
||||
help='node rank for distributed training')
|
||||
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--seed', default=None, type=int,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
parser.add_argument('-bm', '--benchmark', default=0, type=int,
|
||||
metavar='N', help='set benchmark status (default: 1,run benchmark)')
|
||||
parser.add_argument('--device', default='npu', type=str,
|
||||
help='npu or gpu')
|
||||
parser.add_argument('--addr', default='10.136.181.115', type=str,
|
||||
help='master addr')
|
||||
parser.add_argument('--checkpoint-nameprefix', default='checkpoint', type=str,
|
||||
help='checkpoint-nameprefix')
|
||||
parser.add_argument('--checkpoint-freq', default=0, type=int,
|
||||
metavar='N', help='checkpoint frequency (default: 0)'
|
||||
'0: save only one file whitch per epoch;'
|
||||
'n: save diff file per n epoch'
|
||||
'-1:no checkpoint,not support')
|
||||
parser.add_argument('--device_num', default=-1, type=int,
|
||||
help='device_num')
|
||||
parser.add_argument('--warm_up_epochs', default=0, type=int,
|
||||
help='warm up')
|
||||
|
||||
# apex
|
||||
parser.add_argument('--amp', default=False, action='store_true',
|
||||
help='use amp to train the model')
|
||||
parser.add_argument('--loss-scale', default=64., type=float,
|
||||
help='loss scale using in amp, default -1 means dynamic')
|
||||
parser.add_argument('--opt-level', default='O2', type=str,
|
||||
help='loss scale using in amp, default -1 means dynamic')
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
best_acc1 = 0
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
print("===============main()=================")
|
||||
print(args)
|
||||
print("===============main()=================")
|
||||
|
||||
os.environ['KERNEL_NAME_ID'] = str(0)
|
||||
print("+++++++++++++++++++++++++++KERNEL_NAME_ID:", os.environ['KERNEL_NAME_ID'])
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.deterministic = True
|
||||
warnings.warn('You have chosen to seed training. '
|
||||
'This will turn on the CUDNN deterministic setting, '
|
||||
'which can slow down your training considerably! '
|
||||
'You may see unexpected behavior when restarting '
|
||||
'from checkpoints.')
|
||||
|
||||
os.environ['MASTER_ADDR'] = args.addr # '10.136.181.51'
|
||||
os.environ['MASTER_PORT'] = '29688'
|
||||
|
||||
if args.gpu is not None:
|
||||
warnings.warn('You have chosen a specific GPU. This will completely '
|
||||
'disable data parallelism.')
|
||||
|
||||
if args.dist_url == "env://" and args.world_size == -1:
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||
|
||||
if args.device_num != -1:
|
||||
ngpus_per_node = args.device_num
|
||||
elif args.device == 'npu':
|
||||
ngpus_per_node = torch.npu.device_count()
|
||||
else:
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.multiprocessing_distributed:
|
||||
# Since we have ngpus_per_node processes per node, the total world_size
|
||||
# needs to be adjusted accordingly
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
||||
# main_worker process function
|
||||
# The child process uses the environment variables of the parent process,
|
||||
# we have to set KERNEL_NAME_ID for every proc
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
|
||||
else:
|
||||
# Simply call main_worker function
|
||||
main_worker(args.gpu, ngpus_per_node, args)
|
||||
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, args):
|
||||
global best_acc1
|
||||
args.gpu = gpu
|
||||
print("[npu id:", args.gpu, "]", "+++++++++++++++++++++++++++ before set KERNEL_NAME_ID:",
|
||||
os.environ['KERNEL_NAME_ID'])
|
||||
os.environ['KERNEL_NAME_ID'] = str(gpu)
|
||||
print("[npu id:", args.gpu, "]", "+++++++++++++++++++++++++++KERNEL_NAME_ID:", os.environ['KERNEL_NAME_ID'])
|
||||
|
||||
if args.gpu is not None:
|
||||
print("[npu id:", args.gpu, "]", "Use GPU: {} for training".format(args.gpu))
|
||||
|
||||
if args.distributed:
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.multiprocessing_distributed:
|
||||
# For multiprocessing distributed training, rank needs to be the
|
||||
# global rank among all the processes
|
||||
args.rank = args.rank * ngpus_per_node + gpu
|
||||
|
||||
if args.device == 'npu':
|
||||
dist.init_process_group(backend=args.dist_backend, # init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
else:
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
|
||||
loc = 'npu:{}'.format(args.gpu)
|
||||
torch.npu.set_device(loc)
|
||||
|
||||
args.batch_size = int(args.batch_size / ngpus_per_node)
|
||||
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||
|
||||
print("[npu id:", args.gpu, "]", "===============main_worker()=================")
|
||||
print("[npu id:", args.gpu, "]", args)
|
||||
print("[npu id:", args.gpu, "]", "===============main_worker()=================")
|
||||
|
||||
# Data loading code
|
||||
traindir = os.path.join(args.data, 'train')
|
||||
valdir = os.path.join(args.data, 'val')
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
train_dataset = datasets.ImageFolder(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
datasets.ImageFolder(valdir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=args.workers, pin_memory=True, drop_last=True)
|
||||
|
||||
# create model
|
||||
print("[npu id:", args.gpu, "]", "=> creating model '{}'".format(args.arch))
|
||||
model = models.__dict__[args.arch]()
|
||||
# model = densenet121()
|
||||
model = model.to(loc)
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = nn.CrossEntropyLoss().to(loc)
|
||||
optimizer = torch.optim.SGD(model.parameters(), args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
if args.amp:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
best_acc1 = checkpoint['best_acc1']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if args.amp:
|
||||
amp.load_state_dict(checkpoint['amp'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.evaluate:
|
||||
validate(val_loader, model, criterion, args)
|
||||
return
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
adjust_learning_rate(optimizer, epoch, args)
|
||||
|
||||
# train for one epoch
|
||||
train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node)
|
||||
|
||||
if (epoch + 1) % (args.eval_freq) == 0 or epoch == args.epochs - 1:
|
||||
# evaluate on validation set
|
||||
acc1 = validate(val_loader, model, criterion, args, ngpus_per_node)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = acc1 > best_acc1
|
||||
best_acc1 = max(acc1, best_acc1)
|
||||
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0 or epoch == args.epochs - 1):
|
||||
if args.amp:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'amp': amp.state_dict(),
|
||||
}, is_best)
|
||||
else:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, is_best)
|
||||
|
||||
|
||||
def train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
data_time = AverageMeter('Data', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e', start_count_index=0)
|
||||
top1 = AverageMeter('Acc@1', ':6.2f', start_count_index=0)
|
||||
top5 = AverageMeter('Acc@5', ':6.2f', start_count_index=0)
|
||||
progress = ProgressMeter(
|
||||
len(train_loader),
|
||||
[batch_time, data_time, losses, top1, top5],
|
||||
prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
end = time.time()
|
||||
if args.benchmark == 1:
|
||||
optimizer.zero_grad()
|
||||
for i, (images, target) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
loc = 'npu:{}'.format(args.gpu)
|
||||
target = target.to(torch.int32)
|
||||
images, target = images.to(loc, non_blocking=False), target.to(loc, non_blocking=False)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# compute gradient and do SGD step
|
||||
if args.benchmark == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
if args.amp:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if args.benchmark == 0:
|
||||
optimizer.step()
|
||||
elif args.benchmark == 1:
|
||||
BATCH_SIZE_multiplier = int(OPTIMIZER_BATCH_SIZE / args.batch_size)
|
||||
BM_optimizer_step = ((i + 1) % BATCH_SIZE_multiplier) == 0
|
||||
if BM_optimizer_step:
|
||||
for param_group in optimizer.param_groups:
|
||||
for param in param_group['params']:
|
||||
param.grad /= BATCH_SIZE_multiplier
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0):
|
||||
progress.display(i)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0):
|
||||
print("[npu id:", args.gpu, "]", '* FPS@all {:.3f}'.format(ngpus_per_node * args.batch_size / batch_time.avg))
|
||||
hwlog.remark_print(key=hwlog.FPS, value=' * FPS@all {:.3f}'.format(ngpus_per_node * args.batch_size / batch_time.avg))
|
||||
|
||||
|
||||
def validate(val_loader, model, criterion, args, ngpus_per_node):
|
||||
batch_time = AverageMeter('Time', ':6.3f', start_count_index=0)
|
||||
losses = AverageMeter('Loss', ':.4e', start_count_index=0)
|
||||
top1 = AverageMeter('Acc@1', ':6.2f', start_count_index=0)
|
||||
top5 = AverageMeter('Acc@5', ':6.2f', start_count_index=0)
|
||||
progress = ProgressMeter(
|
||||
len(val_loader),
|
||||
[batch_time, losses, top1, top5],
|
||||
prefix='Test: ')
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (images, target) in enumerate(val_loader):
|
||||
|
||||
loc = 'npu:{}'.format(args.gpu)
|
||||
target = target.to(torch.int32)
|
||||
images, target = images.to(loc, non_blocking=False), target.to(loc, non_blocking=False)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0):
|
||||
progress.display(i)
|
||||
|
||||
# TODO: this should also be done with the ProgressMeter
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0):
|
||||
print("[npu id:", args.gpu, "]", '[AVG-ACC] * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
||||
.format(top1=top1, top5=top5))
|
||||
hwlog.remark_print(key=hwlog.EVAL_ACCURACY_TOP1, value="{top1.avg:.3f}".format(top1=top1))
|
||||
hwlog.remark_print(key=hwlog.EVAL_ACCURACY_TOP5, value="{top5.avg:.3f}".format(top5=top5))
|
||||
|
||||
return top1.avg
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, 'model_best_acc%.4f_epoch%d.pth.tar' % (state['best_acc1'], state['epoch']))
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', start_count_index=10):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.start_count_index = start_count_index
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.count += n
|
||||
if self.count > (self.start_count_index * n):
|
||||
self.sum += val * n
|
||||
self.avg = self.sum / (self.count - self.start_count_index * n)
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
def __init__(self, num_batches, meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print("[npu id:", os.environ['KERNEL_NAME_ID'], "]", '\t'.join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = '{:' + str(num_digits) + 'd}'
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
||||
# lr = args.lr * (0.1 ** (epoch // (args.epochs//3 - 3)))
|
||||
|
||||
if args.warm_up_epochs > 0 and epoch < args.warm_up_epochs:
|
||||
lr = args.lr * ((epoch+1) / (args.warm_up_epochs+1))
|
||||
else:
|
||||
alpha = 0
|
||||
cosine_decay = 0.5 * (
|
||||
1 + np.cos(np.pi * (epoch - args.warm_up_epochs) / (args.epochs - args.warm_up_epochs)))
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
lr = args.lr * decayed
|
||||
|
||||
print("=> Epoch[%d] Setting lr: %.4f" % (epoch, lr))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
hwlog.ROOT_DIR = os.path.split(os.path.abspath(__file__))[0]
|
||||
cpu_info, npu_info, framework_info, os_info, benchmark_version = get_environment_info("pytorch")
|
||||
config_info = get_model_parameter("pytorch_config")
|
||||
initinal_data = {"base_lr": 4, "dataset": "imagenet", "optimizer": "SGD", "loss_scale": 64}
|
||||
hwlog.remark_print(key=hwlog.CPU_INFO, value=cpu_info)
|
||||
hwlog.remark_print(key=hwlog.NPU_INFO, value=npu_info)
|
||||
hwlog.remark_print(key=hwlog.OS_INFO, value=os_info)
|
||||
hwlog.remark_print(key=hwlog.FRAMEWORK_INFO, value=framework_info)
|
||||
hwlog.remark_print(key=hwlog.BENCHMARK_VERSION, value=benchmark_version)
|
||||
hwlog.remark_print(key=hwlog.CONFIG_INFO, value=config_info)
|
||||
hwlog.remark_print(key=hwlog.BASE_LR, value=initinal_data.get("base_lr"))
|
||||
hwlog.remark_print(key=hwlog.DATASET, value=initinal_data.get("dataset"))
|
||||
hwlog.remark_print(key=hwlog.OPT_NAME, value=initinal_data.get("optimizer"))
|
||||
hwlog.remark_print(key=hwlog.LOSS_SCALE, value=initinal_data.get("loss_scale"))
|
||||
main()
|
||||
|
||||
+605
@@ -0,0 +1,605 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
import models as models
|
||||
import numpy as np
|
||||
|
||||
from apex import amp
|
||||
from multi_epochs_dataloader import MultiEpochsDataLoader
|
||||
from benchmark_log import hwlog
|
||||
from benchmark_log.basic_utils import get_environment_info
|
||||
from benchmark_log.basic_utils import get_model_parameter
|
||||
|
||||
BATCH_SIZE = 512
|
||||
OPTIMIZER_BATCH_SIZE = 2048
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name]))
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
parser.add_argument('--data', metavar='DIR', default='/opt/npu/dataset/imagenet',
|
||||
help='path to dataset')
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet18)')
|
||||
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=BATCH_SIZE, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('--workspace', type=str, default='./', metavar='DIR',
|
||||
help='path to directory where checkpoints will be stored')
|
||||
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('-ef', '--eval-freq', default=5, type=int,
|
||||
metavar='N', help='evaluate frequency (default: 5)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--world-size', default=-1, type=int,
|
||||
help='number of nodes for distributed training')
|
||||
parser.add_argument('--rank', default=-1, type=int,
|
||||
help='node rank for distributed training')
|
||||
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--seed', default=None, type=int,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
parser.add_argument('-bm', '--benchmark', default=0, type=int,
|
||||
metavar='N', help='set benchmark status (default: 1,run benchmark)')
|
||||
parser.add_argument('--device', default='npu', type=str, help='npu or gpu')
|
||||
parser.add_argument('--addr', default='10.136.181.115', type=str, help='master addr')
|
||||
parser.add_argument('--checkpoint-nameprefix', default='checkpoint', type=str, help='checkpoint-nameprefix')
|
||||
parser.add_argument('--checkpoint-freq', default=0, type=int,
|
||||
metavar='N', help='checkpoint frequency (default: 0)'
|
||||
'0: save only one file whitch per epoch;'
|
||||
'n: save diff file per n epoch'
|
||||
'-1:no checkpoint,not support')
|
||||
parser.add_argument('--device_num', default=-1, type=int,
|
||||
help='device_num')
|
||||
parser.add_argument('--device-list', default='', type=str, help='device id list')
|
||||
parser.add_argument('--warm_up_epochs', default=0, type=int,
|
||||
help='warm up')
|
||||
|
||||
# apex
|
||||
parser.add_argument('--amp', default=False, action='store_true',
|
||||
help='use amp to train the model')
|
||||
parser.add_argument('--loss-scale', default=64., type=float,
|
||||
help='loss scale using in amp, default -1 means dynamic')
|
||||
parser.add_argument('--opt-level', default='O2', type=str,
|
||||
help='loss scale using in amp, default -1 means dynamic')
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
best_acc1 = 0
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
print("===============main()=================")
|
||||
print(args)
|
||||
print("===============main()=================")
|
||||
|
||||
os.environ['KERNEL_NAME_ID'] = str(0)
|
||||
print("+++++++++++++++++++++++++++KERNEL_NAME_ID:", os.environ['KERNEL_NAME_ID'])
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.deterministic = True
|
||||
warnings.warn('You have chosen to seed training. '
|
||||
'This will turn on the CUDNN deterministic setting, '
|
||||
'which can slow down your training considerably! '
|
||||
'You may see unexpected behavior when restarting '
|
||||
'from checkpoints.')
|
||||
|
||||
os.environ['MASTER_ADDR'] = args.addr # '10.136.181.51'
|
||||
os.environ['MASTER_PORT'] = '29688'
|
||||
|
||||
if args.gpu is not None:
|
||||
warnings.warn('You have chosen a specific GPU. This will completely '
|
||||
'disable data parallelism.')
|
||||
|
||||
if args.dist_url == "env://" and args.world_size == -1:
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||
|
||||
if args.device_list != '':
|
||||
ngpus_per_node = len(args.device_list.split(','))
|
||||
elif args.device_num != -1:
|
||||
ngpus_per_node = args.device_num
|
||||
elif args.device == 'npu':
|
||||
ngpus_per_node = torch.npu.device_count()
|
||||
else:
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.multiprocessing_distributed:
|
||||
# Since we have ngpus_per_node processes per node, the total world_size
|
||||
# needs to be adjusted accordingly
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
||||
# main_worker process function
|
||||
# The child process uses the environment variables of the parent process,
|
||||
# we have to set KERNEL_NAME_ID for every proc
|
||||
if args.device == 'npu':
|
||||
# main_worker(args.gpu, ngpus_per_node, args)
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
# Simply call main_worker function
|
||||
main_worker(args.gpu, ngpus_per_node, args)
|
||||
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, args):
|
||||
global best_acc1
|
||||
|
||||
if args.device_list != '':
|
||||
args.gpu = int(args.device_list.split(',')[gpu])
|
||||
else:
|
||||
args.gpu = gpu
|
||||
|
||||
print("[npu id:", args.gpu, "]", "++++++++++++++++ before set KERNEL_NAME_ID:", os.environ['KERNEL_NAME_ID'])
|
||||
os.environ['KERNEL_NAME_ID'] = str(args.gpu)
|
||||
print("[npu id:", args.gpu, "]", "++++++++++++++++ KERNEL_NAME_ID:", os.environ['KERNEL_NAME_ID'])
|
||||
|
||||
if args.gpu is not None:
|
||||
print("[npu id:", args.gpu, "]", "Use GPU: {} for training".format(args.gpu))
|
||||
|
||||
if args.distributed:
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.multiprocessing_distributed:
|
||||
# For multiprocessing distributed training, rank needs to be the
|
||||
# global rank among all the processes
|
||||
args.rank = args.rank * ngpus_per_node + gpu
|
||||
|
||||
if args.device == 'npu':
|
||||
dist.init_process_group(backend=args.dist_backend, # init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
else:
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
|
||||
loc = 'npu:{}'.format(args.gpu)
|
||||
torch.npu.set_device(loc)
|
||||
|
||||
args.batch_size = int(args.batch_size / ngpus_per_node)
|
||||
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||
|
||||
print("[npu id:", args.gpu, "]", "===============main_worker()=================")
|
||||
print("[npu id:", args.gpu, "]", args)
|
||||
print("[npu id:", args.gpu, "]", "===============main_worker()=================")
|
||||
|
||||
train_loader, train_loader_len, train_sampler = get_pytorch_train_loader(args.data,
|
||||
args.batch_size,
|
||||
workers=args.workers,
|
||||
distributed=args.distributed)
|
||||
|
||||
val_loader = get_pytorch_val_loader(args.data, args.batch_size, args.workers, distributed=False)
|
||||
|
||||
# create model
|
||||
print("[npu id:", args.gpu, "]", "=> creating model '{}'".format(args.arch))
|
||||
model = models.__dict__[args.arch]()
|
||||
model = model.to(loc)
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = nn.CrossEntropyLoss().to(loc)
|
||||
optimizer = torch.optim.SGD(model.parameters(), args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
if args.amp:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale)
|
||||
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
best_acc1 = checkpoint['best_acc1']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if args.amp:
|
||||
amp.load_state_dict(checkpoint['amp'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.evaluate:
|
||||
validate(val_loader, model, criterion, args, ngpus_per_node)
|
||||
return
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
adjust_learning_rate(optimizer, epoch, args)
|
||||
|
||||
# train for one epoch
|
||||
train(train_loader, train_loader_len, model, criterion, optimizer, epoch, args, ngpus_per_node)
|
||||
|
||||
if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1 or epoch > int(args.epochs * 0.9):
|
||||
# evaluate on validation set
|
||||
acc1 = validate(val_loader, model, criterion, args, ngpus_per_node)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = acc1 > best_acc1
|
||||
best_acc1 = max(acc1, best_acc1)
|
||||
|
||||
if not args.multiprocessing_distributed or \
|
||||
(args.multiprocessing_distributed and args.rank % ngpus_per_node == 0 or epoch == args.epochs - 1):
|
||||
if args.amp:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'amp': amp.state_dict(),
|
||||
}, is_best)
|
||||
else:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, is_best)
|
||||
|
||||
|
||||
def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, args, ngpus_per_node):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
data_time = AverageMeter('Data', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e', start_count_index=0)
|
||||
top1 = AverageMeter('Acc@1', ':6.2f', start_count_index=0)
|
||||
top5 = AverageMeter('Acc@5', ':6.2f', start_count_index=0)
|
||||
progress = ProgressMeter(
|
||||
train_loader_len,
|
||||
[batch_time, data_time, losses, top1, top5],
|
||||
prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
loc = 'npu:{}'.format(args.gpu)
|
||||
|
||||
mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).view(1, 3, 1, 1)
|
||||
std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).view(1, 3, 1, 1)
|
||||
mean = mean.to(loc, non_blocking=True)
|
||||
std = std.to(loc, non_blocking=True)
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
end = time.time()
|
||||
if args.benchmark == 1:
|
||||
optimizer.zero_grad()
|
||||
|
||||
steps_per_epoch = train_loader_len
|
||||
print('==========step per epoch======================', steps_per_epoch)
|
||||
for i, (images, target) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
target = target.to(torch.int32)
|
||||
images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std)
|
||||
target = target.to(loc, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# compute gradient and do SGD step
|
||||
if args.benchmark == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
if args.amp:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
stream = torch.npu.current_stream()
|
||||
stream.synchronize()
|
||||
|
||||
if args.benchmark == 0:
|
||||
optimizer.step()
|
||||
elif args.benchmark == 1:
|
||||
BATCH_SIZE_multiplier = int(OPTIMIZER_BATCH_SIZE / args.batch_size)
|
||||
BM_optimizer_step = ((i + 1) % BATCH_SIZE_multiplier) == 0
|
||||
if BM_optimizer_step:
|
||||
for param_group in optimizer.param_groups:
|
||||
for param in param_group['params']:
|
||||
param.grad /= BATCH_SIZE_multiplier
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
stream = torch.npu.current_stream()
|
||||
stream.synchronize()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0):
|
||||
progress.display(i)
|
||||
|
||||
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0):
|
||||
print("[npu id:", args.gpu, "]", '* FPS@all {:.3f}, TIME@all {:.3f}'.format(ngpus_per_node * args.batch_size / batch_time.avg, batch_time.avg))
|
||||
hwlog.remark_print(key=hwlog.FPS, value=' * FPS@all {:.3f}'.format(ngpus_per_node * args.batch_size / batch_time.avg))
|
||||
|
||||
|
||||
def validate(val_loader, model, criterion, args, ngpus_per_node):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e', start_count_index=0)
|
||||
top1 = AverageMeter('Acc@1', ':6.2f', start_count_index=0)
|
||||
top5 = AverageMeter('Acc@5', ':6.2f', start_count_index=0)
|
||||
progress = ProgressMeter(
|
||||
len(val_loader),
|
||||
[batch_time, losses, top1, top5],
|
||||
prefix='Test: ')
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
loc = 'npu:{}'.format(args.gpu)
|
||||
mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).view(1, 3, 1, 1)
|
||||
std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).view(1, 3, 1, 1)
|
||||
mean = mean.to(loc, non_blocking=True)
|
||||
std = std.to(loc, non_blocking=True)
|
||||
|
||||
end = time.time()
|
||||
for i, (images, target) in enumerate(val_loader):
|
||||
|
||||
target = target.to(torch.int32)
|
||||
images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std)
|
||||
target = target.to(loc, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
if not args.multiprocessing_distributed or \
|
||||
(args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
|
||||
progress.display(i)
|
||||
|
||||
# TODO: this should also be done with the ProgressMeter
|
||||
if not args.multiprocessing_distributed or \
|
||||
(args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
|
||||
print("[npu id:", args.gpu, "]", '[AVG-ACC] * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
||||
.format(top1=top1, top5=top5))
|
||||
hwlog.remark_print(key=hwlog.EVAL_ACCURACY_TOP1, value="{top1.avg:.3f}".format(top1=top1))
|
||||
hwlog.remark_print(key=hwlog.EVAL_ACCURACY_TOP5, value="{top5.avg:.3f}".format(top5=top5))
|
||||
return top1.avg
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, 'model_best_acc%.4f_epoch%d.pth.tar' % (state['best_acc1'], state['epoch']))
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', start_count_index=10):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.start_count_index = start_count_index
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
if self.count == 0:
|
||||
self.N = n
|
||||
|
||||
self.val = val
|
||||
self.count += n
|
||||
if self.count > (self.start_count_index * self.N):
|
||||
self.sum += val * n
|
||||
self.avg = self.sum / (self.count - self.start_count_index * self.N)
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
def __init__(self, num_batches, meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print("[npu id:", os.environ['KERNEL_NAME_ID'], "]", '\t'.join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = '{:' + str(num_digits) + 'd}'
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
||||
# lr = args.lr * (0.1 ** (epoch // (args.epochs//3 - 3)))
|
||||
|
||||
if args.warm_up_epochs > 0 and epoch < args.warm_up_epochs:
|
||||
lr = args.lr * ((epoch + 1) / (args.warm_up_epochs + 1))
|
||||
else:
|
||||
alpha = 0
|
||||
cosine_decay = 0.5 * (
|
||||
1 + np.cos(np.pi * (epoch - args.warm_up_epochs) / (args.epochs - args.warm_up_epochs)))
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
lr = args.lr * decayed
|
||||
|
||||
print("=> Epoch[%d] Setting lr: %.4f" % (epoch, lr))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
imgs = [img[0] for img in batch]
|
||||
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
|
||||
w = imgs[0].size[0]
|
||||
h = imgs[0].size[1]
|
||||
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
|
||||
for i, img in enumerate(imgs):
|
||||
nump_array = np.asarray(img, dtype=np.uint8)
|
||||
if nump_array.ndim < 3:
|
||||
nump_array = np.expand_dims(nump_array, axis=-1)
|
||||
nump_array = np.rollaxis(nump_array, 2)
|
||||
|
||||
tensor[i] += torch.from_numpy(nump_array)
|
||||
|
||||
return tensor, targets
|
||||
|
||||
|
||||
def get_pytorch_train_loader(data_path, batch_size, workers=5, _worker_init_fn=None, distributed=False):
|
||||
traindir = os.path.join(data_path, 'train')
|
||||
train_dataset = datasets.ImageFolder(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]))
|
||||
|
||||
if distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
dataloader_fn = MultiEpochsDataLoader # torch.utils.data.DataLoader
|
||||
train_loader = dataloader_fn(
|
||||
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, sampler=train_sampler,
|
||||
collate_fn=fast_collate, drop_last=True)
|
||||
return train_loader, len(train_loader), train_sampler
|
||||
|
||||
|
||||
def get_pytorch_val_loader(data_path, batch_size, workers=5, _worker_init_fn=None, distributed=False):
|
||||
valdir = os.path.join(data_path, 'val')
|
||||
val_dataset = datasets.ImageFolder(
|
||||
valdir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
]))
|
||||
|
||||
if distributed:
|
||||
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
|
||||
else:
|
||||
val_sampler = None
|
||||
|
||||
dataloader_fn = MultiEpochsDataLoader # torch.utils.data.DataLoader
|
||||
val_loader = dataloader_fn(
|
||||
val_dataset,
|
||||
sampler=val_sampler,
|
||||
batch_size=batch_size, shuffle=(val_sampler is None),
|
||||
num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, collate_fn=fast_collate)
|
||||
|
||||
return val_loader
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
hwlog.ROOT_DIR = os.path.split(os.path.abspath(__file__))[0]
|
||||
cpu_info, npu_info, framework_info, os_info, benchmark_version = get_environment_info("pytorch")
|
||||
config_info = get_model_parameter("pytorch_config")
|
||||
initinal_data = { "dataset": "imagenet", "optimizer": "SGD", "loss_scale": 64}
|
||||
hwlog.remark_print(key=hwlog.CPU_INFO, value=cpu_info)
|
||||
hwlog.remark_print(key=hwlog.NPU_INFO, value=npu_info)
|
||||
hwlog.remark_print(key=hwlog.OS_INFO, value=os_info)
|
||||
hwlog.remark_print(key=hwlog.FRAMEWORK_INFO, value=framework_info)
|
||||
hwlog.remark_print(key=hwlog.BENCHMARK_VERSION, value=benchmark_version)
|
||||
hwlog.remark_print(key=hwlog.CONFIG_INFO, value=config_info)
|
||||
hwlog.remark_print(key=hwlog.BASE_LR, value=initinal_data.get("base_lr"))
|
||||
hwlog.remark_print(key=hwlog.DATASET, value=initinal_data.get("dataset"))
|
||||
hwlog.remark_print(key=hwlog.OPT_NAME, value=initinal_data.get("optimizer"))
|
||||
hwlog.remark_print(key=hwlog.LOSS_SCALE, value=initinal_data.get("loss_scale"))
|
||||
main()
|
||||
@@ -0,0 +1,60 @@
|
||||
# ImageNet training in PyTorch
|
||||
|
||||
This implements training of ShuffleNetV2 on the ImageNet dataset, mainly modified from [pytorch/examples](https://github.com/pytorch/examples/tree/master/imagenet).
|
||||
|
||||
## ShuffleNet Detail
|
||||
As of the current date, Ascend-Pytorch is still inefficient for contiguous operations.
|
||||
Therefore, ShufflenetV2 is re-implemented using semantics such as custom OP. For details, see models/shufflenetv2_wock_op_woct.py .
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install PyTorch ([pytorch.org](http://pytorch.org))
|
||||
- `pip install -r requirements.txt`
|
||||
- Download the ImageNet dataset from http://www.image-net.org/
|
||||
- Then, and move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh)
|
||||
|
||||
## Training 1p
|
||||
|
||||
To train a model, run `main.py` with the desired model architecture and the path to the ImageNet dataset:
|
||||
|
||||
```bash
|
||||
# FP32 training
|
||||
bash 1p.sh
|
||||
|
||||
# O2 training
|
||||
bash 1p_amp.sh
|
||||
|
||||
# FP32 training for 20 epoch experiment
|
||||
bash 1p_20e.sh
|
||||
|
||||
# FP32 training for 20 epoch experiment
|
||||
bash 1p_20e_amp.sh
|
||||
|
||||
```
|
||||
|
||||
## Training multi-cards
|
||||
```bash
|
||||
# O2 training 2p
|
||||
# Only Support device-list setting in [[0,1], [2,3], [4,5], [6,7]]
|
||||
bash 2p_amp_med.sh
|
||||
|
||||
# O2 training 4p
|
||||
# Only Support device-list setting in [[0,1,2,3], [4,5,6,7]]
|
||||
bash 4p_amp_med.sh
|
||||
|
||||
# O2 training 8p
|
||||
bash 8p_amp_med.sh
|
||||
|
||||
```
|
||||
|
||||
## ShufflenetV2 training result
|
||||
|
||||
| Acc@1 | FPS | Npu_nums| Epochs | Type |
|
||||
| :------: | :------: | :------ | :------: | :------: |
|
||||
| 61.5 | 1200 | 1 | 20 | O2 |
|
||||
| 68.5 | 2200 | 1 | 240 | O2 |
|
||||
| 66.3 | 14000 | 1 | 240 | O2 |
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,649 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
import models as models
|
||||
|
||||
# Apex
|
||||
import numpy as np
|
||||
from apex import amp
|
||||
|
||||
from benchmark_log import hwlog
|
||||
from benchmark_log.basic_utils import get_environment_info
|
||||
from benchmark_log.basic_utils import get_model_parameter
|
||||
|
||||
# from megvii repo
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
# self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
# log_probs = self.logsoftmax(inputs)
|
||||
# targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
# targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
# loss = (-targets * log_probs).mean(0).sum()
|
||||
# return loss
|
||||
|
||||
logprobs = torch.nn.functional.log_softmax(inputs, dim=-1).to("cpu")
|
||||
targets = torch.zeros_like(logprobs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * logprobs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name]))
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
|
||||
choices=model_names,
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet18)')
|
||||
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--world-size', default=-1, type=int,
|
||||
help='number of nodes for distributed training')
|
||||
parser.add_argument('--rank', default=-1, type=int,
|
||||
help='node rank for distributed training')
|
||||
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--seed', default=None, type=int,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
|
||||
# npu
|
||||
parser.add_argument('--npu', default=None, type=int,
|
||||
help='NPU id to use.')
|
||||
|
||||
# add
|
||||
parser.add_argument('--eval_between_epochs', default=1, type=int,
|
||||
help='setting bigger interval to speed up training.')
|
||||
parser.add_argument('--label_smooth', default=0, type=float,
|
||||
help='label smoothing using in CE')
|
||||
parser.add_argument('--lr_scheduler_type', default='step_epoch', type=str,
|
||||
help='lr_scheduler type, such as linear,cosine')
|
||||
parser.add_argument('--warm_up_epochs', default=0, type=int,
|
||||
help='warm up')
|
||||
parser.add_argument('--total_steps', default=-1, type=float,
|
||||
help='warm up')
|
||||
parser.add_argument('--save_path', default='./training/save', type=str,
|
||||
help='save model base path')
|
||||
parser.add_argument('--tb_path', default='./training/events', type=str,
|
||||
help='save tensorboard events path')
|
||||
|
||||
# apex
|
||||
parser.add_argument('--amp', default=False, action='store_true',
|
||||
help='use amp to train the model')
|
||||
parser.add_argument('--loss_scale', default='dynamic', type=str,
|
||||
help='loss scale using in amp, default -1 means dynamic')
|
||||
parser.add_argument('--opt_level', default='O1', type=str,
|
||||
help='opt_level using in amp, default O1.')
|
||||
|
||||
best_acc1 = 0
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.deterministic = True
|
||||
warnings.warn('You have chosen to seed training. '
|
||||
'This will turn on the CUDNN deterministic setting, '
|
||||
'which can slow down your training considerably! '
|
||||
'You may see unexpected behavior when restarting '
|
||||
'from checkpoints.')
|
||||
|
||||
if args.gpu is not None:
|
||||
warnings.warn('You have chosen a specific GPU. This will completely '
|
||||
'disable data parallelism.')
|
||||
|
||||
if args.dist_url == "env://" and args.world_size == -1:
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.multiprocessing_distributed:
|
||||
# Since we have ngpus_per_node processes per node, the total world_size
|
||||
# needs to be adjusted accordingly
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
||||
# main_worker process function
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
# Simply call main_worker function
|
||||
main_worker(args.gpu, ngpus_per_node, args)
|
||||
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, args):
|
||||
global best_acc1
|
||||
args.gpu = gpu
|
||||
|
||||
if args.gpu is not None:
|
||||
print("Use GPU: {} for training".format(args.gpu))
|
||||
|
||||
if args.distributed:
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.multiprocessing_distributed:
|
||||
# For multiprocessing distributed training, rank needs to be the
|
||||
# global rank among all the processes
|
||||
args.rank = args.rank * ngpus_per_node + gpu
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
# create model
|
||||
if args.pretrained:
|
||||
print("=> using pre-trained model '{}'".format(args.arch))
|
||||
model = models.__dict__[args.arch](pretrained=True)
|
||||
else:
|
||||
print("=> creating model '{}'".format(args.arch))
|
||||
model = models.__dict__[args.arch]()
|
||||
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
if args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model = model.cuda(args.gpu)
|
||||
elif args.npu is not None:
|
||||
torch.npu.set_device("npu:%d"%args.npu)
|
||||
model = model.to("npu:%d"%args.npu)
|
||||
# else:
|
||||
# # DataParallel will divide and allocate batch_size to all available GPUs
|
||||
# if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
|
||||
# model.features = torch.nn.DataParallel(model.features)
|
||||
# model.cuda()
|
||||
# else:
|
||||
# model = torch.nn.DataParallel(model).cuda()
|
||||
|
||||
# apex
|
||||
if args.amp:
|
||||
# Initialization
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale)
|
||||
print("=> Using amp mode.")
|
||||
|
||||
# if args.distributed:
|
||||
# # For multiprocessing distributed, DistributedDataParallel constructor
|
||||
# # should always set the single device scope, otherwise,
|
||||
# # DistributedDataParallel will use all available devices.
|
||||
# if args.gpu is not None:
|
||||
# # torch.cuda.set_device(args.gpu)
|
||||
# # model.cuda(args.gpu)
|
||||
# # When using a single GPU per process and per
|
||||
# # DistributedDataParallel, we need to divide the batch size
|
||||
# # ourselves based on the total number of GPUs we have
|
||||
# args.batch_size = int(args.batch_size / ngpus_per_node)
|
||||
# args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
# else:
|
||||
# # model.cuda()
|
||||
# # DistributedDataParallel will divide and allocate batch_size to all
|
||||
# # available GPUs if device_ids are not set
|
||||
# model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
if args.label_smooth > 0:
|
||||
# criterion = CrossEntropyLabelSmooth(1000, 0.1).cuda(args.gpu)
|
||||
criterion = CrossEntropyLabelSmooth(1000, 0.1).to("npu:%d"%args.npu)
|
||||
else:
|
||||
# criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
||||
criterion = nn.CrossEntropyLoss().to("npu:%d"%args.npu)
|
||||
|
||||
|
||||
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
if args.gpu is None:
|
||||
checkpoint = torch.load(args.resume)
|
||||
else:
|
||||
# Map model to be loaded to specified single gpu.
|
||||
loc = 'cuda:{}'.format(args.gpu)
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
best_acc1 = checkpoint['best_acc1']
|
||||
if args.gpu is not None:
|
||||
# best_acc1 may be from a checkpoint from a different GPU
|
||||
best_acc1 = best_acc1.to(args.gpu)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
# cudnn.benchmark = True
|
||||
|
||||
# Data loading code
|
||||
traindir = os.path.join(args.data, 'train')
|
||||
valdir = os.path.join(args.data, 'val')
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
train_dataset = datasets.ImageFolder(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
datasets.ImageFolder(valdir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if args.evaluate:
|
||||
validate(val_loader, model, criterion, args)
|
||||
return
|
||||
|
||||
global_step = args.start_epoch * len(train_loader)
|
||||
if args.total_steps < 0:
|
||||
args.total_steps = len(train_loader) * args.epochs
|
||||
if args.warm_up_epochs > 0:
|
||||
args.warm_up_steps = len(train_loader) * args.warm_up_epochs
|
||||
else:
|
||||
args.warm_up_steps = 0
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
if 'epoch' in args.lr_scheduler_type:
|
||||
if args.lr_scheduler_type in ['', 'step_epoch']:
|
||||
adjust_learning_rate(optimizer, epoch, args)
|
||||
else:
|
||||
adjust_learning_rate_epoch(optimizer, args, epoch)
|
||||
|
||||
|
||||
# train for one epoch
|
||||
global_step = train(train_loader, model, criterion, optimizer, epoch, args, global_step=global_step)
|
||||
|
||||
|
||||
if (epoch + 1) % args.eval_between_epochs == 0 or epoch > int(args.epochs * 0.9):
|
||||
# evaluate on validation set
|
||||
acc1 = validate(val_loader, model, criterion, args)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = acc1 > best_acc1
|
||||
best_acc1 = max(acc1, best_acc1)
|
||||
|
||||
model = model.to("cpu")
|
||||
if args.multiprocessing_distributed:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
# 'optimizer' : optimizer.state_dict(),
|
||||
}, is_best.to("cpu"), save_path=os.path.join(args.save_path, str(args.gpu)))
|
||||
else:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
# 'optimizer': optimizer.state_dict(),
|
||||
}, is_best.to("cpu"), save_path=args.save_path)
|
||||
else:
|
||||
model = model.to("cpu")
|
||||
if args.multiprocessing_distributed:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
# 'optimizer': optimizer.state_dict(),
|
||||
}, False, save_path=os.path.join(args.save_path, str(args.gpu)))
|
||||
else:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
# 'optimizer': optimizer.state_dict(),
|
||||
}, False, save_path=args.save_path)
|
||||
|
||||
model = model.to("npu")
|
||||
|
||||
def train(train_loader, model, criterion, optimizer, epoch, args, global_step):
|
||||
batch_time = AverageMeter('Time', ':6.3f', start_count_index=10)
|
||||
data_time = AverageMeter('Data', ':6.3f', start_count_index=10)
|
||||
losses = AverageMeter('Loss', ':.4e')
|
||||
top1 = AverageMeter('Acc@1', ':6.2f')
|
||||
top5 = AverageMeter('Acc@5', ':6.2f')
|
||||
progress = ProgressMeter(
|
||||
len(train_loader),
|
||||
[batch_time, data_time, losses, top1, top5],
|
||||
prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
print('==> enter train mode.')
|
||||
|
||||
end = time.time()
|
||||
for i, (images, target) in enumerate(train_loader):
|
||||
if 'epoch' not in args.lr_scheduler_type:
|
||||
lr_step = adjust_learning_rate_step(optimizer, args, global_step)
|
||||
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
target = target.cuda(args.gpu, non_blocking=True)
|
||||
if args.npu is not None:
|
||||
images = images.to("npu:%d" % args.npu, non_blocking=True)
|
||||
if not args.label_smooth > 0:
|
||||
target = target.to(torch.int32)
|
||||
target = target.to("npu:%d" % args.npu, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
if args.label_smooth > 0:
|
||||
loss = criterion(output, target).to("npu:%d" % args.npu, non_blocking=True)
|
||||
else:
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
if args.label_smooth > 0:
|
||||
target = target.to(torch.int32)
|
||||
target = target.to("npu:%d" % args.npu, non_blocking=True)
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
if args.amp:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
progress.display(i)
|
||||
|
||||
global_step += 1
|
||||
|
||||
# if i > 50:
|
||||
# break
|
||||
print("[npu id:", args.gpu, "]", '* FPS@all {:.3f}'.format(args.batch_size / batch_time.avg))
|
||||
hwlog.remark_print(key=hwlog.FPS, value=' * FPS@all {:.3f}'.format(args.batch_size / batch_time.avg))
|
||||
|
||||
|
||||
return global_step
|
||||
|
||||
def validate(val_loader, model, criterion, args):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e')
|
||||
top1 = AverageMeter('Acc@1', ':6.2f')
|
||||
top5 = AverageMeter('Acc@5', ':6.2f')
|
||||
progress = ProgressMeter(
|
||||
len(val_loader),
|
||||
[batch_time, losses, top1, top5],
|
||||
prefix='Test: ')
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
print('==> enter eval mode.')
|
||||
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (images, target) in enumerate(val_loader):
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
target = target.cuda(args.gpu, non_blocking=True)
|
||||
if args.npu is not None:
|
||||
target = target.to(torch.int32)
|
||||
images = images.to("npu:%d" % args.npu, non_blocking=True)
|
||||
target = target.to("npu:%d" % args.npu, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
if args.label_smooth > 0:
|
||||
loss = criterion(output, target).to("npu:%d" % args.npu, non_blocking=True)
|
||||
else:
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
progress.display(i)
|
||||
|
||||
# TODO: this should also be done with the ProgressMeter
|
||||
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
||||
.format(top1=top1, top5=top5))
|
||||
hwlog.remark_print(key=hwlog.EVAL_ACCURACY_TOP1, value="{top1.avg:.3f}".format(top1=top1))
|
||||
hwlog.remark_print(key=hwlog.EVAL_ACCURACY_TOP5, value="{top5.avg:.3f}".format(top5=top5))
|
||||
return top1.avg
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_path='./'):
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
torch.save(state, os.path.join(save_path, filename))
|
||||
if is_best:
|
||||
shutil.copyfile(os.path.join(save_path, filename), os.path.join(save_path, 'model_best_acc%.4f_epoch%d.pth.tar'%(state['best_acc1'], state['epoch'])))
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', start_count_index=0):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.start_count_index = start_count_index
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.count += n
|
||||
if self.count > (self.start_count_index * n):
|
||||
self.sum += val * n
|
||||
self.avg = self.sum / (self.count - self.start_count_index * n)
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
def __init__(self, num_batches, meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print('\t'.join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = '{:' + str(num_digits) + 'd}'
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
||||
lr = args.lr * (0.1 ** (epoch // 30))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
return lr
|
||||
|
||||
def adjust_learning_rate_step(optimizer, args, global_step):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
||||
if args.warm_up_steps > 0 and global_step < args.warm_up_steps:
|
||||
lr = args.lr * (global_step / args.warm_up_steps)
|
||||
else:
|
||||
if args.lr_scheduler_type == 'linear':
|
||||
lr = args.lr * (1 - (global_step - args.warm_up_steps) / (args.total_steps - - args.warm_up_steps))
|
||||
elif args.lr_scheduler_type == 'cosine':
|
||||
alpha = 0
|
||||
cosine_decay = 0.5 * (1 + np.cos(np.pi * (global_step - args.warm_up_steps) / (args.total_steps - args.warm_up_steps)))
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
lr = args.lr * decayed
|
||||
|
||||
lr = max(lr, 0)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
return lr
|
||||
|
||||
|
||||
def adjust_learning_rate_epoch(optimizer, args, global_epoch):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
||||
if args.warm_up_epochs > 0 and global_epoch < args.warm_up_epochs:
|
||||
lr = args.lr * ((global_epoch+1) / (args.warm_up_epochs+1))
|
||||
else:
|
||||
if args.lr_scheduler_type == 'linear_epoch':
|
||||
lr = args.lr * (1 - (global_epoch - args.warm_up_epochs) / (args.epochs - - args.warm_up_epochs))
|
||||
elif args.lr_scheduler_type == 'cosine_epoch':
|
||||
alpha = 0
|
||||
cosine_decay = 0.5 * (1 + np.cos(np.pi * (global_epoch - args.warm_up_epochs) / (args.epochs - args.warm_up_epochs)))
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
lr = args.lr * decayed
|
||||
|
||||
lr = max(lr, 0)
|
||||
|
||||
print("=> Epoch[%d] Setting lr: %.4f"%(global_epoch, lr))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
return lr
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
hwlog.ROOT_DIR = os.path.split(os.path.abspath(__file__))[0]
|
||||
cpu_info, npu_info, framework_info, os_info, benchmark_version = get_environment_info("pytorch")
|
||||
config_info = get_model_parameter("pytorch_config")
|
||||
initinal_data = {"base_lr": 0.256, "dataset": "imagenet", "optimizer": "SGD", "loss_scale": 64}
|
||||
hwlog.remark_print(key=hwlog.CPU_INFO, value=cpu_info)
|
||||
hwlog.remark_print(key=hwlog.NPU_INFO, value=npu_info)
|
||||
hwlog.remark_print(key=hwlog.OS_INFO, value=os_info)
|
||||
hwlog.remark_print(key=hwlog.FRAMEWORK_INFO, value=framework_info)
|
||||
hwlog.remark_print(key=hwlog.BENCHMARK_VERSION, value=benchmark_version)
|
||||
hwlog.remark_print(key=hwlog.CONFIG_INFO, value=config_info)
|
||||
hwlog.remark_print(key=hwlog.BASE_LR, value=initinal_data.get("base_lr"))
|
||||
hwlog.remark_print(key=hwlog.DATASET, value=initinal_data.get("dataset"))
|
||||
hwlog.remark_print(key=hwlog.OPT_NAME, value=initinal_data.get("optimizer"))
|
||||
hwlog.remark_print(key=hwlog.LOSS_SCALE, value=initinal_data.get("loss_scale"))
|
||||
main()
|
||||
+1
@@ -0,0 +1 @@
|
||||
from .shufflenetv2_wock_op_woct_8p import *
|
||||
+67
@@ -0,0 +1,67 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.jit.annotations import Dict
|
||||
|
||||
|
||||
class IntermediateLayerGetter(nn.ModuleDict):
|
||||
"""
|
||||
Module wrapper that returns intermediate layers from a model
|
||||
|
||||
It has a strong assumption that the modules have been registered
|
||||
into the model in the same order as they are used.
|
||||
This means that one should **not** reuse the same nn.Module
|
||||
twice in the forward if you want this to work.
|
||||
|
||||
Additionally, it is only able to query submodules that are directly
|
||||
assigned to the model. So if `model` is passed, `model.feature1` can
|
||||
be returned, but not `model.feature1.layer2`.
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model on which we will extract the features
|
||||
return_layers (Dict[name, new_name]): a dict containing the names
|
||||
of the modules for which the activations will be returned as
|
||||
the key of the dict, and the value of the dict is the name
|
||||
of the returned activation (which the user can specify).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m = torchvision.models.resnet18(pretrained=True)
|
||||
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
||||
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
|
||||
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
||||
>>> out = new_m(torch.rand(1, 3, 224, 224))
|
||||
>>> print([(k, v.shape) for k, v in out.items()])
|
||||
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
||||
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
||||
"""
|
||||
_version = 2
|
||||
__annotations__ = {
|
||||
"return_layers": Dict[str, str],
|
||||
}
|
||||
|
||||
def __init__(self, model, return_layers):
|
||||
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
||||
raise ValueError("return_layers are not present in model")
|
||||
orig_return_layers = return_layers
|
||||
return_layers = {str(k): str(v) for k, v in return_layers.items()}
|
||||
layers = OrderedDict()
|
||||
for name, module in model.named_children():
|
||||
layers[name] = module
|
||||
if name in return_layers:
|
||||
del return_layers[name]
|
||||
if not return_layers:
|
||||
break
|
||||
|
||||
super(IntermediateLayerGetter, self).__init__(layers)
|
||||
self.return_layers = orig_return_layers
|
||||
|
||||
def forward(self, x):
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
out_name = self.return_layers[name]
|
||||
out[out_name] = x
|
||||
return out
|
||||
+208
@@ -0,0 +1,208 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .utils import load_state_dict_from_url
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
|
||||
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
|
||||
]
|
||||
|
||||
model_urls = {
|
||||
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
|
||||
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
|
||||
'shufflenetv2_x1.5': None,
|
||||
'shufflenetv2_x2.0': None,
|
||||
}
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
# type: (torch.Tensor, int) -> torch.Tensor
|
||||
batchsize, num_channels, height, width = x.data.size()
|
||||
channels_per_group = num_channels // groups
|
||||
|
||||
# reshape
|
||||
x = x.view(batchsize, groups,
|
||||
channels_per_group, height, width)
|
||||
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
|
||||
# flatten
|
||||
x = x.view(batchsize, -1, height, width)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride):
|
||||
super(InvertedResidual, self).__init__()
|
||||
|
||||
if not (1 <= stride <= 3):
|
||||
raise ValueError('illegal stride value')
|
||||
self.stride = stride
|
||||
|
||||
branch_features = oup // 2
|
||||
assert (self.stride != 1) or (inp == branch_features << 1)
|
||||
|
||||
if self.stride > 1:
|
||||
self.branch1 = nn.Sequential(
|
||||
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.branch1 = nn.Sequential()
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(inp if (self.stride > 1) else branch_features,
|
||||
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
||||
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1, x2 = x.chunk(2, dim=1)
|
||||
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
||||
else:
|
||||
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
||||
|
||||
out = channel_shuffle(out, 2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Module):
|
||||
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
|
||||
if len(stages_repeats) != 3:
|
||||
raise ValueError('expected stages_repeats as list of 3 positive ints')
|
||||
if len(stages_out_channels) != 5:
|
||||
raise ValueError('expected stages_out_channels as list of 5 positive ints')
|
||||
self._stage_out_channels = stages_out_channels
|
||||
|
||||
input_channels = 3
|
||||
output_channels = self._stage_out_channels[0]
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
|
||||
for name, repeats, output_channels in zip(
|
||||
stage_names, stages_repeats, self._stage_out_channels[1:]):
|
||||
seq = [inverted_residual(input_channels, output_channels, 2)]
|
||||
for i in range(repeats - 1):
|
||||
seq.append(inverted_residual(output_channels, output_channels, 1))
|
||||
setattr(self, name, nn.Sequential(*seq))
|
||||
input_channels = output_channels
|
||||
|
||||
output_channels = self._stage_out_channels[-1]
|
||||
self.conv5 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.fc = nn.Linear(output_channels, num_classes)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.conv5(x)
|
||||
x = x.mean([2, 3]) # globalpool
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
|
||||
model = ShuffleNetV2(*args, **kwargs)
|
||||
|
||||
if pretrained:
|
||||
model_url = model_urls[arch]
|
||||
if model_url is None:
|
||||
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
|
||||
else:
|
||||
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
||||
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
|
||||
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
|
||||
+256
@@ -0,0 +1,256 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from .utils import load_state_dict_from_url
|
||||
except:
|
||||
pass
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
|
||||
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
|
||||
]
|
||||
|
||||
model_urls = {
|
||||
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
|
||||
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
|
||||
'shufflenetv2_x1.5': None,
|
||||
'shufflenetv2_x2.0': None,
|
||||
}
|
||||
|
||||
|
||||
class Channel_Shuffle(nn.Module):
|
||||
def __init__(self, inp, groups=4, split_shuffle=True):
|
||||
super(Channel_Shuffle, self).__init__()
|
||||
|
||||
self.split_shuffle = split_shuffle
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1_list = x1.chunk(self.groups // 2, 1)
|
||||
x2_list = x2.chunk(self.groups // 2, 1)
|
||||
|
||||
if self.split_shuffle:
|
||||
split_point = len(x1_list) // 2
|
||||
out1 = []
|
||||
out2 = []
|
||||
for idx in range(split_point):
|
||||
out1.append(x1_list[idx])
|
||||
out1.append(x2_list[idx])
|
||||
for idx in range(split_point, len(x1_list), 1):
|
||||
out2.append(x1_list[idx])
|
||||
out2.append(x2_list[idx])
|
||||
return torch.cat(out1, 1), torch.cat(out2, 1)
|
||||
else:
|
||||
out = []
|
||||
for idx in range(len(x1_list)):
|
||||
out.append(x1_list[idx])
|
||||
out.append(x2_list[idx])
|
||||
return torch.cat(out, 1)
|
||||
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, split_shuffle=True):
|
||||
super(InvertedResidual, self).__init__()
|
||||
|
||||
if not (1 <= stride <= 3):
|
||||
raise ValueError('illegal stride value')
|
||||
self.stride = stride
|
||||
|
||||
branch_features = oup // 2
|
||||
assert (self.stride != 1) or (inp == branch_features << 1)
|
||||
|
||||
if self.stride > 1:
|
||||
self.branch1 = nn.Sequential(
|
||||
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.branch1 = nn.Sequential()
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(inp if (self.stride > 1) else branch_features,
|
||||
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
if self.stride > 1:
|
||||
self.channel_shuffle = Channel_Shuffle(inp=branch_features + branch_features,
|
||||
split_shuffle=split_shuffle)
|
||||
else:
|
||||
self.channel_shuffle = Channel_Shuffle(inp=inp, split_shuffle=split_shuffle)
|
||||
|
||||
@staticmethod
|
||||
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
||||
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1, x2 = x
|
||||
x2 = self.branch2(x2)
|
||||
else:
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
|
||||
# out = channel_shuffle(out, 2)
|
||||
out = self.channel_shuffle(x1, x2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Module):
|
||||
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
|
||||
if len(stages_repeats) != 3:
|
||||
raise ValueError('expected stages_repeats as list of 3 positive ints')
|
||||
if len(stages_out_channels) != 5:
|
||||
raise ValueError('expected stages_out_channels as list of 5 positive ints')
|
||||
self._stage_out_channels = stages_out_channels
|
||||
|
||||
input_channels = 3
|
||||
output_channels = self._stage_out_channels[0]
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
|
||||
for name, repeats, output_channels in zip(
|
||||
stage_names, stages_repeats, self._stage_out_channels[1:]):
|
||||
seq = [inverted_residual(input_channels, output_channels, 2)]
|
||||
for i in range(repeats - 1):
|
||||
if i == repeats - 2:
|
||||
seq.append(inverted_residual(output_channels, output_channels, 1, split_shuffle=False))
|
||||
else:
|
||||
seq.append(inverted_residual(output_channels, output_channels, 1))
|
||||
setattr(self, name, nn.Sequential(*seq))
|
||||
input_channels = output_channels
|
||||
|
||||
output_channels = self._stage_out_channels[-1]
|
||||
self.conv5 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.fc = nn.Linear(output_channels, num_classes)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def _forward_impl(self, x):
|
||||
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.conv5(x)
|
||||
# x = x.mean([2, 3]) # globalpool
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
|
||||
model = ShuffleNetV2(*args, **kwargs)
|
||||
|
||||
if pretrained:
|
||||
model_url = model_urls[arch]
|
||||
if model_url is None:
|
||||
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
|
||||
else:
|
||||
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
||||
# [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
|
||||
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
||||
[4, 8, 4], [16, 128, 256, 512, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
|
||||
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = shufflenet_v2_x1_0()
|
||||
print(model)
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
+330
@@ -0,0 +1,330 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from .utils import load_state_dict_from_url
|
||||
except:
|
||||
pass
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
|
||||
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
|
||||
]
|
||||
|
||||
model_urls = {
|
||||
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
|
||||
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
|
||||
'shufflenetv2_x1.5': None,
|
||||
'shufflenetv2_x2.0': None,
|
||||
}
|
||||
|
||||
|
||||
class IndexSelectFullImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x1, x2, fp_index, bp_index1, bp_index2):
|
||||
stream = torch.npu.current_stream()
|
||||
stream.synchronize()
|
||||
|
||||
ctx.bp_index1 = bp_index1
|
||||
ctx.bp_index2 = bp_index2
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
result = x.index_select(1, fp_index)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
stream = torch.npu.current_stream()
|
||||
stream.synchronize()
|
||||
|
||||
# convert to NCHW to avoid extra 5HD --> 4D
|
||||
grad_output.data = grad_output.data.npu_format_cast(0)
|
||||
out1 = grad_output.index_select(1, ctx.bp_index1)
|
||||
out2 = grad_output.index_select(1, ctx.bp_index2)
|
||||
return out1, out2, None, None, None, None
|
||||
|
||||
|
||||
class IndexSelectHalfImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x1, x2, fp_index1, fp_index2, bp_index1, bp_index2):
|
||||
ctx.bp_index1 = bp_index1
|
||||
ctx.bp_index2 = bp_index2
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
return x.index_select(1, fp_index1), x.index_select(1, fp_index2)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output1, grad_output2):
|
||||
grad_output = torch.cat([grad_output1, grad_output2], 1)
|
||||
out1 = grad_output.index_select(1, ctx.bp_index1)
|
||||
out2 = grad_output.index_select(1, ctx.bp_index2)
|
||||
return out1, out2, None, None, None, None
|
||||
|
||||
|
||||
class Channel_Shuffle(nn.Module):
|
||||
def __init__(self, inp, groups=2, split_shuffle=True):
|
||||
super(Channel_Shuffle, self).__init__()
|
||||
|
||||
self.split_shuffle = split_shuffle
|
||||
self.group_len = inp // groups
|
||||
self.out = np.array(list(range(inp))).reshape(groups, self.group_len).transpose(1, 0).flatten().tolist()
|
||||
if self.split_shuffle:
|
||||
self.register_buffer('fp_index1', torch.tensor(self.out[:self.group_len]))
|
||||
self.register_buffer('fp_index2', torch.tensor(self.out[self.group_len:]))
|
||||
else:
|
||||
self.register_buffer('fp_index', torch.tensor(self.out))
|
||||
# self.register_buffer('bp_index', torch.tensor(list(range(0, inp, 2))+list(range(1,inp,2))))
|
||||
self.register_buffer('bp_index1', torch.tensor(list(range(0, inp, 2))))
|
||||
self.register_buffer('bp_index2', torch.tensor(list(range(1, inp, 2))))
|
||||
|
||||
def forward(self, x1, x2):
|
||||
if self.split_shuffle:
|
||||
return IndexSelectHalfImplementation.apply(x1, x2, self.fp_index1, self.fp_index2, self.bp_index1,
|
||||
self.bp_index2)
|
||||
else:
|
||||
return IndexSelectFullImplementation.apply(x1, x2, self.fp_index, self.bp_index1, self.bp_index2)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, split_shuffle=True):
|
||||
super(InvertedResidual, self).__init__()
|
||||
|
||||
if not (1 <= stride <= 3):
|
||||
raise ValueError('illegal stride value')
|
||||
self.stride = stride
|
||||
|
||||
branch_features = oup // 2
|
||||
assert (self.stride != 1) or (inp == branch_features << 1)
|
||||
|
||||
if self.stride > 1:
|
||||
self.branch1 = nn.Sequential(
|
||||
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.branch1 = nn.Sequential()
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(inp if (self.stride > 1) else branch_features,
|
||||
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
if self.stride > 1:
|
||||
self.channel_shuffle = Channel_Shuffle(inp=branch_features + branch_features, groups=2,
|
||||
split_shuffle=split_shuffle)
|
||||
else:
|
||||
self.channel_shuffle = Channel_Shuffle(inp=inp, groups=2, split_shuffle=split_shuffle)
|
||||
|
||||
@staticmethod
|
||||
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
||||
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1, x2 = x
|
||||
x2 = self.branch2(x2)
|
||||
else:
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
|
||||
# out = channel_shuffle(out, 2)
|
||||
out = self.channel_shuffle(x1, x2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Module):
|
||||
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
|
||||
if len(stages_repeats) != 3:
|
||||
raise ValueError('expected stages_repeats as list of 3 positive ints')
|
||||
if len(stages_out_channels) != 5:
|
||||
raise ValueError('expected stages_out_channels as list of 5 positive ints')
|
||||
self._stage_out_channels = stages_out_channels
|
||||
|
||||
input_channels = 3
|
||||
output_channels = self._stage_out_channels[0]
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
|
||||
for name, repeats, output_channels in zip(
|
||||
stage_names, stages_repeats, self._stage_out_channels[1:]):
|
||||
seq = [inverted_residual(input_channels, output_channels, 2)]
|
||||
for i in range(repeats - 1):
|
||||
if i == repeats - 2:
|
||||
seq.append(inverted_residual(output_channels, output_channels, 1, split_shuffle=False))
|
||||
else:
|
||||
seq.append(inverted_residual(output_channels, output_channels, 1))
|
||||
setattr(self, name, nn.Sequential(*seq))
|
||||
input_channels = output_channels
|
||||
|
||||
output_channels = self._stage_out_channels[-1]
|
||||
self.conv5 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.fc = nn.Linear(output_channels, num_classes)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def _forward_impl(self, x):
|
||||
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.conv5(x)
|
||||
# x = x.mean([2, 3]) # globalpool
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
|
||||
model = ShuffleNetV2(*args, **kwargs)
|
||||
|
||||
if pretrained:
|
||||
model_url = model_urls[arch]
|
||||
if model_url is None:
|
||||
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
|
||||
else:
|
||||
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
||||
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
|
||||
# return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
||||
# [4, 8, 4], [16, 128, 256, 464, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
|
||||
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import pickle
|
||||
|
||||
|
||||
def init():
|
||||
# init input
|
||||
x = np.random.randn(32, 3, 224, 224).astype(np.float32)
|
||||
with open('input_tensor.pkl', 'wb')as f:
|
||||
pickle.dump(x, f)
|
||||
model = shufflenet_v2_x1_0()
|
||||
with open('init_weight.pth', 'wb')as f:
|
||||
torch.save(model.state_dict(), f)
|
||||
|
||||
|
||||
with open('input_tensor.pkl', 'rb')as f:
|
||||
input_tensor = torch.from_numpy(pickle.load(f))
|
||||
input_tensor.requires_grad = True
|
||||
|
||||
model = shufflenet_v2_x1_0()
|
||||
with open('init_weight.pth', 'rb')as f:
|
||||
model.load_state_dict(torch.load(f))
|
||||
|
||||
inter_feature = {}
|
||||
inter_gradient = {}
|
||||
def make_hook(name, flag):
|
||||
if flag == 'forward':
|
||||
def hook(m, input, output):
|
||||
inter_feature[name] = input
|
||||
|
||||
return hook
|
||||
elif flag == 'backward':
|
||||
def hook(m, input, output):
|
||||
inter_gradient[name] = output
|
||||
|
||||
return hook
|
||||
else:
|
||||
assert False
|
||||
for name, m in model.named_modules():
|
||||
m.register_forward_hook(make_hook(name, 'forward'))
|
||||
m.register_backward_hook(make_hook(name, 'backward'))
|
||||
|
||||
out = model(input_tensor)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
print(inter_feature)
|
||||
print(inter_gradient)
|
||||
+328
@@ -0,0 +1,328 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from .utils import load_state_dict_from_url
|
||||
except:
|
||||
pass
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
|
||||
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
|
||||
]
|
||||
|
||||
model_urls = {
|
||||
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
|
||||
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
|
||||
'shufflenetv2_x1.5': None,
|
||||
'shufflenetv2_x2.0': None,
|
||||
}
|
||||
|
||||
|
||||
class IndexSelectFullImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x1, x2, fp_index, bp_index1, bp_index2):
|
||||
stream = torch.npu.current_stream()
|
||||
stream.synchronize()
|
||||
|
||||
ctx.bp_index1 = bp_index1
|
||||
ctx.bp_index2 = bp_index2
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
result = x.index_select(1, fp_index)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
stream = torch.npu.current_stream()
|
||||
stream.synchronize()
|
||||
|
||||
# convert to NCHW to avoid extra 5HD --> 4D
|
||||
grad_output.data = grad_output.data.npu_format_cast(0)
|
||||
out1 = grad_output.index_select(1, ctx.bp_index1)
|
||||
out2 = grad_output.index_select(1, ctx.bp_index2)
|
||||
return out1, out2, None, None, None, None
|
||||
|
||||
|
||||
class IndexSelectHalfImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x1, x2, fp_index1, fp_index2, bp_index1, bp_index2):
|
||||
ctx.bp_index1 = bp_index1
|
||||
ctx.bp_index2 = bp_index2
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
return x.index_select(1, fp_index1), x.index_select(1, fp_index2)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output1, grad_output2):
|
||||
grad_output = torch.cat([grad_output1, grad_output2], 1)
|
||||
out1 = grad_output.index_select(1, ctx.bp_index1)
|
||||
out2 = grad_output.index_select(1, ctx.bp_index2)
|
||||
return out1, out2, None, None, None, None
|
||||
|
||||
|
||||
class Channel_Shuffle(nn.Module):
|
||||
def __init__(self, inp, groups=2, split_shuffle=True):
|
||||
super(Channel_Shuffle, self).__init__()
|
||||
|
||||
self.split_shuffle = split_shuffle
|
||||
self.group_len = inp // groups
|
||||
self.out = np.array(list(range(inp))).reshape(groups, self.group_len).transpose(1, 0).flatten().tolist()
|
||||
if self.split_shuffle:
|
||||
self.register_buffer('fp_index1', torch.tensor(self.out[:self.group_len], dtype=torch.int32))
|
||||
self.register_buffer('fp_index2', torch.tensor(self.out[self.group_len:], dtype=torch.int32))
|
||||
else:
|
||||
self.register_buffer('fp_index', torch.tensor(self.out, dtype=torch.int32))
|
||||
self.register_buffer('bp_index1', torch.tensor(list(range(0, inp, 2)), dtype=torch.int32))
|
||||
self.register_buffer('bp_index2', torch.tensor(list(range(1, inp, 2)), dtype=torch.int32))
|
||||
|
||||
def forward(self, x1, x2):
|
||||
if self.split_shuffle:
|
||||
return IndexSelectHalfImplementation.apply(x1, x2, self.fp_index1, self.fp_index2, self.bp_index1,
|
||||
self.bp_index2)
|
||||
else:
|
||||
return IndexSelectFullImplementation.apply(x1, x2, self.fp_index, self.bp_index1, self.bp_index2)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, split_shuffle=True):
|
||||
super(InvertedResidual, self).__init__()
|
||||
|
||||
if not (1 <= stride <= 3):
|
||||
raise ValueError('illegal stride value')
|
||||
self.stride = stride
|
||||
|
||||
branch_features = oup // 2
|
||||
assert (self.stride != 1) or (inp == branch_features << 1)
|
||||
|
||||
if self.stride > 1:
|
||||
self.branch1 = nn.Sequential(
|
||||
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.branch1 = nn.Sequential()
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(inp if (self.stride > 1) else branch_features,
|
||||
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
if self.stride > 1:
|
||||
self.channel_shuffle = Channel_Shuffle(inp=branch_features + branch_features, groups=2,
|
||||
split_shuffle=split_shuffle)
|
||||
else:
|
||||
self.channel_shuffle = Channel_Shuffle(inp=inp, groups=2, split_shuffle=split_shuffle)
|
||||
|
||||
@staticmethod
|
||||
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
||||
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1, x2 = x
|
||||
x2 = self.branch2(x2)
|
||||
else:
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
|
||||
# out = channel_shuffle(out, 2)
|
||||
out = self.channel_shuffle(x1, x2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Module):
|
||||
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
|
||||
if len(stages_repeats) != 3:
|
||||
raise ValueError('expected stages_repeats as list of 3 positive ints')
|
||||
if len(stages_out_channels) != 5:
|
||||
raise ValueError('expected stages_out_channels as list of 5 positive ints')
|
||||
self._stage_out_channels = stages_out_channels
|
||||
|
||||
input_channels = 3
|
||||
output_channels = self._stage_out_channels[0]
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
|
||||
for name, repeats, output_channels in zip(
|
||||
stage_names, stages_repeats, self._stage_out_channels[1:]):
|
||||
seq = [inverted_residual(input_channels, output_channels, 2)]
|
||||
for i in range(repeats - 1):
|
||||
if i == repeats - 2:
|
||||
seq.append(inverted_residual(output_channels, output_channels, 1, split_shuffle=False))
|
||||
else:
|
||||
seq.append(inverted_residual(output_channels, output_channels, 1))
|
||||
setattr(self, name, nn.Sequential(*seq))
|
||||
input_channels = output_channels
|
||||
|
||||
output_channels = self._stage_out_channels[-1]
|
||||
self.conv5 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.fc = nn.Linear(output_channels, num_classes)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def _forward_impl(self, x):
|
||||
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.conv5(x)
|
||||
# x = x.mean([2, 3]) # globalpool
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
|
||||
model = ShuffleNetV2(*args, **kwargs)
|
||||
|
||||
if pretrained:
|
||||
model_url = model_urls[arch]
|
||||
if model_url is None:
|
||||
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
|
||||
else:
|
||||
state_dict = load_state_dict_from_url(model_url, progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
||||
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
|
||||
# return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
|
||||
# [4, 8, 4], [16, 128, 256, 464, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
|
||||
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
|
||||
|
||||
|
||||
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
|
||||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
|
||||
<https://arxiv.org/abs/1807.11164>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
|
||||
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import pickle
|
||||
|
||||
|
||||
def init():
|
||||
# init input
|
||||
x = np.random.randn(32, 3, 224, 224).astype(np.float32)
|
||||
with open('input_tensor.pkl', 'wb')as f:
|
||||
pickle.dump(x, f)
|
||||
model = shufflenet_v2_x1_0()
|
||||
with open('init_weight.pth', 'wb')as f:
|
||||
torch.save(model.state_dict(), f)
|
||||
|
||||
|
||||
with open('input_tensor.pkl', 'rb')as f:
|
||||
input_tensor = torch.from_numpy(pickle.load(f))
|
||||
input_tensor.requires_grad = True
|
||||
|
||||
model = shufflenet_v2_x1_0()
|
||||
with open('init_weight.pth', 'rb')as f:
|
||||
model.load_state_dict(torch.load(f))
|
||||
|
||||
inter_feature = {}
|
||||
inter_gradient = {}
|
||||
def make_hook(name, flag):
|
||||
if flag == 'forward':
|
||||
def hook(m, input, output):
|
||||
inter_feature[name] = input
|
||||
|
||||
return hook
|
||||
elif flag == 'backward':
|
||||
def hook(m, input, output):
|
||||
inter_gradient[name] = output
|
||||
|
||||
return hook
|
||||
else:
|
||||
assert False
|
||||
for name, m in model.named_modules():
|
||||
m.register_forward_hook(make_hook(name, 'forward'))
|
||||
m.register_backward_hook(make_hook(name, 'backward'))
|
||||
|
||||
out = model(input_tensor)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
print(inter_feature)
|
||||
print(inter_gradient)
|
||||
+4
@@ -0,0 +1,4 @@
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
+31
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
|
||||
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._DataLoader__initialized = False
|
||||
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||||
self._DataLoader__initialized = True
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
|
||||
class _RepeatSampler(object):
|
||||
""" Sampler that repeats forever.
|
||||
Args:
|
||||
sampler (Sampler)
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
Reference in New Issue
Block a user