Source code for torchjpeg.dct._nn

import math
from typing import Optional

import torch
from torch import Tensor

from ._block import blockify, deblockify


def double_size_tensor() -> Tensor:
    """box resize, takes and 8 x 8 image and returns a 16 x 16 image"""
    op = torch.zeros((8, 8, 16, 16))
    for i in range(0, 8):
        for j in range(0, 8):
            for u in range(0, 16):
                for v in range(0, 16):
                    if i == u // 2 and j == v // 2:
                        op[i, j, u, v] = 1

    return op


def half_size_tensor() -> Tensor:
    """box resize, takes a 16 x 16 and returns an 8 x 8"""
    op = torch.zeros((16, 16, 8, 8))
    for i in range(0, 16):
        for j in range(0, 16):
            for u in range(0, 8):
                for v in range(0, 8):
                    if i == 2 * u and j == 2 * v:
                        op[i, j, u, v] = 1

    return op


def A(alpha) -> float:
    """DCT orthonormal scale factor"""
    if alpha == 0:
        return 1.0 / math.sqrt(2)

    return 1.0


def D() -> Tensor:
    """DCT tensor"""
    D_t = torch.zeros((8, 8, 8, 8))

    for i in range(8):
        for j in range(8):
            for alpha in range(8):
                for beta in range(8):
                    scale_a = A(alpha)
                    scale_b = A(beta)

                    coeff_x = math.cos(((2 * i + 1) * alpha * math.pi) / 16)
                    coeff_y = math.cos(((2 * j + 1) * beta * math.pi) / 16)

                    D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y
    return D_t


def reblock() -> Tensor:
    """reblocker, takes a 16 x 16 and returns 4 8 x 8 blocks"""
    B_t = torch.zeros((16, 16, 4, 8, 8))

    # This is deep but it really needs to be
    # pylint: disable=too-many-nested-blocks
    for s_x in range(16):
        for s_y in range(16):
            for n in range(4):
                for i in range(8):
                    for j in range(8):
                        x = n % 2
                        y = n // 2
                        if x * 8 + i == s_x and y * 8 + j == s_y:
                            B_t[s_x, s_y, n, i, j] = 1.0

    return B_t


def macroblock() -> Tensor:
    """Takes 4 x 8 x 8 and returns a 16 x 16"""
    B_t = torch.zeros((4, 8, 8, 16, 16))

    # 0 goes in top left
    for alpha in range(8):
        for beta in range(8):
            B_t[0, alpha, beta, alpha, beta] = 1

    # 1 goes in top right
    for alpha in range(8):
        for beta in range(8):
            B_t[1, alpha, beta, alpha + 8, beta] = 1

    # 2 goes in bottom left
    for alpha in range(8):
        for beta in range(8):
            B_t[2, alpha, beta, alpha, beta + 8] = 1

    # 3 goes in bottom right
    for alpha in range(8):
        for beta in range(8):
            B_t[3, alpha, beta, alpha + 8, beta + 8] = 1

    return B_t


class ResizeOps:  # pylint: disable=too-few-public-methods
    """
    Caches the resize operation tensors and enables them to be built on demand. Generally not for public use
    """

    resizer: Optional[Tensor] = None
    halfsizer: Optional[Tensor] = None
    dct: Optional[Tensor] = None
    reblocker: Optional[Tensor] = None
    macroblocker: Optional[Tensor] = None
    block_doubler: Optional[Tensor] = None
    block_halver: Optional[Tensor] = None

    @classmethod
    def lazy_build_ops(cls) -> None:
        """
        Builds the resize operations
        """

        # HACK assume none of the operators are set if the first one isnt set
        if cls.resizer is None:
            cls.resizer = double_size_tensor()

            cls.halfsizer = half_size_tensor()
            cls.dct = D()
            cls.reblocker = reblock()
            cls.macroblocker = macroblock()

            # block doubler combines the following linear operations in order: inverse DCT, NN doubling, reshape to 4 x 8 x 8, DCT, reshape back to 16 x 16
            cls.block_doubler = torch.einsum("ijab,ijmn,mnzxy,xypq,zpqrw->abrw", cls.dct, cls.resizer, cls.reblocker, cls.dct, cls.macroblocker)
            # 16 x 16 -> 4 x 8 x 8 -> idct -> 16 x 16 -> resize -> dct
            cls.block_halver = torch.einsum("mnzab,ijab,zijrw,rwxy,xypq->mnpq", cls.reblocker, cls.dct, cls.macroblocker, cls.halfsizer, cls.dct)


[docs]def double_nn_dct(input_dct: Tensor, op: Optional[Tensor] = None) -> Tensor: r""" double_nn_dct(input_dct: Tensor, op: Tensor = block_doubler) -> Tensor: DCT domain nearest neighbor doubling The function computes a 2x nearest neighbor upsampling on DCT coefficients without converting them to pixels. It is equivalent to the following procedure: IDCT -> 2x upsampling -> DCT Args: input_dct (Tensor): The input DCT coefficients in the format :math:`(N, C, H, W)` op (Tensor): The doubling operation tensor, mostly used to satisfy torchscript. Should be of shape :math:`8 \times 8 \times 16 \times 16`. Leave as default unless you know what you're doing. Returns: Tensor: The coefficients of the resized image, double the height and width of the input. """ if op is None: ResizeOps.lazy_build_ops() op = ResizeOps.block_doubler if op is not None: op = op.to(input_dct.device) dct_blocks = blockify(input_dct, 8) dct_doubled = torch.einsum("abrw,ncdab->ncdrw", [op, dct_blocks]) deblocked_doubled = deblockify(dct_doubled, (input_dct.shape[2] * 2, input_dct.shape[3] * 2)) return deblocked_doubled
[docs]def half_nn_dct(input_dct: Tensor, op: Optional[Tensor] = None) -> Tensor: r""" half_nn_dct(input_dct: Tensor, op: Tensor = block_halver) -> Tensor: DCT domain nearest neighbor half-sizing The function computes a 2x nearest neighbor downsampling on DCT coefficients without converting them to pixels. It is equivalent to the following procedure: IDCT -> 2x downsampling -> DCT Args: input_dct (Tensor): The input DCT coefficients in the format :math:`(N, C, H, W)` op (Tensor): The halving operation tensor, mostly used to satisfy torchscript. Should be of shape :math:`16 \times 16 \times 8 \times 8`. Leave as default unless you know what you're doing. Returns: Tensor: The coefficients of the resized image, halg the height and width of the input. """ if op is None: ResizeOps.lazy_build_ops() op = ResizeOps.block_halver if op is not None: op = op.to(input_dct.device) dct_blocks = blockify(input_dct, 16) dct_halved = torch.einsum("abrw,ncdab->ncdrw", [op, dct_blocks]) deblocked_halved = deblockify(dct_halved, (input_dct.shape[2] // 2, input_dct.shape[3] // 2)) return deblocked_halved