Image classification using Split-brain Autoencoder

5 minute read

Introduction

Deep learning algorithms have shown that, when given large collections of labelled data, they can achieve human-level performance on computer vision tasks. However, for many practical tasks, the availability of data is limited. Self-supervised pretraining is a method of training whereby a network predicts a part of its input using an another unseen part, which acts as the label. The objective is to learn useful representations of the data in order to fine-tune with supervision on downstream tasks such as image classification.

Implementation

Split-Brain Autoencoder method finds useful global features for classification by solving complementary prediction tasks and therefore utilizing all data in the input. The network is divided into two fully convolutional sub-networks and each is trained to predict one subset of channels of input from the other. For fine-tuning, a classifier is added as the last layer. Using a dataset of 96x96 images, with 512k unlabeled images, 64k labelled training images, and 64k labelled validation images, we perform 1000-class classification.

The approach consists of splitting an image into two subsets of input channels (2 to 1 for a 3-channel space), preferably using a color space that separates color and luminosity. It then passes each subset through a fully convolutional architecture in order to predict the other subset. To make this prediction, it takes the Cross Entropy loss between the network output and a downsampled, quantized version of the original image (acting as labels). To clarify this with our numbers, the input image has 96x96 input features and each sub-network has 12x12 output features, each of which corresponds to a pixel in a 12x12 downsampled ground truth of the input image. The number of output channels in each sub-network corresponds to the number of classes for each pixel, which is exactly the number of colors into which each channel was quantized into. Fine-tuning consists of adding a classifier on top of the concatenated output of the two sub-networks.

Pretraining models

class SplitBrain(nn.Module):

    def __init__(self, encoder="alex", num_ch2=25, num_ch1=100):

        super(SplitBrain, self).__init__()
        if encoder == "alex":
            self.ch2_net = AlexNetAE(in_channels=2, out_channels=num_ch1)
            self.ch1_net = AlexNetAE(in_channels=1, out_channels=num_ch2**2)

        elif encoder == "resnet":
            self.ch2_net = ResNetAE(in_channels=2, out_channels=num_ch1)
            self.ch1_net = ResNetAE(in_channels=1, out_channels=num_ch2**2)

        elif encoder == "googl":
            self.ch2_net = GoogLeNetAE(in_channels=2, out_channels=num_ch1)
            self.ch1_net = GoogLeNetAE(in_channels=1, out_channels=num_ch2**2)

        elif encoder == "simple":
            self.ch2_net = SimpleAE(in_channels=2, out_channels=num_ch1)
            self.ch1_net = SimpleAE(in_channels=1, out_channels=num_ch2**2)

        print("Split Brain Parameters- AB Net: ", sum(p.numel() for p in self.ch2_net.parameters() if p.requires_grad))
        print("Split Brain Parameters- ch1 Net: ", sum(p.numel() for p in self.ch1_net.parameters() if p.requires_grad))

    def forward(self, x):
        ch2, ch1 = x
        ch2_hat = self.ch1_net(ch1)
        ch1_hat = self.ch2_net(ch2)
        return ch2_hat, ch1_hat


