• Автор темы AI
  • Дата начала
  • " /> AI - The Rise of Pallas: Unlocking TPU Potential with Custom Kernels | SoftoolStore.de - Программное обеспечение, Avid Media Composer, Книги, Новости, Windows, Интернет-новости, Бесплатные прокси (HTTP, Socks 4, Socks 5)

    AI The Rise of Pallas: Unlocking TPU Potential with Custom Kernels

    AI

    Редактор
    Регистрация
    23 Август 2023
    Сообщения
    1 746
    Лучшие ответы
    0
    Баллы
    51
    Offline
    #1
    Accelerating AI/ML Model Training with Custom Operators — Part 3


    Photo by Hendrik Morkel on Unsplash

    This is the third part of a series of posts on the topic of building custom operators for optimizing AI/ML workloads. In our previous post we demonstrated the simplicity and accessibility of Triton. Named for the Greek god of the sea, Triton empowers Python developers to increase their control over the GPU and optimize its use for the specific workload at hand. In this post we move one step down the lineage of Greek mythology to Triton’s daughter, Pallas and discuss her namesake, the JAX extension for writing custom kernels for GPU and TPU.

    One of the most important features of NVIDIA GPUs — and a significant factor in their rise to prominence — is their programmability. A key ingredient of the GPU offering are frameworks for creating General-Purpose GPU (GPGPU) operators, such as CUDA and Triton.

    In previous posts (e.g., here) we discussed the opportunity for running ML workloads on Google TPUs and the potential for a meaningful increase in price performance and a reduction in training costs. One of the disadvantages that we noted at the time was the absence of tools for creating custom operators. As a result, models requiring unique operators that were either unsupported by the underlying ML framework (e.g., TensorFlow/XLA) or implemented in a suboptimal manner, would underperform on TPU compared to GPU. This development gap was particularly noticeable over the past few years with the frequent introduction of newer and faster solutions for computing attention on GPU. Enabled by GPU kernel development frameworks, these led to a significant improvement in the efficiency of transformer models.

    On TPUs, on the other hand, the lack of appropriate tooling prevented this innovation and transformer models were stuck with the attention mechanisms that were supported by the official SW stack. Fortunately, with the advent of Pallas this gap has been addressed. Built as an extension to JAX and with dedicated support for PyTorch/XLA, Pallas enables the creation of custom kernels for GPU and TPU. For its GPU support Pallas utilizes Triton, and for its TPU support it uses a library called Mosaic. Although we will focus on custom kernels for TPU, it is worth noting that when developing in JAX, GPU kernel customization with Pallas offers some advantages over Triton (e.g., see here).

    Our intention in this post is to draw attention to Pallas and demonstrate its potential. Please do not view this post as a replacement for the official Pallas documentation. The examples we will share were chosen for demonstrative purposes, only. We have made no effort to optimize these or verify their robustness, durability, or accuracy.

    Importantly, at the time of this writing Pallas is an experimental feature and still under active development. The samples we share (which are based on JAX version 0.4.32 and PyTorch version 2.4.1) may become outdated by the time you read this. Be sure to use the most up-to-date APIs and resources available for your Pallas development.

    Many thanks to Yitzhak Levi for his contributions to this post.

    Environment Setup


    For the experiments described below we use the following environment setup commands:

    # create TPU node
    gcloud alpha compute tpus queued-resources create v5litepod-1-resource \
    --node-id v5litepod \
    --project <project-id> \
    --zone us-central1-a \
    --accelerator-type v5litepod-1 \
    --runtime-version v2-alpha-tpuv5-lite \
    --valid-until-duration 1d \
    --service-account <service-account> \

    # check TPU node status (wait for state to be ACTIVE)
    gcloud alpha compute tpus queued-resources describe v5litepod-1-resource \
    --project <project-id> \
    --zone us-central1-a

    # SSH to TPU node
    gcloud alpha compute tpus tpu-vm ssh v5litepod \
    --project <project-id> \
    --zone us-central1-a

    # install dependencies
    pip install torch_xla[tpu] \
    -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas]
    pip install timm

    # run tests
    python train.py

    #exit ssh
    exit

    # delete TPU node
    gcloud alpha compute tpus queued-resources delete v5litepod-1-resource \
    --project <project-id> \
    --zone us-central1-a --force --quiet
    Pallas Kernels for TPU


    In the toy example of our first post in this series, we distinguished between two different ways in which custom kernel development can potentially boost performance. The first is by combining (fusing) together multiple operations in a manner that reduces the overhead of: 1) loading multiple individual kernels, and 2) reading and writing intermediate values (e.g., see PyTorch’s tutorial on multiply-add fusion). The second is by meticulously applying the resources of the underlying accelerator in manner that optimizes the function at hand. We briefly discuss these two opportunities as they pertain to developing custom TPU kernels and make note of the limitations of the Pallas support.

    Operator Fusion on TPU


    The TPU is an XLA (Accelerated Linear Algebra) device, i.e., it runs code that has been generated by the XLA compiler. When training an AI model in a frameworks such as JAX or PyTorch/XLA, the training step is first transformed into an intermediate graph representation (IR). This computation graph is then fed to the XLA compiler which converts it into machine code that can run on the TPU. Contrary to eager execution mode, in which operations are executed individually, this mode of running models enables XLA to identify and implement opportunities for operator fusion during compilation. And, in fact, operator fusion is the XLA compiler’s most important optimization. Naturally, no compiler is perfect and we are certain to come across additional opportunities for fusion through custom kernels. But, generally speaking, we might expect the opportunity for boosting runtime performance in this manner to be lower than in the case of eager execution.

    Optimizing TPU Utilization


    Creating optimal kernels for TPU requires a comprehensive and intimate understanding of the TPU system architecture. Importantly, TPUs are very different from GPUs: expertise in GPUs and CUDA does not immediately carry over to TPU development. For example, while GPUs contain a large number of processors and draw their strength from their ability to perform massive parallelization, TPUs are primarily sequential with dedicated engines for running highly vectorized operations and support for asynchronous scheduling and memory loading.

    The differences between the underlying architectures of the GPU and TPU can have significant implications on how custom kernels should be designed. Mastering TPU kernel development requires 1) appropriate overlapping of memory and compute operations via pipelining, 2) knowing how to mix between the use of the scalar, vector (VPU) and matrix (MXU) compute units and their associated scalar and vector registers (SREG and VREG) and memory caches (SMEM and VMEM), 3) a comprehension of the costs of different low-level operations, 4) appropriate megacore configuration (on supporting TPU generations), 5) a grasp of the different types of TPU topologies and their implications on how to support distributed computing, and more.

    Framework Limitations


    While the ability to create custom operators in Python using JAX functions and APIs greatly increases the simplicity and accessibility of Pallas kernel development, it also limits its expressivity. Additionally, (as of the time of this writing) there are some JAX APIs that are not supported by Pallas on TPU (e.g., see here). As a result, you may approach Pallas with the intention of implementing a particular operation only to discover that the framework does not support the APIs that you need. This is in contrast to frameworks such as CUDA which enable a great deal of flexibility when developing custom kernels (for GPU).

    The matrix multiplication tutorial in the Pallas documentation provides an excellent introduction to Pallas kernel development, highlighting the potential for operator fusion and customization alongside the challenges involved in optimizing performance (e.g., appropriate tuning of the input block size). The tutorial clearly illustrates that maximizing the full potential of the TPU requires a certain degree of specialization. However, as we intend to demonstrate, even the novice ML developer can benefit from Pallas kernels.

    Integrating the Use of Existing Pallas Kernels


    To benefit from custom Pallas kernels you do not necessarily need to know how to build them. In our first example we demonstrate how you can leverage existing Pallas kernels from dedicated public repositories.

    Example — Flash Attention in Torch/XLA


    The JAX github repository includes implementations of a number of Pallas kernels, including flash attention. Here we will demonstrate its use in a Torch/XLA Vision Transformer (ViT) model. Although Pallas kernels are developed in JAX, they can be adopted into Torch/XLA, e.g., via the make_kernel_from_pallas utility (see the documentation for details). In the case of flash attention the adoption is implemented by Torch/XLA.

    In the following code block we define a stripped down version of the classic timm attention block with an option to define the underlying attention operator in the constructor. We will use this option to compare the performance of the flash attention Pallas kernel to its alternatives.

    # general imports
    import os, time, functools
    # torch imports
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    import torch_xla.core.xla_model as xm
    # custom kernel import
    from torch_xla.experimental.custom_kernel import flash_attention
    # timm imports
    from timm.layers import Mlp
    from timm.models.vision_transformer import VisionTransformer

    class TPUAttentionBlock(nn.Module):
    def __init__(
    self,
    dim: int = 768,
    num_heads: int = 12,
    attn_fn = None,
    **kwargs
    ) -> None:
    super().__init__()
    self.attn_fn = attn_fn
    self.num_heads = num_heads
    self.head_dim = dim // num_heads
    self.norm1 = nn.LayerNorm(dim)
    self.norm2 = nn.LayerNorm(dim)
    self.qkv = nn.Linear(dim, dim * 3, bias=False)
    self.proj = nn.Linear(dim, dim)
    self.mlp = Mlp(
    in_features=dim,
    hidden_features=dim * 4,
    )

    def forward(self, x_in: torch.Tensor) -> torch.Tensor:
    x = self.norm1(x_in)

    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)

    if self.attn_fn is None:
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    x = attn @ v
    else:
    x = self.attn_fn(q, k, v)

    x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = x + x_in
    x = x + self.mlp(self.norm2(x))
    return x

    In the following block we train a simple ViT-backed classification model using the input dataset and attention function (attn_fn) of choice.

    def train(dataset, attn_fn=None):
    device = xm.xla_device()

    train_loader = DataLoader(
    dataset,
    batch_size=128,
    num_workers=os.cpu_count(),
    pin_memory=True
    )

    # configure the VisionTranformer in a manner that complies with the
    # Pallas flash_attention kernel constraints
    model = VisionTransformer(
    block_fn=functools.partial(TPUAttentionBlock, attn_fn=attn_fn),
    img_size=256,
    class_token=False,
    global_pool="avg"
    )

    optimizer = torch.optim.SGD(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    # copy the model to the TPU
    model = model.to(device)

    model.train()

    t0 = time.perf_counter()
    summ = 0
    count = 0


    for step, data in enumerate(train_loader):
    # copy data to TPU
    inputs = data[0].to(device=device, non_blocking=True)
    label = data[1].to(device=device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)
    with torch.autocast('xla', dtype=torch.bfloat16):
    output = model(inputs)
    loss = loss_fn(output, label)
    loss.backward()
    optimizer.step()
    xm.mark_step()

    # capture step time
    batch_time = time.perf_counter() - t0
    if step > 20: # skip first steps
    summ += batch_time
    count += 1
    t0 = time.perf_counter()
    if step > 100:
    break

    print(f'average step time: {summ / count}')

    Note the specific configuration we chose for the VisionTransformer. This is to comply with certain restrictions (as of the time of this writing) of the custom flash attention kernel (e.g., on tensor shapes).

    Finally, we define a dataset and compare the runtimes of training with three different attention routines, 1. using native PyTorch functions, 2. using PyTorch’s built in SDPA function, and 3. using the custom Pallas operator:

    # use random data
    class FakeDataset(Dataset):
    def __len__(self):
    return 1000000

    def __getitem__(self, index):
    rand_image = torch.randn([3, 256, 256], dtype=torch.float32)
    label = torch.tensor(data=index % 1024, dtype=torch.int64)
    return rand_image, label

    ds = FakeDataset()

    print('PyTorch native')
    train(ds, attn_fn=None)

    print('PyTorch SDPA')
    train(ds, attn_fn=functools.partial(F.scaled_dot_product_attention, scale=1.0))

    print('Pallas flash_attention')
    train(ds, attn_fn=flash_attention)

    The comparative results are captured in the table below:


    Step time for different attention blocks (lower is better) — by Author

    Although our Pallas kernel clearly underperforms when compared to its alternatives, we should not be discouraged:

    1. It is likely that these results could be improved with appropriate tuning.
    2. These results are specific to the model and runtime environment that we chose. The Pallas kernel may exhibit wholly different comparative results in other use cases.
    3. The real power of Pallas is in the ability to create and adjust low level operators to our specific needs. Although runtime performance is important, a 23% performance penalty (as in our example) may be a small price to pay for this flexibility. Moreover, the opportunity for customization may open up possibilities for optimizations that are not supported by the native framework operations.
    Enhancing Existing Kernels


    Oftentimes it may be easier to tweak an existing Pallas kernel to your specific needs, rather than creating one from scratch. This is especially recommended if the kernel has already been optimized as performance tuning can be tedious and time-consuming. The official matrix multiplication tutorial includes a few examples of how to extend and enhance an existing kernel. Here we undertake one of the suggested exercises: we implement int8 matrix multiplication and assess its performance advantage over its bfloat16 alternative.

    Example — Int8 Matrix Multiplication


    In the code block below we implement an int8 version of the matrix multiplication example.

    import functools, timeit
    import jax
    import jax.numpy as jnp
    from jax.experimental import pallas as pl
    from jax.experimental.pallas import tpu as pltpu


    # set to True to develop/debug on CPU
    interpret = False


    def matmul_kernel_int8(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
    @pl.when(pl.program_id(2) == 0)
    def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

    acc_ref[...] += jnp.dot(
    x_ref[...], y_ref[...], preferred_element_type=jnp.int32
    )

    @pl.when(pl.program_id(2) == nsteps - 1)
    def _():
    z_ref[...] = acc_ref[...]


    @functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
    def matmul_int8(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    ):
    m, k = x.shape
    _, n = y.shape
    return pl.pallas_call(
    functools.partial(matmul_kernel_int8, nsteps=k // bk),
    grid_spec=pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
    pl.BlockSpec(block_shape=(bm, bk),
    index_map=lambda i, j, k: (i, k)),
    pl.BlockSpec(block_shape=(bk, bn),
    index_map=lambda i, j, k: (k, j)),
    ],
    out_specs=pl.BlockSpec(block_shape=(bm, bn),
    index_map=lambda i, j, k: (i, j)),
    scratch_shapes=[pltpu.VMEM((bm, bn), jnp.int32)],
    grid=(m // bm, n // bn, k // bk),
    ),
    out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32),
    compiler_params=dict(mosaic=dict(
    dimension_semantics=("parallel", "parallel", "arbitrary"))),
    interpret=interpret
    )(x, y)

    Note our use of an int32 accumulation matrix for addressing the possibility of overflow. Also note our use of the interpret flag for debugging of Pallas kernels on CPU (as recommended here).

    To assess our kernel, we introduce a slight modification to the benchmarking utilities defined in the tutorial and compare the runtime results to both the jnp.float16 Pallas matmul kernel and the built-in JAX matmul API:

    def benchmark(f, ntrials: int = 100):
    def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    res=timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
    number=ntrials
    )
    time = res/ntrials
    # print(f"Time: {time}")
    return time

    return run


    def analyze_matmul(m: int, k: int, n: int, dtype: jnp.dtype,
    mm_func):
    x = jnp.ones((m, k), dtype=dtype)
    y = jnp.ones((k, n), dtype=dtype)
    time = benchmark(mm_func)(x, y)
    print("Matmul time: ", time)
    mm_ops = 2*m*k*n/time
    v5e_ops = 394e12 if dtype == jnp.int8 else 197e12
    print(f"OP/s utilization: {mm_ops / v5e_ops * 100:.4f}%")
    print()


    print("bfloat16 Pallas matmul")
    mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
    analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)


    print("int8 Pallas matmul")
    mm = functools.partial(matmul_int8, bm=512, bk=1024, bn=1024)
    analyze_matmul(8192, 8192, 8192, jnp.int8, mm)

    print("XLA int8 matmul")
    mm = functools.partial(jnp.matmul, preferred_element_type=jnp.int32)
    analyze_matmul(8192, 8192, 8192, jnp.int8, mm)

    The results of our experiment are captured in the table below:


    Matmul time and utilization (by Author)

    By using int8 matrices (rather than bfloat16matrices) on tpuv5e we can boost the runtime performance of our custom matrix multiplication kernel by 71%. However, as in the case of the bfloat16 example, additional tuning is required to match the performance of the built-in matmul operator. The potential for improvement is highlighted by the drop in system utilization when compared to bfloat16.

    Creating a Kernel from Scratch


    While leveraging existing kernels can be greatly beneficial, it is unlikely to solve all of your problems. Inevitably, you may need to implement an operation that is either unsupported on TPU or exhibits suboptimal performance. Here we demonstrate the creation of a relatively simple pixel-wise kernel. For the sake of continuity, we choose the same Generalized Intersection Over Union (GIOU) operation as in our previous posts.

    Example — A GIOU Pallas Kernel


    In the code block below we define a Pallas kernel that implements GIOU on pairs of batches of bounding boxes, each of dimension BxNx4 (where we denote the batch size by B and the number of boxes per sample by N) . The function returns a tensor of scores of dimension BxN. We choose a block size of 128 on both the batch axis and the boxes axis, i.e., we divide each of the tensors into blocks of 128x128x4 and pass them to our kernel function. The grid and BlockSpec index_map are defined accordingly.

    import timeit
    import jax
    from jax.experimental import pallas as pl
    import jax.numpy as jnp

    # set to True to develop/debug on CPU
    interpret = False

    # perform giou on a single block
    def giou_kernel(preds_left_ref,
    preds_top_ref,
    preds_right_ref,
    preds_bottom_ref,
    targets_left_ref,
    targets_top_ref,
    targets_right_ref,
    targets_bottom_ref,
    output_ref):
    epsilon = 1e-5

    # copy tensors into local memory
    preds_left = preds_left_ref[...]
    preds_top = preds_top_ref[...]
    preds_right = preds_right_ref[...]
    preds_bottom = preds_bottom_ref[...]

    gt_left = targets_left_ref[...]
    gt_top = targets_top_ref[...]
    gt_right = targets_right_ref[...]
    gt_bottom = targets_bottom_ref[...]

    # Compute the area of each box
    area1 = (preds_right - preds_left) * (preds_bottom - preds_top)
    area2 = (gt_right - gt_left) * (gt_bottom - gt_top)

    # Compute the intersection
    left = jnp.maximum(preds_left, gt_left)
    top = jnp.maximum(preds_top, gt_top)
    right = jnp.minimum(preds_right, gt_right)
    bottom = jnp.minimum(preds_bottom, gt_bottom)

    # intersection width and height
    inter_w = jnp.maximum(right - left, 0)
    inter_h = jnp.maximum(bottom - top, 0)

    # intersection area
    inter_area = inter_w * inter_h

    # union of two boxes
    union_area = area1 + area2 - inter_area

    iou_val = inter_area / jnp.maximum(union_area, epsilon)

    # Compute the smallest enclosing box
    enclose_left = jnp.minimum(preds_left, gt_left)
    enclose_top = jnp.minimum(preds_top, gt_top)
    enclose_right = jnp.maximum(preds_right, gt_right)
    enclose_bottom = jnp.maximum(preds_bottom, gt_bottom)

    # enclosing box width and height
    enclose_w = jnp.maximum(enclose_right - enclose_left, 0)
    enclose_h = jnp.maximum(enclose_bottom - enclose_top, 0)

    # enclosing box area
    enclose_area = enclose_w * enclose_h

    # Compute GIOU
    delta_area = (enclose_area - union_area)
    enclose_area = jnp.maximum(enclose_area, epsilon)
    output_ref[...] = iou_val - delta_area / enclose_area


    @jax.jit
    def batch_giou(preds, targets):
    m, n, _ = preds.shape
    output = pl.pallas_call(
    giou_kernel,
    out_shape=jax.ShapeDtypeStruct((m, n), preds.dtype),
    in_specs=[pl.BlockSpec(block_shape=(128, 128),
    index_map=lambda i, j: (i, j))]*8,
    out_specs=pl.BlockSpec(block_shape=(128, 128),
    index_map=lambda i, j: (i, j)),
    grid=(m // 128, n // 128),
    compiler_params=dict(mosaic=dict(
    dimension_semantics=("parallel", "parallel"))),
    interpret=interpret
    )(*jnp.unstack(preds, axis=-1), *jnp.unstack(targets, axis=-1))
    return output

    Although the creation of a new TPU kernel is certainly cause for celebration (especially if it enables a previously blocked ML workload) our work is not done. A critical part of Pallas kernel development is tuning the operator, (e.g. the block size) for optimal runtime performance. We omit this stage in the interest of brevity.

    To asses the performance of our kernel, we compare it to the following native JAX GIOU implementation:

    def batched_box_iou(boxes1, boxes2):
    epsilon = 1e-5

    # Compute areas of both sets of boxes
    area1 = (boxes1[..., 2]-boxes1[..., 0])*(boxes1[..., 3]-boxes1[..., 1])
    area2 = (boxes2[..., 2]-boxes2[..., 0])*(boxes2[..., 3]-boxes2[..., 1])

    # corners of intersection
    lt = jnp.maximum(boxes1[..., :2], boxes2[..., :2])
    rb = jnp.minimum(boxes1[..., 2:], boxes2[..., 2:])

    # width and height of intersection
    wh = jnp.clip(rb - lt, a_min=0)

    # area of the intersection
    inter = wh[..., 0] * wh[..., 1]

    # union of the two boxes
    union = area1 + area2 - inter
    iou = inter / jnp.clip(union, a_min=epsilon)

    # corners of enclosing box
    lti = jnp.minimum(boxes1[..., :2], boxes2[..., :2])
    rbi = jnp.maximum(boxes1[..., 2:], boxes2[..., 2:])

    # Width and height of the enclosing box
    whi = jnp.clip(rbi - lti, a_min=0)

    # Area of the enclosing box
    areai = jnp.clip(whi[..., 0] * whi[..., 1], a_min=epsilon)

    # Generalized IoU
    return iou - (areai - union) / areai

    We generate two batches of randomly generated bounding boxes and measure the performance of our functions using the benchmark function defined above.

    from jax import random

    batch_size = 1024
    n_boxes = 256
    img_size = 256
    boxes = []
    for i in range(2):
    k1, k2 = random.split(random.key(i), 2)

    # Randomly generate box sizes and positions
    box_sizes = random.randint(k1, shape=(batch_size, n_boxes, 2), minval=1, maxval=img_size)
    top_left = random.randint(k2, shape=(batch_size, n_boxes, 2), minval=0, maxval=img_size - 1)
    bottom_right = jnp.clip(top_left + box_sizes, 0, img_size - 1)

    # Concatenate top-left and bottom-right coordinates
    rand_boxes = jnp.concatenate((top_left, bottom_right), axis=2)

    boxes.append(rand_boxes.astype(jnp.float32))


    time = benchmark(batch_giou)(boxes[0], boxes[1])
    print(f'Pallas kernel: {time}')
    time = benchmark(batched_box_iou)(boxes[0], boxes[1])
    print(f'JAX function: {time}')
    time = benchmark(jax.jit(batched_box_iou))(boxes[0], boxes[1])
    print(f'Jitted function: {time}')

    The comparative results appear in the table below:


    Avg time of different GIOU implementations (lower is better) — by Author

    We can see that JIT-compiling our naive JAX implementation results in slightly better performance than our Pallas kernel. Once again, we can see that matching or surpassing the performance results of JIT compilation (and its inherent kernel fusion) would require fine-tuning of our custom kernel.

    Utilizing the Sequential Nature of TPUs


    While the ability to develop custom kernels for TPU offers great potential, our examples thus far have demonstrated that reaching optimal runtime performance could be challenging. One way to overcome this is to seek opportunities to utilize the unique properties of the TPU architecture. One example of this is the sequential nature of the TPU processor. Although deep learning workloads tend to rely on operations that are easily parallelizable (e.g., matrix multiplication), on occasion they require algorithms that are inherently sequential. These can pose a serious challenge for the SIMT (single instruction multi thread) model of GPUs and can sometimes have a disproportionate impact on runtime performance. In a sequel to this post, we demonstrate how we can implement sequential algorithms in a way that takes advantage of the TPUs sequential processor and in a manner that minimizes their performance penalty.

    Summary


    The introduction of Pallas marks an important milestone in the evolution of TPUs. By enabling customization of TPU operations it can potentially unlock new opportunities for TPU programmability, particularly in the world of ML. Our intention in this post was to demonstrate the accessibility of this powerful new feature. While our examples have indeed shown this, they have also highlighted the effort required to reach optimal runtime performance.

    This post has merely scratched the surface of Pallas kernel development. Be sure to see the official documentation to learn more about automatic differentiation in Pallas, developing sparse kernels, and more.


    The Rise of Pallas: Unlocking TPU Potential with Custom Kernels was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
     

    Похожие темы

    Сверху Снизу