import torch import copy import sys import time import pickle import numpy as np import warnings from scipy.interpolate import Rbf from collections import OrderedDict from constants import * def update_progress(index, length, **kwargs): ''' display progress Input: `index`: (int) shows the index of current progress `length`: (int) total length of the progress `**kwargs`: info to display (e.g. accuracy) ''' barLength = 10 # Modify this to change the length of the progress bar progress = float(index/length) if isinstance(progress, int): progress = float(progress) if not isinstance(progress, float): progress = 0 if progress >= 1: progress = 1 block = int(round(barLength*progress)) text = "\rPercent: [{0}] {1:.2f}% ({2}/{3}) ".format( "#"*block + "-"*(barLength-block), round(progress*100, 3), \ index, length) for key, value in kwargs.items(): text = text + str(key) + ': ' + str(value) + ', ' if len(kwargs) != 0: text = text[:-2:] sys.stdout.write(text) sys.stdout.flush() def get_layer_by_param_name(model, param_name): ''' Get a certain layer (e.g. torch.Conv2d) from a model by layer parameter name (e.g. models.conv_layers.0.weight) Input: `model`: model we want to get a certain layer from `param_name`: (string) layer parameter name Output: `layer`: (e.g. torch.nn.Conv2d) ''' # Get layer from model using layer name. layer_name_str_split = param_name.split(STRING_SEPARATOR)[:-1] layer = model for s in layer_name_str_split: layer = getattr(layer, s) return layer def get_keys_from_ordered_dict(ordered_dict): ''' get ordered list of keys from ordered dict Input: `ordered_dict` Output: `dict_keys` ''' dict_keys = [] for key, _ in ordered_dict.items(): dict_keys.append(key) # get key from (key, value) pair return dict_keys def extract_feature_map_sizes(model, input_data_shape): ''' get conv and fc layerwise feature map size Input: `model`: model which we want to get layerwise feature map size. `input_data_shape`: (list) [C, H, W]. Output: `fmap_sizes_dict`: (dict) layerwise feature map sizes. ''' fmap_sizes_dict = {} hooks = [] model = model.cuda() model.eval() def _register_hook(module): def _hook(module, input, output): type_str = module.__class__.__name__ if type_str in (CONV_LAYER_TYPES + FC_LAYER_TYPES): module_id = id(module) in_fmap_size = list(input[0].size()) out_fmap_size = list(output.size()) fmap_sizes_dict[module_id] = {KEY_INPUT_FEATURE_MAP_SIZE: in_fmap_size, KEY_OUTPUT_FEATURE_MAP_SIZE: out_fmap_size} if (not isinstance(module, torch.nn.Sequential) and not isinstance(module, torch.nn.ModuleList) and not ( module == model)): hooks.append(module.register_forward_hook(_hook)) model.apply(_register_hook) _ = model(torch.randn([1, *input_data_shape]).cuda()) for hook in hooks: hook.remove() return fmap_sizes_dict def get_network_def_from_model(model, input_data_shape): ''' return network def (OrderedDict) of the input model network_def only contains information about FC, Conv2d, ConvTranspose2d not includes batchnorm ... Input: `model`: model we want to get network_def from `input_data_shape`: (list) [C, H, W]. Output: `network_def`: (OrderedDict) keys(): layer name (e.g. model.0.1, feature.2 ...) values(): layer properties (dict) ''' network_def = OrderedDict() state_dict = model.state_dict() # extract model keys in ordered manner from model dict. state_dict_keys = get_keys_from_ordered_dict(state_dict) # extract the feature map sizes. fmap_sizes_dict = extract_feature_map_sizes(model, input_data_shape) # for pixel shuffle previous_layer_name_str = None previous_out_channels = None before_squared_pixel_shuffle_factor = int(1) for layer_param_name in state_dict_keys: layer = get_layer_by_param_name(model, layer_param_name) layer_id = id(layer) layer_name_str = STRING_SEPARATOR.join(layer_param_name.split(STRING_SEPARATOR)[:-1]) layer_type_str = layer.__class__.__name__ # If conv layer, populate network definition. # WARNING: ignores maxpool and upsampling layers. if layer_type_str in (CONV_LAYER_TYPES + FC_LAYER_TYPES) and WEIGHTSTRING in layer_param_name: # Populate network def. if layer_type_str in FC_LAYER_TYPES: network_def[layer_name_str] = { KEY_IS_DEPTHWISE: False, KEY_NUM_IN_CHANNELS: layer.in_features, KEY_NUM_OUT_CHANNELS: layer.out_features, KEY_KERNEL_SIZE: (1, 1), KEY_STRIDE: (1, 1), KEY_PADDING: (0, 0), KEY_GROUPS: 1, KEY_INPUT_FEATURE_MAP_SIZE: [1, fmap_sizes_dict[layer_id][KEY_INPUT_FEATURE_MAP_SIZE][1], 1, 1], KEY_OUTPUT_FEATURE_MAP_SIZE: [1, fmap_sizes_dict[layer_id][KEY_OUTPUT_FEATURE_MAP_SIZE][1], 1, 1] } else: # this means layer_type_str is in CONV_LAYER_TYPES # Note: Need to handle the special case when there is only one filter in the depth-wise layer # because the number of groups will also be 1, which is the same as that of the point-wise layer. if layer.groups == 1: is_depthwise = False else: is_depthwise = True network_def[layer_name_str] = { KEY_IS_DEPTHWISE: is_depthwise, KEY_NUM_IN_CHANNELS: layer.in_channels, KEY_NUM_OUT_CHANNELS: layer.out_channels, KEY_KERNEL_SIZE: layer.kernel_size, KEY_STRIDE: layer.stride, KEY_PADDING: layer.padding, KEY_GROUPS: layer.groups, # (1, C, H, W) KEY_INPUT_FEATURE_MAP_SIZE: fmap_sizes_dict[layer_id][KEY_INPUT_FEATURE_MAP_SIZE], KEY_OUTPUT_FEATURE_MAP_SIZE: fmap_sizes_dict[layer_id][KEY_OUTPUT_FEATURE_MAP_SIZE] } network_def[layer_name_str][KEY_LAYER_TYPE_STR] = layer_type_str # Support pixel shuffle. if layer_type_str in FC_LAYER_TYPES: before_squared_pixel_shuffle_factor = int(1) else: if previous_out_channels is None: before_squared_pixel_shuffle_factor = int(1) else: if previous_out_channels % layer.in_channels != 0: raise ValueError('previous_out_channels is not divisible by layer.in_channels.') before_squared_pixel_shuffle_factor = int(previous_out_channels / layer.in_channels) previous_out_channels = layer.out_channels if previous_layer_name_str is not None: network_def[previous_layer_name_str][ KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] = before_squared_pixel_shuffle_factor network_def[layer_name_str][KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] = before_squared_pixel_shuffle_factor previous_layer_name_str = layer_name_str if previous_layer_name_str: network_def[previous_layer_name_str][ KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] = before_squared_pixel_shuffle_factor return network_def def compute_weights_and_macs(network_def): ''' Compute the number of weights and MACs of a whole network. Input: `network_def`: defined in get_network_def_from_model() Output: `layer_weights_dict`: (OrderedDict) records layerwise num of weights. `total_num_weights`: (int) total num of weights. `layer_macs_dict`: (OrderedDict) recordes layerwise num of MACs. `total_num_macs`: (int) total num of MACs. ''' total_num_weights, total_num_macs = 0, 0 # Extract conv layer names from ordered network dict. network_def_keys = get_keys_from_ordered_dict(network_def) # Init dict to store num resources for each layer. layer_weights_dict = OrderedDict() layer_macs_dict = OrderedDict() # Iterate over conv layers in network def. for layer_name in network_def_keys: # Take product of filter size dimensions to get num weights for layer. layer_num_weights = (network_def[layer_name][KEY_NUM_OUT_CHANNELS] / \ network_def[layer_name][KEY_GROUPS]) * \ network_def[layer_name][KEY_NUM_IN_CHANNELS] * \ network_def[layer_name][KEY_KERNEL_SIZE][0] * \ network_def[layer_name][KEY_KERNEL_SIZE][1] # Store num weights in layer dict and add to total. layer_weights_dict[layer_name] = layer_num_weights total_num_weights += layer_num_weights # Determine num macs for layer using output size. output_size = network_def[layer_name][KEY_OUTPUT_FEATURE_MAP_SIZE] output_height, output_width = output_size[2], output_size[3] layer_num_macs = layer_num_weights * output_width * output_height # Store num macs in layer dict and add to total. layer_macs_dict[layer_name] = layer_num_macs total_num_macs += layer_num_macs return layer_weights_dict, total_num_weights, layer_macs_dict, total_num_macs def measure_latency(model, input_data_shape, runtimes=500): ''' Measure latency of 'model' Randomly sample 'runtimes' inputs with normal distribution and measure the latencies Input: `model`: model to be measured (e.g. torch.nn.Conv2d) `input_shape`: (list) input shape of the model (e.g. (B, C, H, W)) Output: average time (float) ''' total_time = .0 is_cuda = next(model.parameters()).is_cuda if is_cuda: cuda_num = next(model.parameters()).get_device() for i in range(runtimes): if is_cuda: input = torch.cuda.FloatTensor(*input_data_shape).normal_(0, 1) input = input.cuda(cuda_num) with torch.no_grad(): start = time.time() model(input) torch.cuda.synchronize() finish = time.time() else: input = torch.randn(input_data_shape) with torch.no_grad(): start = time.time() model(input) finish = time.time() total_time += (finish - start) return total_time/float(runtimes) def compute_latency_from_lookup_table(network_def, lookup_table_path): ''' Compute the latency of all layers defined in `network_def` (only including Conv and FC). When the value of latency is not in the lookup table, that value would be interpolated. Input: `network_def`: defined in get_network_def_from_model() `lookup_table_path`: (string) path to lookup table Output: `latency`: (float) latency ''' latency = .0 with open(lookup_table_path, 'rb') as file_id: lookup_table = pickle.load(file_id) for layer_name, layer_properties in network_def.items(): if layer_name not in lookup_table.keys(): raise ValueError('Layer name {} in network def not found in lookup table'.format(layer_name)) break num_in_channels = layer_properties[KEY_NUM_IN_CHANNELS] num_out_channels = layer_properties[KEY_NUM_OUT_CHANNELS] if (num_in_channels, num_out_channels) in lookup_table[layer_name][KEY_LATENCY].keys(): latency += lookup_table[layer_name][KEY_LATENCY][(num_in_channels, num_out_channels)] else: # Not found in the lookup table, then interpolate the latency feature_samples = np.array(list(lookup_table[layer_name][KEY_LATENCY].keys())) feature_samples_in = feature_samples[:, 0] feature_samples_out = feature_samples[:, 1] measurement = np.array(list(lookup_table[layer_name][KEY_LATENCY].values())) assert feature_samples_in.shape == feature_samples_out.shape assert feature_samples_in.shape == measurement.shape rbf = Rbf(feature_samples_in, feature_samples_out, \ measurement, function='cubic') num_in_channels = np.array([num_in_channels]) num_out_channels = np.array([num_out_channels]) estimated_latency = rbf(num_in_channels, num_out_channels) latency += estimated_latency[0] return latency def compute_resource(network_def, resource_type, lookup_table_path=None): ''' compute resource based on resource type Input: `network_def`: defined in get_network_def_from_model() `resource_type`: (string) (FLOPS/WEIGHTS/LATENCY) `lookup_table_path`: (string) path to lookup table Output: `resource`: (float) ''' if resource_type == 'FLOPS': _, _, _, resource = compute_weights_and_macs(network_def) elif resource_type == 'WEIGHTS': _, resource, _, _ = compute_weights_and_macs(network_def) elif resource_type == 'LATENCY': resource = compute_latency_from_lookup_table(network_def, lookup_table_path) else: raise ValueError('Only support the resource type `FLOPS`, `WEIGHTS`, and `LATENCY`.') return resource def build_latency_lookup_table(network_def_full, lookup_table_path, min_conv_feature_size=8, min_fc_feature_size=128, measure_latency_batch_size=4, measure_latency_sample_times=500, verbose=False): ''' Build lookup table for latencies of layers defined by `network_def_full`. Supported layers: Conv2d, Linear, ConvTranspose2d Modify get_network_def_from_model() and this function to include more layer types. input: `network_def_full`: defined in get_network_def_from_model() `lookup_table_path`: (string) path to save the file of lookup table `min_conv_feature_size`: (int) The size of feature maps of simplified layers (conv layer) along channel dimmension are multiples of 'min_conv_feature_size'. The reason is that on mobile devices, the computation of (B, 7, H, W) tensors would take longer time than that of (B, 8, H, W) tensors. `min_fc_feature_size`: (int) The size of features of simplified FC layers are multiples of 'min_fc_feature_size'. `measure_latency_batch_size`: (int) the batch size of input data when running forward functions to measure latency. `measure_latency_sample_times`: (int) the number of times to run the forward function of a layer in order to get its latency. `verbose`: (bool) set True to display detailed information. ''' resource_type = 'LATENCY' # Generate the lookup table. lookup_table = OrderedDict() for layer_name, layer_properties in network_def_full.items(): if verbose: print('-------------------------------------------') print('Measuring layer', layer_name, ':') # If the layer has the same properties as a previous layer, directly use the previous lookup table. for layer_name_pre, layer_properties_pre in network_def_full.items(): if layer_name_pre == layer_name: break # Do not consider pixel shuffling. layer_properties_pre[KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] = layer_properties[ KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] layer_properties_pre[KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] = layer_properties[ KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] if layer_properties_pre == layer_properties: lookup_table[layer_name] = lookup_table[layer_name_pre] if verbose: print(' Find previous layer', layer_name_pre, 'that has the same properties') break if layer_name in lookup_table: continue is_depthwise = layer_properties[KEY_IS_DEPTHWISE] num_in_channels = layer_properties[KEY_NUM_IN_CHANNELS] num_out_channels = layer_properties[KEY_NUM_OUT_CHANNELS] kernel_size = layer_properties[KEY_KERNEL_SIZE] stride = layer_properties[KEY_STRIDE] padding = layer_properties[KEY_PADDING] groups = layer_properties[KEY_GROUPS] layer_type_str = layer_properties[KEY_LAYER_TYPE_STR] input_data_shape = layer_properties[KEY_INPUT_FEATURE_MAP_SIZE] lookup_table[layer_name] = {} lookup_table[layer_name][KEY_IS_DEPTHWISE] = is_depthwise lookup_table[layer_name][KEY_NUM_IN_CHANNELS] = num_in_channels lookup_table[layer_name][KEY_NUM_OUT_CHANNELS] = num_out_channels lookup_table[layer_name][KEY_KERNEL_SIZE] = kernel_size lookup_table[layer_name][KEY_STRIDE] = stride lookup_table[layer_name][KEY_PADDING] = padding lookup_table[layer_name][KEY_GROUPS] = groups lookup_table[layer_name][KEY_LAYER_TYPE_STR] = layer_type_str lookup_table[layer_name][KEY_INPUT_FEATURE_MAP_SIZE] = input_data_shape lookup_table[layer_name][KEY_LATENCY] = {} print('Is depthwise:', is_depthwise) print('Num in channels:', num_in_channels) print('Num out channels:', num_out_channels) print('Kernel size:', kernel_size) print('Stride:', stride) print('Padding:', padding) print('Groups:', groups) print('Input feature map size:', input_data_shape) print('Layer type:', layer_type_str) ''' if num_in_channels >= min_feature_size and \ (num_in_channels % min_feature_size != 0 or num_out_channels % min_feature_size != 0): raise ValueError('The number of channels is not divisible by {}.'.format(str(min_feature_size))) ''' if layer_type_str in CONV_LAYER_TYPES: min_feature_size = min_conv_feature_size elif layer_type_str in FC_LAYER_TYPES: min_feature_size = min_fc_feature_size else: raise ValueError('Layer type {} not supported'.format(layer_type_str)) for reduced_num_in_channels in range(num_in_channels, 0, -min_feature_size): if verbose: index = 1 print(' Start measuring num_in_channels =', reduced_num_in_channels) if is_depthwise: reduced_num_out_channels_list = [reduced_num_in_channels] else: reduced_num_out_channels_list = list(range(num_out_channels, 0, -min_feature_size)) for reduced_num_out_channels in reduced_num_out_channels_list: if resource_type == 'LATENCY': if layer_type_str == 'Conv2d': if is_depthwise: layer_test = torch.nn.Conv2d(reduced_num_in_channels, reduced_num_out_channels, \ kernel_size, stride, padding, groups=reduced_num_in_channels) else: layer_test = torch.nn.Conv2d(reduced_num_in_channels, reduced_num_out_channels, \ kernel_size, stride, padding, groups=groups) input_data_shape = layer_properties[KEY_INPUT_FEATURE_MAP_SIZE] input_data_shape = (measure_latency_batch_size, reduced_num_in_channels, *input_data_shape[2::]) elif layer_type_str == 'Linear': layer_test = torch.nn.Linear(reduced_num_in_channels, reduced_num_out_channels) input_data_shape = (measure_latency_batch_size, reduced_num_in_channels) elif layer_type_str == 'ConvTranspose2d': if is_depthwise: layer_test = torch.nn.ConvTranspose2d(reduced_num_in_channels, reduced_num_out_channels, kernel_size, stride, padding, groups=reduced_num_in_channels) else: layer_test = torch.nn.ConvTranspose2d(reduced_num_in_channels, reduced_num_out_channels, kernel_size, stride, padding, groups=groups) input_data_shape = layer_properties[KEY_INPUT_FEATURE_MAP_SIZE] input_data_shape = (measure_latency_batch_size, reduced_num_in_channels, *input_data_shape[2::]) else: raise ValueError('Not support this type of layer.') if torch.cuda.is_available(): layer_test = layer_test.cuda() measurement = measure_latency(layer_test, input_data_shape, measure_latency_sample_times) else: raise ValueError('Only support building the lookup table for `LATENCY`.') # Add the measurement into the lookup table. lookup_table[layer_name][KEY_LATENCY][(reduced_num_in_channels, reduced_num_out_channels)] = measurement if verbose: update_progress(index, len(reduced_num_out_channels_list), latency=str(measurement)) index = index + 1 if verbose: print(' ') print(' Finish measuring num_in_channels =', reduced_num_in_channels) # Save the lookup table. with open(lookup_table_path, 'wb') as file_id: pickle.dump(lookup_table, file_id) return def simplify_network_def_based_on_constraint(network_def, block, constraint, resource_type, lookup_table_path=None, skip_connection_block_sets=[], min_feature_size=8): ''' Derive how much a certain block of layers ('block') should be simplified based on resource constraints. Here we treat one block as one layer although a block can contain several layers. Input: `network_def`: simplifiable network definition (conv & fc). defined in self.get_network_def_from_model(...) `block`: (int) index of block to simplify `constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy `resource_type`: (string) `FLOPS`, `WEIGHTS`, or `LATENCY` `lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY' `skip_connection_block_sets`: (list or tuple) the list of sets of blocks. Blocks in the same sets will have the same number of output channels as the corresponding feature maps will be summed later. (default: []) For example, if the outputs of block 0 and block 4 are summed and the outputs of block 1 and block 5 are summed, then skip_connection_block_sets = [(0, 4), (1, 5)] or ((0, 4), (1, 5)). Note that we currently support addition. `min_feature_size`: (int) the number of output channels of simplified (pruned) layer would be multiples of min_feature_size. (defulat: 8) Output: `simplified_network_def`: simplified network definition. Indicates how much the network should be simplified/pruned. `simplified_resource`: (float) the estimated resource consumption of simplified models. ''' # Check whether the block has a skip connection. block = [block] for skip_connection_block_set in skip_connection_block_sets: if block[0] in skip_connection_block_set: block = list(skip_connection_block_set) block.sort() break print(' simplify_def> constraint: ', constraint) print(' simplify_def> target block:', block) # Find the target layer and other layers whose output would later be added to that of target layer # (i.e. skip connection) # (contains layer index) target_layer_indices = [] max_num_out_channels = None block_counter = 0 for layer_idx, (layer_name, layer_properties) in enumerate(network_def.items()): # Neglect the depthwise layers. if layer_properties[KEY_IS_DEPTHWISE]: continue if block_counter == block[0]: target_layer_indices.append(layer_idx) if max_num_out_channels is not None: if max_num_out_channels != layer_properties[KEY_NUM_OUT_CHANNELS]: print('The blocks involved in this skip connection do not have compatible numbers of output ' 'channels.') sys.stdout.flush() max_num_out_channels = layer_properties[KEY_NUM_OUT_CHANNELS] print(' simplify_def> target layer: {}, layer index: {}'.format(layer_name, layer_idx)) del block[0] if not block: break block_counter += 1 # Check target_layer_idx. if target_layer_indices is None: raise ValueError('`Block` seems out of bound.') # Determine the number of filters and the resource consumption. simplified_network_def = copy.deepcopy(network_def) simplified_resource = None return_with_constraint_satisfied = False if max_num_out_channels >= min_feature_size: # Try numbers of channels that are multiples of '_MIN_FEATURE_SIZE'. num_out_channels_try = list(range(max_num_out_channels // min_feature_size * min_feature_size, min_feature_size - 1, -min_feature_size)) else: num_out_channels_try = [max_num_out_channels] ''' Update # of output channels of target layers. Update # of input/output channels of all depthwise layers between target layers and other subsequent non-depthwise layers (assuming # of groups == # of input channels) Update # of input channels of one non-depthwise layer following the target layers. ''' for current_num_out_channels in num_out_channels_try: # Only allow multiple of '_MIN_FEATURE_SIZE'. for target_layer_index in target_layer_indices: update_num_out_channels = True current_num_out_channels_after_pixel_shuffle = current_num_out_channels for layer_idx, (layer_name, layer_properties) in enumerate(simplified_network_def.items()): if layer_idx < target_layer_index: continue # for the block to be simplified (# of output channels is simplified) if update_num_out_channels: if not layer_properties[KEY_IS_DEPTHWISE]: layer_properties[KEY_NUM_OUT_CHANNELS] = current_num_out_channels update_num_out_channels = False print(' simplify_def> layer {}: num of output channel changed to {}'.format(layer_name, str(current_num_out_channels))) else: raise ValueError('Expected a non-depthwise layer but got a depthwise layer.') # for blocks following the target blocks (# of input channels is simplified) else: if current_num_out_channels_after_pixel_shuffle % layer_properties[ KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] != 0: raise ValueError('current_num_out_channels or current_num_out_channels_after_pixel_shuffle is ' 'not divisible by the scaling factor of pixel shuffling.') current_num_out_channels_after_pixel_shuffle = ( current_num_out_channels_after_pixel_shuffle / layer_properties[ KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR]) layer_properties[KEY_NUM_IN_CHANNELS] = current_num_out_channels_after_pixel_shuffle print(' simplify_def> layer {}: num of input channel changed to {}'.format(layer_name, str(current_num_out_channels_after_pixel_shuffle))) ''' Consider the case that a FC layer is placed after a Conv and Flatten: FC: input feature size: Cin output feature size: Cout Conv: output feature map size: H x W x C So Cin = H x W x C. If C -> C' based on constraints, then Cin -> H x W x C' ''' if layer_properties[KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] == 1: if network_def[layer_name][KEY_NUM_IN_CHANNELS] > max_num_out_channels: assert network_def[layer_name][KEY_NUM_IN_CHANNELS] % max_num_out_channels == 0 # H x W here spatial_factor = network_def[layer_name][KEY_NUM_IN_CHANNELS] // max_num_out_channels layer_properties[KEY_NUM_IN_CHANNELS] = spatial_factor*current_num_out_channels print(' simplify_def> [Update] layer {}: num of input channel changed to {}'.format(layer_name, str(spatial_factor*current_num_out_channels))) if not layer_properties[KEY_IS_DEPTHWISE]: break else: layer_properties[KEY_NUM_OUT_CHANNELS] = current_num_out_channels_after_pixel_shuffle layer_properties[KEY_GROUPS] = current_num_out_channels_after_pixel_shuffle print(' simplify_def> depthwise layer {}: num of output channel changed to {}'.format(layer_name, str(current_num_out_channels_after_pixel_shuffle))) # Get the current resource consumption simplified_resource = compute_resource(simplified_network_def, resource_type, lookup_table_path) print(' simplify_def> finish trying num of output channel: {}, resource: {}'.format(current_num_out_channels, simplified_resource)) # Terminate the simplification when the constraint has been satisfied. if simplified_resource < constraint: return_with_constraint_satisfied = True print(' simplify_def> constraint {} met when trying num of output channel: {}'.format(constraint, current_num_out_channels)) break if not return_with_constraint_satisfied: warnings.warn( 'Constraint not satisfied: constraint = {}, simplified_resource = {}'.format(constraint, simplified_resource)) return simplified_network_def, simplified_resource def simplify_model_based_on_network_def(simplified_network_def, model): ''' Choose which filters to perserve Here filters with largest L2 magnitude will be kept Input: `simplified_network_def`: network_def shows how a model will be pruned. defined in get_network_def_from_model() `model`: model to be simplified. Output: `simplified_model`: simplified model. ''' simplified_model = copy.deepcopy(model) simplified_state_dict = simplified_model.state_dict() kept_filter_idx = None for layer_param_full_name in simplified_state_dict.keys(): layer = get_layer_by_param_name(simplified_model, layer_param_full_name) layer_param_full_name_split = layer_param_full_name.split(STRING_SEPARATOR) layer_name_str = STRING_SEPARATOR.join(layer_param_full_name_split[:-1]) layer_param_name = layer_param_full_name_split[-1] layer_type_str = layer.__class__.__name__ # Reduce the number of input channels based on the layer and data type. # Reduce the number of biases of simplified layers if kept_filter_idx is None: pass elif layer_type_str in CONV_LAYER_TYPES: # Support pixel shuffle. before_squared_pixel_shuffle_factor = simplified_network_def[layer_name_str][ KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] kept_filter_idx = (kept_filter_idx[::before_squared_pixel_shuffle_factor] / before_squared_pixel_shuffle_factor) if layer_param_name == WEIGHTSTRING: #WEIGHTSTRING == layer_param_name: if layer.groups == 1: # Pointwise layer or depthwise layer with only one filter. setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[:, kept_filter_idx, :, :])) layer.in_channels = len(kept_filter_idx) print(' simplify_model> simplify Conv layer {}: ipnut channel weights {}'.format(layer_name_str, len(kept_filter_idx))) else: # depthwise setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx, :, :, :])) layer.in_channels = len(kept_filter_idx) layer.out_channels = len(kept_filter_idx) layer.groups = len(kept_filter_idx) print(' simplify_model> simplify Conv layer {}: ipnut/output channel weights {} and groups {}'.format(layer_name_str, len(kept_filter_idx), len(kept_filter_idx))) elif layer_param_name == BIASSTRING: #BIASSTRING == layer_param_name: print(' simplify_model> simplify Conv layer {}: output channel biases {}'.format(layer_name_str, len(kept_filter_idx))) setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx])) else: raise ValueError('The layer_param_name `{}` is not supported.'.format(layer_param_name)) elif layer_type_str in FC_LAYER_TYPES: if layer_param_name == BIASSTRING: print(' simplify_model> simplify FC layer {}: output channel biases {}'.format(layer_name_str, len(kept_filter_idx))) # the weights of this layer is already reduced setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx])) else: ''' the input features should be modified as its previous layer has different output features Consider the case that a FC layer is placed after a Conv and Flatten: FC: input feature size: Cin output feature size: Cout Conv: output feature map size: H x W x C So Cin = H x W x C. If C -> C' based on constraints, then Cin -> H x W x C' ''' num_in_features = simplified_network_def[layer_name_str][KEY_NUM_IN_CHANNELS] if num_in_features > len(kept_filter_idx): assert num_in_features % len(kept_filter_idx) == 0 # H x W here spatial_ratio = int(num_in_features / len(kept_filter_idx)) kept_filter_idx_fc = kept_filter_idx.clone() kept_filter_idx_fc_element = kept_filter_idx_fc*spatial_ratio kept_filter_idx_fc = kept_filter_idx_fc_element.clone() for i in range(1, spatial_ratio): kept_filter_idx_fc = torch.cat((kept_filter_idx_fc, kept_filter_idx_fc_element + i), dim=0) kept_filter_idx_fc, _ = kept_filter_idx_fc.sort() setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[:, kept_filter_idx_fc])) layer.in_features = len(kept_filter_idx_fc) assert len(kept_filter_idx_fc) == num_in_features else: setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[:, kept_filter_idx])) layer.in_features = len(kept_filter_idx) print(' simplify_model> simplify FC layer {}: input channel weights {}'.format(layer_name_str, layer.in_features)) elif layer_type_str in BNORM_LAYER_TYPES: if any(substr == layer_param_name for substr in [WEIGHTSTRING, BIASSTRING]): setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx], requires_grad=True)) layer.num_features = len(kept_filter_idx) print(' simplify_model> simplify {} layer {}: {} {}'.format(layer_type_str, layer_name_str, layer_param_name, layer.num_features)) elif any(substr == layer_param_name for substr in [RUNNING_MEANSTRING, RUNNING_VARSTRING]): setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx], requires_grad=False)) layer.num_features = len(kept_filter_idx) print(' simplify_model> simplify {} layer {}: {} {}'.format(layer_type_str, layer_name_str, layer_param_name, layer.num_features)) elif NUM_BATCHES_TRACKED == layer_param_name: getattr(layer, layer_param_name).zero_() print(' simplify_model> simplify {} layer {}: {} set to 0'.format(layer_type_str, layer_name_str, layer_param_name)) else: raise ValueError('The layer_param_name `{}` is not supported.'.format(layer_param_name)) else: raise ValueError('The layer type `{}` is not supported.'.format(type(layer))) # Reduce the number of filters and update kept_filter_idx if it is in network_def and # not a depth-wise layer. # Reduce the number of output feature maps of simplified layers if (layer_param_name == WEIGHTSTRING and layer_name_str in simplified_network_def and not simplified_network_def[layer_name_str][KEY_IS_DEPTHWISE]): num_filters = simplified_network_def[layer_name_str][KEY_NUM_OUT_CHANNELS] weight = layer.weight.data if num_filters == weight.shape[0]: # Not target layer thus not simplify # Means the current model and simplified network def # have the same number of output channels kept_filter_idx = None # Based on L2 norm, determine `kept_filter_idx` # `kept_filter_idx` is used to simplify the current layer (conv & fc) and # is also used to simplify some related following layers () else: after_squared_pixel_shuffle_factor = simplified_network_def[layer_name_str][ KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] if num_filters % after_squared_pixel_shuffle_factor != 0: raise ValueError('num_filters is not divisible by after_squared_pixel_shuffle_factor.') num_filters //= after_squared_pixel_shuffle_factor if layer_type_str in CONV_LAYER_TYPES: filter_norm = (weight * weight).sum((1, 2, 3)) filter_norm = filter_norm.view(-1, after_squared_pixel_shuffle_factor).sum(1) elif layer_type_str in FC_LAYER_TYPES: filter_norm = (weight * weight).sum(1) _, kept_filter_idx = filter_norm.topk(num_filters, sorted=False) # consider pixel shuffle kept_filter_idx_element = kept_filter_idx * after_squared_pixel_shuffle_factor kept_filter_idx = kept_filter_idx_element.clone() for pixel_shuffle_factor_counter in range(1, after_squared_pixel_shuffle_factor): kept_filter_idx = torch.cat( (kept_filter_idx, kept_filter_idx_element + pixel_shuffle_factor_counter), dim=0) kept_filter_idx, _ = kept_filter_idx.sort() if layer_type_str in CONV_LAYER_TYPES: setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx, :, :, :])) layer.out_channels = len(kept_filter_idx) print(' simplify_model> simplify Conv layer {}: output channel weights {}'.format(layer_name_str, len(kept_filter_idx))) elif layer_type_str in FC_LAYER_TYPES: setattr(layer, layer_param_name, torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx, :])) layer.out_features = len(kept_filter_idx) print(' simplify_model> simplify FC layer {}: output channel weights {}'.format(layer_name_str, len(kept_filter_idx))) return simplified_model