class SimpleAE(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(SimpleAE, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.encoder = nn.Sequential(
            nn.Conv2d(self.in_channels, 12, 4, stride=2, padding=1),  # [batch, 12, 48, 48]
            nn.BatchNorm2d(12),
            nn.ReLU(),
            nn.Conv2d(12, 24, 4, stride=2, padding=1),                # [batch, 24, 24, 24]
            nn.BatchNorm2d(24),
            nn.ReLU(),
            nn.Conv2d(24, 48, 5),                                     # [batch, 48, 20, 20]
            nn.BatchNorm2d(48),
            nn.ReLU(),
            nn.Conv2d(48, 96, 5),                                     # [batch, 48, 16, 16]
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.Conv2d(96, self.out_channels, 5)                       # [batch, out, 12, 12]
        )

    def forward(self, x):
        encoded = self.encoder(x.view(x.shape[0], self.in_channels, 96, 96))
        return encoded

Finetuning models


class SBNetClassifier(nn.Module):

    def __init__(self, encoder="alex", classifier="mlp", num_ch2=10, num_ch1=100, downsample_size=12):
        super(SBNetClassifier, self).__init__()
        self.sp = SplitBrain(encoder=encoder, num_ch2=num_ch2, num_ch1=num_ch1)

        n_in = (num_ch2**2+num_ch1)*downsample_size**2
        if classifier == "mlp":
            self.classifier = MLPClassifier(n_in,1000)

        elif classifier == "shallow":
            self.classifier = ShallowClassifier(n_in,1000)

        print("Total Finetuning Params: ", sum(p.numel() for p in self.parameters() if p.requires_grad))

    def forward(self, x):
        ch2, ch1 = x
        ch2_hat, ch1_hat = self.sp((ch2.view(ch2.shape[0], self.sp.ch2_net.in_channels, 96, 96), ch1.view(ch1.shape[0], self.sp.ch1_net.in_channels, 96, 96)))
        full_emb = torch.cat((ch2_hat, ch1_hat), 1)
        linear = self.classifier(full_emb.view(full_emb.shape[0], -1))
        return linear

Finetuning code

# Create model

classifier = create_sb_model(type="classifier_"+args.model_type+"_shallow", ckpt=pretrained_weight_name, num_ch2=args.num_classes_ch2, num_ch1=args.num_classes_ch1, downsample_size=args.downsample_size)

''' Load data '''
loader_sup, loader_val_sup, loader_unsup = nyu_lab_loader("../ssl_data_96", args.batch_size, downsample_params=[args.downsample_size, args.num_classes_ch2, args.num_classes_ch1], image_space=args.image_space, num_samples=args.num_samples_per_class)

# Define an optimizer and criterion

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

prev_top1 = 0.



#### Train #####################################################################
for epoch in range(args.epochs):

    running_loss = 0.0
    classifier.train()

    for i, (inputs, labels, _) in enumerate(loader_sup, 0):

        inputs = get_torch_vars(inputs.type(torch.FloatTensor))
        ch1 = inputs[:, 0, :, :]  # one channel
        ch2 = inputs[:, 1:3, :, :]  # two channels
        labels = get_torch_vars(labels)
        optimizer.zero_grad()

        # ============ Forward ============

        out = classifier((ch2, ch1))

        # =========== Compute Loss =========

        loss = criterion(out, labels)
        running_loss += loss.data

        # ============ Backward ============

        loss.backward()
        optimizer.step()

        # Accuracy
        top_1_acc = n_correct_top_1 / n_samples
        top_k_acc = n_correct_top_k / n_samples

        # Early Stopping
        if(top_1_acc < prev_top1):
            print("Early stopping triggered.")
            exit(0)
        else:
            prev_top1 = top_1_acc

        print('Validation top 1 accuracy: %f' % top_1_acc)
        print('Validation top %d accuracy: %f'% (top_k, top_k_acc))

    ''' Save Trained Model '''

    print('Done Training Epoch ', epoch, '. Saving Model...')
    torch.save(classifier.state_dict(), finetuned_weight_name)

    ''' Update Learning Rate '''
    scheduler.step()

Pretraining code

# Create model

split_brain = create_sb_model(type=args.model_type, num_ch2=args.num_classes_ch2, num_ch1=args.num_classes_ch1)

split_brain.train() # set model to training mode (redundant)

# Size of model
pytorch_total_params = sum(p.numel() for p in split_brain.parameters() if p.requires_grad)

print("\n\nThe model has loaded: Total ", pytorch_total_params, " parameters.")



''' Load data '''

loader_sup, loader_val_sup, loader_unsup = nyu_lab_loader("../ssl_data_96", args.batch_size, downsample_params=[args.downsample_size, args.num_classes_ch2, args.num_classes_ch1], image_space=args.image_space)



# Define an optimizer (with LR rate decay) and criterion

criterion_ch2 = nn.CrossEntropyLoss()

criterion_ch1 = nn.CrossEntropyLoss()



optimizer = optim.Adam(split_brain.parameters(), lr=1e-4)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

#### Train #####################################################################

for epoch in range(args.epochs):
    running_loss_ch2 = 0.0
    running_loss_ch1 = 0.0
    
    for i, (inputs, _, downsample) in enumerate(loader_unsup, 0):
        inputs = get_torch_vars(inputs.type(torch.FloatTensor))
        ch1 = inputs[:,0,:,:] # one channel
        ch2 = inputs[:,1:3,:,:] # two channels

        # ============ Forward ============
        ch2_hat, ch1_hat = split_brain((ch2, ch1))

        #===== Additional Processing For Pixel CrossEntropy =====
        # Quantized labels from resized original image
        # Combine a and b dims to generate 625 classes

        ch2_labels = (downsample[:, 1, :, :] * args.num_classes_ch2 + downsample[:, 2, :, :]).type(torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor).view(args.batch_size, args.downsample_size**2)
        ch2_labels_unbind = torch.unbind(ch2_labels)

        ch1_labels = downsample[:, 0, :, :].type(torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor).view(args.batch_size, args.downsample_size**2)
        ch1_labels_unbind = torch.unbind(ch1_labels)

        # ==== Get predictions for each color class and channel  =====
        ch2_hat_4loss = ch2_hat.permute(0,2,3,1).contiguous().view(args.batch_size, args.downsample_size**2, args.num_classes_ch2**2) #[batch_size*16^2, n_classes_ch2]
        ch2_hat_unbind = torch.unbind(ch2_hat_4loss)

        ch1_hat_4loss = ch1_hat.permute(0,2,3,1).contiguous().view(args.batch_size, args.downsample_size**2, args.num_classes_ch1)    #[batch*256, n_classes_ch1]
        ch1_hat_unbind = torch.unbind(ch1_hat_4loss)

        # ============ Compute Loss ===========

        loss_ch2 = 0.
        loss_ch1 = 0.

        for idx in range(len(ch1_hat_unbind)):
            loss_ch2 += criterion_ch2(ch2_hat_unbind[idx], ch2_labels_unbind[idx])
            loss_ch1 += criterion_ch1(ch1_hat_unbind[idx], ch1_labels_unbind[idx])

        loss = loss_ch1 + loss_ch2

        # ============ Backward ===========
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    ''' Save Trained Model '''

    print('Saving Model after each epoch ', epoch)
    torch.save(split_brain.state_dict(), pretrained_weight_name)

    ''' Update Learning Rate '''
    scheduler.step()

Link to Project: