{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IN5400, 2021 | Lab on object detection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "This exercise is about making an *image classifier and object localization tool*, much like that described in the first part of the lecture slides.  That is, we will build and train a CNN which produces both an image class label and, if there is a suited class found, a set of regression coefficients providing a bounding box.\n",
    "\n",
    "In total we have 16451 images, of which 11788 contains a single bird.  None of the other images contains a bird.  The ones with a bird has in addition a bounding box specified.\n",
    "\n",
    "Our model will *detect if there is a bird* in the image, and if so, *produce a bounding box* surrounding the bird.\n",
    "\n",
    "Below we have provided a (quite rich) skeleton for you to use as a starting point.  What we want you to do is to **1)** modify a pre-trained VGG16 network/model to allow both image classificaton and bounding-box regression, **2)** implement the two-step loss function (both image classification and regression loss), **3)** calculate the average IoU as well as precision and recall as we progress the training.\n",
    "\n",
    "The first two tasks need to be done for the code to be able to at all train/fine-tune the model.\n",
    "\n",
    "Where we suggest you edit the code is marked with \"**TASK!**\"\n",
    "\n",
    "Do not expect to see optimized code or clever programming.  Key here is simplicity, in that we hope you find the code easy to understand.\n",
    "\n",
    "Note that, depending on your setup, it might be an idea to copy the dataset to a local disk.  The dataset is about 1.5 GB."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset and show example images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import pickle\n",
    "import time\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import models\n",
    "from tqdm import tqdm\n",
    "\n",
    "from additionals import *\n",
    "\n",
    "# Specify data and output folders\n",
    "pth_data = '/itf-fi-ml/shared/courses/IN5400/tranfers/object_detection_data/'\n",
    "pth_work = None #  [optional] '/path-to-somewhere-you-can-store-intermediate-results/'\n",
    "\n",
    "# Split the data into a training and validation set, and create data-loaders for each [reduce batch_size if memory issues]\n",
    "random.seed(42)  # Create \"deterministic\" train/val split\n",
    "train_id, val_id = split_cub_set(0.2, pth_data)\n",
    "datasets = {'train': CUBTypeDataset(train_id, pth_data), 'val': CUBTypeDataset(val_id, pth_data)}\n",
    "dataloaders = {split: torch.utils.data.DataLoader(datasets[split], batch_size=16, shuffle=(split == 'train'), pin_memory=True) for split in ('train', 'val')}\n",
    "\n",
    "threshold_IoU = 0.75  # For use in calculating precision and recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Visualize an example image from the training data [rerun cell to see more examples]\n",
    "ind = random.choice(range(len(datasets['train'])))\n",
    "im, box, im_size = datasets['train'][ind]\n",
    "\n",
    "im = np.transpose(im.detach().numpy(), [1, 2, 0])\n",
    "\n",
    "imshow(im, box)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct a model based on a pretrained VGG16 network\n",
    "As the \"backgone\", or starting-point model, we have chosen the rather \"simple\" VGG16 network.  In this model, the *forward* function (which is the one called with the images as input) does the following:\n",
    "\n",
    "```python\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "```\n",
    "\n",
    "The \"avgpool(x)\" call will produce a 7x7x512 matrix no matter the input size of the images.  (This fits well with the network illustrated in the lecture notes.)  The \"self.classifier(x)\" step takes as input the 7x7x512 already-flattened cube (made into one long vector) as input and produces a 1000-long output (one for each of the 1000 classes) which is the final output of the model.\n",
    "\n",
    "What we need to do is to instead let the model output 6 numbers: [c0,c1,x0,y0,x1,y1]. c0 is the probability of background, c1 the probability of bird, while (x0, y0) and (x1, y1) are the upper left and bottom right corner of the bounding box.\n",
    "\n",
    "The bounding box is relative to the size of the image, just like in the lecture slide examples.  The bounding box, though, is defined using upper-left and lower-right corners instead of its center and height/width as seen in the lecture slides.\n",
    "\n",
    "A suggestion for solving this first task is that we replace the \"classifier\" function with a simple linear function transforming the 7x7x512 values into a 6-numbered output.\n",
    "\n",
    "Replacing the \"classifier\" function can be done by writing: model.classifier = ...   And a linear transform function can be created using nn.Linear(..).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Construct the model based on a pretrained VGG16 \"backbone\".\n",
    "model = models.vgg16(pretrained=True)\n",
    "\n",
    "# Here you need to modify the model.classifier so that it outputs six numbers: [c0 c1 x0 y0 x1 y1]\n",
    "# TASK! <Do this task first>\n",
    "# <solution>\n",
    "model.classifier = nn.Linear(512*7*7, 6)  # Output: [c0 c1 x0 y0 x1 y1], where c0 is probability of background and c1 is of bird\n",
    "# </solution>\n",
    "\n",
    "model = model.cuda()  # Put the model on GPU\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Verify that your model produces the correct output shape\n",
    "out = model(torch.rand((1, 3, 224, 224)).cuda())  # Run a single dummy image through the model\n",
    "assert out.shape == (1, 6), \"The output of the model for a single image should be an array of 6 numbers.\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify the functions needed for our final loss criterion\n",
    "criterion_bbox = nn.MSELoss().cuda()\n",
    "criterion_cls = nn.CrossEntropyLoss().cuda()\n",
    "\n",
    "# Specify the optimizer and the learning-rate scheduler -- please feel free to play around with these!\n",
    "optimizer = optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-2)\n",
    "scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0, max_lr=1e-3, step_size_up=822*1, step_size_down=822*5, base_momentum=0.9, max_momentum=0.9)  # Increase lr from 0 to 1e-3 linearly during the first epoch, then let the lr gradually fall back to zero during the next 5 epochs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the model\n",
    "\n",
    "Here, the first task you should solve (to at all make the program run), is to define and calculate the loss function.\n",
    "\n",
    "Note that where we suggest you edit the code is marked with \"TASK!\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_epochs = 6  # Specify for how many epochs we are prepared to train the model\n",
    "\n",
    "epoch_loss = {'train': [], 'val': []}\n",
    "epoch_metrics_and_more = {'train': [], 'val': []}\n",
    "\n",
    "# Train the model by looping through all the images 'total_epochs' times\n",
    "for epoch in range(total_epochs):\n",
    "\n",
    "    for phase in ('train', 'val'):\n",
    "        is_training = phase == 'train'\n",
    "        model.train(is_training)\n",
    "        torch.set_grad_enabled(is_training)\n",
    "\n",
    "        total_loss = 0\n",
    "        total_images = 0\n",
    "\n",
    "        start_time = time.time()\n",
    "\n",
    "        for ims, boxes, im_sizes in tqdm(dataloaders[phase]):\n",
    "\n",
    "            ims = ims.cuda()\n",
    "            ids_has_box = np.bitwise_not(np.all(boxes.numpy() == -1, axis=1))\n",
    "            targets_cls = torch.tensor(ids_has_box, dtype=torch.long).cuda()\n",
    "\n",
    "            # Run images through the model, producing a set of [c0, c1, x0, y0, y1, y2] for each\n",
    "            outputs = model(ims)\n",
    "\n",
    "            # Find average IoU, and store values needed for calculating precision, recall etc.\n",
    "            # TASK! <We suggest you complete the \"loss\" task before doing this one>\n",
    "\n",
    "            # Calculate the loss\n",
    "            # TASK! <Do this task first!>\n",
    "            # loss_cls = # Compute the image classification loss | Hint: Use \"criterion_cls(..)\" on all [c0 c1] outputs\n",
    "            # loss_bbox = # Compute the bbox loss | Hint: Use \"criterion_bbox(..)\" on the [x0 y0 y1 y2] outputs where the ground truth says that we have a bbox (cf. \"ids_has_box\"/\"targets_cls\")\n",
    "            # if not any(ids_has_box):\n",
    "            #     loss_bbox = torch.zeros(1).cuda()\n",
    "            # loss = # Compute the total loss | Hint: Combine the cls and bbox loss\n",
    "            # <solution>\n",
    "            loss_cls = criterion_cls(outputs[:, 0:2], targets_cls)\n",
    "            loss_bbox = criterion_bbox(outputs[ids_has_box, 2:], boxes[ids_has_box, :].cuda())\n",
    "            if not any(ids_has_box):\n",
    "                loss_bbox = torch.zeros(1).cuda()\n",
    "            loss = loss_bbox + loss_cls\n",
    "            # </solution>\n",
    "\n",
    "            # Add the loss to the total loss this epoch\n",
    "            total_loss += loss.item()\n",
    "            total_images += ims.size(0)\n",
    "\n",
    "            if is_training:\n",
    "                # Update the weights by doing back-propagation\n",
    "                # [Note that by doing this last, the CPU can work on loading images for the next iteration while the GPU handles this task]\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                scheduler.step()\n",
    "\n",
    "        del loss, loss_cls, loss_bbox, outputs, targets_cls  # Make sure we allow freeing these on GPU\n",
    "\n",
    "        # Calculate precision and recall\n",
    "        # TASK! <We suggest you complete the \"loss\" task before doing this one>\n",
    "\n",
    "        # Print summary for this epoch\n",
    "        # TASK! <Update this to print more statistics as you add them>\n",
    "        elapsed_time = time.time() - start_time\n",
    "        print('[{}]\\tEpoch: {}/{}\\tAvg loss*10: {:.5f}\\tTime: {:.3f} | lr={:.6f}'.format(\n",
    "            phase, epoch + 1, total_epochs, 10 * total_loss / total_images, elapsed_time, get_lr(optimizer)))\n",
    "        epoch_loss[phase].append([total_loss / total_images])\n",
    "        epoch_metrics_and_more[phase].append([get_lr(optimizer)])\n",
    "\n",
    "    # Save model to file each epoch if so is opted for\n",
    "    if pth_work is not None:\n",
    "        torch.save(model, os.path.join(pth_work, 'latest.pth'))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Now let us visualize some results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot losses and other metrics\n",
    "plt.figure(figsize=(10, 8))\n",
    "for phase in ('train', 'val'):\n",
    "    plt.plot(range(len(epoch_loss[phase])), np.asarray(epoch_loss[phase])[:, 0]*10, label=(phase + '_loss*10'))\n",
    "    plt.plot(range(len(epoch_metrics_and_more[phase])), np.asarray(epoch_metrics_and_more[phase])[:, 0], label=(phase + '_lr'))\n",
    "    # If you store more, add lines for plotting them here, or make new figures for them..\n",
    "plt.legend(prop={'size': 12})\n",
    "plt.grid(True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize some prediction results [rerun cell to see more examples]\n",
    "ind = random.choice(range(len(datasets['val'])))\n",
    "im, box_gt, im_size = datasets['val'][ind]\n",
    "\n",
    "outputs = model(im.unsqueeze(0).cuda()).detach()\n",
    "pred_box = outputs[0, 2:].cpu().numpy()\n",
    "pred_cls = torch.nn.functional.softmax(outputs[0, 0:2], dim=0)\n",
    "box_gt = box_gt.numpy()\n",
    "\n",
    "im = im.detach().numpy()\n",
    "im = np.transpose(im, [1, 2, 0])\n",
    "\n",
    "if pred_cls[1] > pred_cls[0]:\n",
    "    imshow(im, box_gt, pred_box)\n",
    "    print('Predicted bbox    : {}'.format(pred_box))\n",
    "    print('Ground-truth bbox : {}'.format(box_gt))\n",
    "else:\n",
    "    imshow(im, box_gt)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Additional questions\n",
    "\n",
    "- Let us say you now encounter images containing multiple birds.  However, luckily you have at your disposal a tool for providing (somewhat sensible) proposals of sub-regions (parts of the image) that might contain the birds.  How could you (\"easily\") apply your already-developed image classification and object localization code to handle this multi-bird setting?\n",
    "- If you did not cover it already when answering that last question, explain what non-max-suppression is and why it is likely that you would need to apply it here?\n",
    "- In the context of object detection, where we can have multiple objects in a given image, what are anchor boxes (also known as priors or default boxes)?\n",
    "- Take a look at the figure in the lecture slides illustrating the concepts of the SSD approach to object detection.  Explain the basic building blocks and concepts behind the algorithm.\n",
    "- A refinement of this just-mentioned SSD approach is to add a feature-pyramid networks (FPN) step.  What is the key issue we want such a step to alleviate?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
