import numpy as np class SimpleTransformer: """ SimpleTransformer is a simple class for preprocessing and deprocessing images for caffe. """ def __init__(self, mean=[128, 128, 128]): self.mean = np.array(mean, dtype=np.float32) self.scale = 1.0 def set_mean(self, mean): """ Set the mean to subtract for centering the data. """ self.mean = mean def set_scale(self, scale): """ Set the data scaling. """ self.scale = scale def preprocess(self, im): """ preprocess() emulate the pre-processing occurring in the vgg16 caffe prototxt. """ im = np.float32(im) im = im[:, :, ::-1] # change to BGR im -= self.mean im *= self.scale im = im.transpose((2, 0, 1)) return im def deprocess(self, im): """ inverse of preprocess() """ im = im.transpose(1, 2, 0) im /= self.scale im += self.mean im = im[:, :, ::-1] # change to RGB return np.uint8(im) class CaffeSolver: """ Caffesolver is a class for creating a solver.prototxt file. It sets default values and can export a solver parameter file. Note that all parameters are stored as strings. Strings variables are stored as strings in strings. """ def __init__(self, testnet_prototxt_path="testnet.prototxt", trainnet_prototxt_path="trainnet.prototxt", debug=False): self.sp = {} # critical: self.sp['base_lr'] = '0.001' self.sp['momentum'] = '0.9' # speed: self.sp['test_iter'] = '100' self.sp['test_interval'] = '250' # looks: self.sp['display'] = '25' self.sp['snapshot'] = '2500' self.sp['snapshot_prefix'] = '"snapshot"' # string within a string! # learning rate policy self.sp['lr_policy'] = '"fixed"' # important, but rare: self.sp['gamma'] = '0.1' self.sp['weight_decay'] = '0.0005' self.sp['train_net'] = '"' + trainnet_prototxt_path + '"' self.sp['test_net'] = '"' + testnet_prototxt_path + '"' # pretty much never change these. self.sp['max_iter'] = '100000' self.sp['test_initialization'] = 'false' self.sp['average_loss'] = '25' # this has to do with the display. self.sp['iter_size'] = '1' # this is for accumulating gradients if (debug): self.sp['max_iter'] = '12' self.sp['test_iter'] = '1' self.sp['test_interval'] = '4' self.sp['display'] = '1' def add_from_file(self, filepath): """ Reads a caffe solver prototxt file and updates the Caffesolver instance parameters. """ with open(filepath, 'r') as f: for line in f: if line[0] == '#': continue splitLine = line.split(':') self.sp[splitLine[0].strip()] = splitLine[1].strip() def write(self, filepath): """ Export solver parameters to INPUT "filepath". Sorted alphabetically. """ f = open(filepath, 'w') for key, value in sorted(self.sp.items()): if not(type(value) is str): raise TypeError('All solver parameters must be strings') f.write('%s: %s\n' % (key, value))