AboutWriting

OpenAI-CLIP-Powered Search Engine

TL;DR

  • Run code in your terminal
  • Play the web demo in your browser
  • macOS app demo:

Framer's Component Picker

Output:

Framer's Component Picker

Framer's Component Picker

OpenAI CLIP

Framer's Component Picker

The versatility of CLIP is quite amazing. Below, I want to demonstrate its zero-shot capabilities on various tasks, such as text-prompted detection.

CLIP's Performance

Training Efficiency:

Framer's Component Picker

CLIP is among one of the most efficient models with an accuracy of 41% at 400 million images, outperforming other models such as the Bag of Words Prediction (27%) and the Transformer Language Model (16%) at the same number of images. This means that CLIP trains much faster than other models within the same domain.

Generalization: CLIP has been trained with such a wide array of image styles that it is far more flexible and than other models like ImageNet. It is important to note that CLIP generalizes well with images that it was trained on, not images outside of its training domain.

Automatically generate proposal regions with selective search, compute their similarity with a natural language query in CLIP embedding space, and return the top-k detections with non-maximum suppression.

Install dependencies


%%capture
!pip install ftfy regex tqdm matplotlib selectivesearch
!pip install git+https://github.com/openai/CLIP.git

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import urllib.request
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import clip
from PIL import Image
from torchvision import transforms
import selectivesearch
from collections import OrderedDict

def load_image(img_path, resize=None, pil=False):
    image = Image.open(image_path).convert("RGB")
    if resize is not None:
        image = image.resize((resize, resize))
    if pil:
        return image
    image = np.asarray(image).astype(np.float32) / 255.
    return image

# Reference: https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
def nms(dets, thresh):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep

# Reference: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py
def vis_detections(im, dets, thresh=0.5, caption=None):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    top_idx = dets[:, -1].argmax()

    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red' if i == top_idx else 'green', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:.3f}'.format(score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.axis('off')
    plt.tight_layout()
    plt.draw()
    if caption is not None:
        plt.title(caption, fontsize=20)
    plt.show()

Generate Bounding Boxes with Selective Search


image_url = 'http://archive.jsonline.com/Services/image.ashx?domain=www.jsonline.com&file=30025294_messykitchen1.jpg&resize=' #@param {type:"string"}
resize = None#@param {type:"raw"}
topk = 50#@param {type:"integer"}
scale =  200#@param {type:"integer"}
sigma =  0.8#@param {type:"number"}
min_size = 50#@param {type:"integer"}

# Download the image from the web.
image_path = 'image.png'
urllib.request.urlretrieve(image_url, image_path)

if resize is not None:
    assert isinstance(resize, int), "resize should be an integer."

img = load_image(image_path)
oh, ow = img.shape[:2]
print(f"Image resolution: {oh, ow}")

# Selective search.
img_search = load_image(image_path, resize=resize)
img_lbl, regions = selectivesearch.selective_search(
    img_search, scale=scale, sigma=sigma, min_size=min_size)
candidates = OrderedDict()
for i, r in enumerate(regions):
    if r['rect'] in candidates:
        continue
    if r['size'] < 1000:
        continue
    x, y, w, h = r['rect']
    if w / h > 1.5 or h / w > 1.5:
        continue
    if resize is not None:
        sx = (ow / resize)
        sy = (oh / resize)
        x_, y_, w_, h_ = r['rect']
        x = np.clip(x_ * sx, 0, ow).astype(int)
        y = np.clip(y_ * sy, 0, oh).astype(int)
        w = np.clip(w_ * sx, 0, ow).astype(int)
        h = np.clip(h_ * sy, 0, oh).astype(int)
        r['rect'] = (x, y, w, h)
    candidates[i] = r['rect']
candidates = list(candidates.values())
print(f"Generated {len(candidates)} bounding boxes. Taking the top {topk}.")
candidates = candidates[:topk]

# Display topk bounding boxes.
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 8))
ax.imshow(img)
for x, y, w, h in candidates:
    rect = mpatches.Rectangle(
        (x, y), w, h, fill=False, edgecolor='red', linewidth=1)
    ax.add_patch(rect)
plt.show()