• Автор темы AI
  • Дата начала
  • " /> AI - Optimizing Transformer Models for Variable-Length Input Sequences | SoftoolStore.de - Программное обеспечение, Avid Media Composer, Книги, Новости, Windows, Интернет-новости, Бесплатные прокси (HTTP, Socks 4, Socks 5)

    AI Optimizing Transformer Models for Variable-Length Input Sequences

    AI

    Редактор
    Регистрация
    23 Август 2023
    Сообщения
    1 744
    Лучшие ответы
    0
    Баллы
    51
    Offline
    #1
    How PyTorch NestedTensors, FlashAttention2, and xFormers can Boost Performance and Reduce AI Costs


    Photo by Tanja Zöllner on Unsplash

    As generative AI (genAI) models grow in both popularity and scale, so do the computational demands and costs associated with their training and deployment. Optimizing these models is crucial for enhancing their runtime performance and reducing their operational expenses. At the heart of modern genAI systems is the Transformer architecture and its attention mechanism, which is notably compute-intensive.

    In a previous post, we demonstrated how using optimized attention kernels can significantly accelerate the performance of Transformer models. In this post, we continue our exploration by addressing the challenge of variable-length input sequences — an inherent property of real-world data, including documents, code, time-series, and more.

    The Challenge of Batching Variable-Length Input


    In a typical deep learning workload, individual samples are grouped into batches before being copied to the GPU and fed to the AI model. Batching improves computational efficiency and often aids model convergence during training. Usually, batching involves stacking all of the sample tensors along a new dimension — the batch dimension. However, torch.stack requires that all tensors to have the same shape, which is not the case with variable-length sequences.

    Padding and its Inefficiencies


    The traditional way to address this challenge is to pad the input sequences to a fixed length and then perform stacking. This solution requires appropriate masking within the model so that the output is not affected by the irrelevant tensor elements. In the case of attention layers, a padding mask indicates which tokens are padding and should not be attended to (e.g., see PyTorch MultiheadAttention). However, padding can waste considerable GPU resources, increasing costs and slowing development. This is especially true for large-scale AI models.

    Don’t Pad, Concatenate


    One way to avoid padding is to concatenate sequences along an existing dimension instead of stacking them along a new dimension. Contrary to torch.stack, torch.cat allows inputs of different shapes. The output of concatenation is single sequence whose length equals the sum of the lengths of the individual sequences. For this solution to work, our single sequence would need to be supplemented by an attention mask that would ensure that each token only attends to other tokens in the same original sequence, in a process sometimes referred to as document masking. Denoting the sum of the lengths of all of the individual by N and adopting ”big O” notation, the size of this mask would need to be O(N²), as would the compute complexity of a standard attention layer, making this solution highly inefficient.

    Attention Layer Optimization


    The solution to this problem comes in the form of specialized attention layers. Contrary to the standard attention layer that performs the full set of O(N²) attention scores only to mask out the irrelevant ones, these optimized attention kernels are designed to calculate only the scores that matter. In this post we will explore several solutions, each with their own distinct characteristics. These include:

    Integration into Existing HuggingFace Models


    For teams working with pre-trained models, transitioning to these optimizations might seem challenging. We will demonstrate how HuggingFace’s APIs simplify this process, enabling developers to integrate these techniques with minimal code changes and effort.

    Disclaimers

    • Please do not interpret our use of any platforms, libraries, or optimization techniques as an endorsement for their use. The best options for you will depend greatly on the specifics of your own use-case.
    • Some of the APIs discussed here are in prototype or beta stages and may change in the future.
    • The code examples provided are for demonstrative purposes only. We make no claims regarding their accuracy, optimality, or robustness.

    Special thanks to Yitzhak Levi and Peleg Nahaliel for their contributions to this post.

    Toy LLM Model


    To facilitate our discussion we will define a simple generative model (partially inspired by the GPT model defined here). For a more comprehensive guide on building language models, please see one of the many excellent tutorials available online (e.g., here).

    Transformer Block


    We begin by constructing a basic Transformer block, specifically designed to facilitate experimentation with different attention mechanisms and optimizations. While our block performs the same computation as standard Transformer blocks, we make slight modifications to the usual choice of operators in order to support the possibility of PyTorch NestedTensor inputs (as described here).

    # general imports
    import time, functools

    # torch imports
    import torch
    from torch.utils.data import Dataset, DataLoader
    import torch.nn as nn

    # Define Transformer settings
    BATCH_SIZE = 32
    NUM_HEADS = 16
    HEAD_DIM = 64
    DIM = NUM_HEADS * HEAD_DIM
    DEPTH = 24
    NUM_TOKENS = 1024
    MAX_SEQ_LEN = 1024
    PAD_ID = 0
    DEVICE = 'cuda'

    class MyAttentionBlock(nn.Module):
    def __init__(
    self,
    attn_fn,
    dim,
    num_heads,
    format=None,
    **kwargs
    ):
    super().__init__()
    self.attn_fn = attn_fn
    self.num_heads = num_heads
    self.dim = dim
    self.head_dim = dim // num_heads
    self.norm1 = nn.LayerNorm(dim, bias=False)
    self.norm2 = nn.LayerNorm(dim, bias=False)
    self.qkv = nn.Linear(dim, dim * 3)
    self.proj = nn.Linear(dim, dim)

    # mlp layers
    self.fc1 = nn.Linear(dim, dim * 4)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(dim * 4, dim)

    self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)
    if format == 'bshd':
    self.permute = nn.Identity()

    def mlp(self, x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.fc2(x)
    return x

    def reshape_and_permute(self,x, batch_size):
    x = x.view(batch_size, -1, self.num_heads, self.head_dim)
    return self.permute(x)

    def forward(self, x_in, attn_mask=None):
    batch_size = x_in.size(0)
    x = self.norm1(x_in)
    qkv = self.qkv(x)

    # rather than first reformatting and then splitting the input
    # state, we first split and then reformat q, k, v in order to
    # support PyTorch Nested Tensors
    q, k, v = qkv.chunk(3, -1)
    q = self.reshape_and_permute(q, batch_size)
    k = self.reshape_and_permute(k, batch_size)
    v = self.reshape_and_permute(v, batch_size)

    # call the attn_fn with the input attn_mask
    x = self.attn_fn(q, k, v, attn_mask=attn_mask)

    # reformat output
    x = self.permute(x).reshape(batch_size, -1, self.dim)
    x = self.proj(x)
    x = x + x_in
    x = x + self.mlp(self.norm2(x))
    return x
    Transformer Decoder Model


    Building on our programmable Transformer block, we construct a typical Transformer decoder model.

    class MyDecoder(nn.Module):
    def __init__(
    self,
    block_fn,
    num_tokens,
    dim,
    num_heads,
    num_layers,
    max_seq_len,
    pad_idx=None
    ):
    super().__init__()
    self.num_heads = num_heads
    self.pad_idx = pad_idx
    self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)
    self.positional_embedding = nn.Embedding(max_seq_len, dim)
    self.blocks = nn.ModuleList([
    block_fn(
    dim=dim,
    num_heads=num_heads
    )
    for _ in range(num_layers)])
    self.output = nn.Linear(dim, num_tokens)

    def embed_tokens(self, input_ids, position_ids=None):
    x = self.embedding(input_ids)
    if position_ids is None:
    position_ids = torch.arange(input_ids.shape[1],
    device=x.device)
    x = x + self.positional_embedding(position_ids)
    return x

    def forward(self, input_ids, position_ids=None, attn_mask=None):
    # Embed tokens and add positional encoding
    x = self.embed_tokens(input_ids, position_ids)
    if self.pad_idx is not None:
    assert attn_mask is None
    # create a padding mask - we assume boolean masking
    attn_mask = (input_ids != self.pad_idx)
    attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1) \
    .expand(-1, self.num_heads, -1, -1)

    for b in self.blocks:
    x = b(x, attn_mask)

    logits = self.output(x)
    return logits
    Variable Length Sequence Input


    Next, we create a dataset containing sequences of variable lengths, where each sequence is made up of randomly generated tokens. For simplicity, we (arbitrarily) select a fixed distribution for the sequence lengths. In real-world scenarios, the distribution of sequence lengths typically reflects the nature of the data, such as the length of documents or audio segments. Note, that the distribution of lengths directly affects the computational inefficiencies caused by padding.

    # Use random data
    class FakeDataset(Dataset):
    def __len__(self):
    return 1000000

    def __getitem__(self, index):
    length = torch.randint(1, MAX_SEQ_LEN, (1,))
    sequence = torch.randint(1, NUM_TOKENS, (length + 1,))
    input = sequence[:-1]
    target = sequence[1:]
    return input, target

    def pad_sequence(sequence, length, pad_val):
    return torch.nn.functional.pad(
    sequence,
    (0, length - sequence.shape[0]),
    value=pad_val
    )

    def collate_with_padding(batch):
    padded_inputs = []
    padded_targets = []
    for b in batch:
    padded_inputs.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))
    padded_targets.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)
    return {
    'inputs': padded_inputs,
    'targets': padded_targets
    }

    def data_to_device(data, device):
    if isinstance(data, dict):
    return {
    key: data_to_device(val,device)
    for key, val in data.items()
    }
    elif isinstance(data, (list, tuple)):
    return type(data)(
    data_to_device(val, device) for val in data
    )
    elif isinstance(data, torch.Tensor):
    return data.to(device=device, non_blocking=True)
    else:
    return data.to(device=device)
    Training/Evaluation Loop


    Lastly, we implement a main function that performs training/evaluation on input sequences of varying length.

    def main(
    block_fn,
    data_collate_fn=collate_with_padding,
    pad_idx=None,
    train=True,
    compile=False
    ):
    torch.random.manual_seed(0)
    device = torch.device(DEVICE)
    torch.set_float32_matmul_precision("high")

    # Create dataset and dataloader
    data_set = FakeDataset()
    data_loader = DataLoader(
    data_set,
    batch_size=BATCH_SIZE,
    collate_fn=data_collate_fn,
    num_workers=12,
    pin_memory=True,
    drop_last=True
    )

    model = MyDecoder(
    block_fn=block_fn,
    num_tokens=NUM_TOKENS,
    dim=DIM,
    num_heads=NUM_HEADS,
    num_layers=DEPTH,
    max_seq_len=MAX_SEQ_LEN,
    pad_idx=pad_idx
    ).to(device)

    if compile:
    model = torch.compile(model)

    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
    optimizer = torch.optim.SGD(model.parameters())

    def train_step(model, inputs, targets,
    position_ids=None, attn_mask=None):
    with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
    outputs = model(inputs, position_ids, attn_mask)
    outputs = outputs.view(-1, NUM_TOKENS)
    targets = targets.flatten()
    loss = criterion(outputs, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    @torch.no_grad()
    def eval_step(model, inputs, targets,
    position_ids=None, attn_mask=None):
    with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
    outputs = model(inputs, position_ids, attn_mask)
    if outputs.is_nested:
    outputs = outputs.data._values
    targets = targets.data._values
    else:
    outputs = outputs.view(-1, NUM_TOKENS)
    targets = targets.flatten()
    loss = criterion(outputs, targets)
    return loss

    if train:
    model.train()
    step_fn = train_step
    else:
    model.eval()
    step_fn = eval_step

    t0 = time.perf_counter()
    summ = 0
    count = 0

    for step, data in enumerate(data_loader):
    # Copy data to GPU
    data = data_to_device(data, device=device)
    step_fn(model, data['inputs'], data['targets'],
    position_ids=data.get('indices'),
    attn_mask=data.get('attn_mask'))

    # Capture step time
    batch_time = time.perf_counter() - t0
    if step > 20: # Skip first steps
    summ += batch_time
    count += 1
    t0 = time.perf_counter()
    if step >= 100:
    break
    print(f'average step time: {summ / count}')
    PyTorch SDPA with Padding


    For our baseline experiments, we configure our Transformer block to utilize PyTorch’s SDPA mechanism. In our experiments, we run both training and evaluation, both with and without torch.compile. These were run on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1

    from torch.nn.functional import scaled_dot_product_attention as sdpa
    block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)
    causal_block_fn = functools.partial(
    MyAttentionBlock,
    attn_fn=functools.partial(sdpa, is_causal=True)
    )

    for mode in ['eval', 'train']:
    for compile in [False, True]:
    block_func = causal_block_fn\
    if mode == 'train' else block_fn
    print(f'{mode} with {collate}, '
    f'{"compiled" if compile else "uncompiled"}')
    main(block_fn=block_func,
    pad_idx=PAD_ID,
    train=mode=='train',
    compile=compile)

    Performance Results:

    • Evaluation: 132 milliseconds (ms) without torch.compile, 130 ms with torch.compile
    • Training: 342 ms without torch.compile, 299 ms with torch.compile
    Optimizing for Variable Length Input


    In this section, we will explore several optimization techniques for handling variable-length input sequences in Transformer models.

    Padding Optimization


    Our first optimization relates not to the attention kernel but to our padding mechanism. Rather than padding the sequences in each batch to a constant length, we pad to the length of the longest sequence in the batch. The following block of code consists of our revised collation function and updated experiments.

    def collate_pad_to_longest(batch):
    padded_inputs = []
    padded_targets = []
    max_length = max([b[0].shape[0] for b in batch])
    for b in batch:
    padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))
    padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)
    return {
    'inputs': padded_inputs,
    'targets': padded_targets
    }

    for mode in ['eval', 'train']:
    for compile in [False, True]:
    block_func = causal_block_fn\
    if mode == 'train' else block_fn
    print(f'{mode} with {collate}, '
    f'{"compiled" if compile else "uncompiled"}')
    main(block_fn=block_func,
    data_collate_fn=collate_pad_to_longest,
    pad_idx=PAD_ID,
    train=mode=='train',
    compile=compile)

    Padding to the longest sequence in each batch results in a slight performance acceleration:

    • Evaluation: 129 ms without torch.compile, 116 ms with torch.compile
    • Training: 337 ms without torch.compile, 294 ms with torch.compile
    SDPA with PyTorch NestedTensors


    Next, we take advantage of the built-in support for PyTorch NestedTensors in SDPA in evaluation mode. Currently a prototype feature, PyTorch NestedTensors allows for grouping together tensors of varying length. These are sometimes referred to as jagged or ragged tensors. In the code block below, we define a collation function for grouping our sequences into NestedTensors. We also define an indices entry so that we can properly calculate the positional embeddings.

    PyTorch NestedTensors are supported by a limited number of PyTorch ops. Working around these limitations can require some creativity. For example, addition between NestedTensors is only supported when they share precisely the same “jagged” shape. In the code below we use a workaround to ensure that the indices entry shares the same shape as the model inputs.

    def nested_tensor_collate(batch):
    inputs = torch.nested.as_nested_tensor([b[0] for b in batch],
    layout=torch.jagged)
    targets = torch.nested.as_nested_tensor([b[1] for b in batch],
    layout=torch.jagged)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])

    # workaround for creating a NestedTensor with identical "jagged" shape
    xx = torch.empty_like(inputs)
    xx.data._values[:] = indices

    return {
    'inputs': inputs,
    'targets': targets,
    'indices': xx
    }

    for compile in [False, True]:
    print(f'eval with nested tensors, '
    f'{"compiled" if compile else "uncompiled"}')
    main(
    block_fn=block_fn,
    data_collate_fn=nested_tensor_collate,
    train=False,
    compile=compile
    )

    Although, with torch.compile, the NestedTensor optimization results in a step time of 131 ms, similar to our baseline result, in compiled mode the step time drops to 42 ms for an impressive ~3x improvement.

    FlashAttention2


    In our previous post we demonstrated the use of FlashAttention and its impact on the performance of a transformer model. In this post we demonstrate the use of flash_attn_varlen_func from flash-attn (2.7.0), an API designed for use with variable-sized inputs. To use this function, we concatenate all of the sequences in the batch into a single sequence. We also create a cu_seqlens tensor that points to the indices within the concatenated tensor where each of the individual sequences start. The code block below includes our collation function followed by evaluation and training experiments. Note, that flash_attn_varlen_func does not support torch.compile (at the time of this writing).

    def collate_concat(batch):
    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
    seqlens = torch.tensor([b[0].shape[0] for b in batch])
    seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
    cu_seqlens = torch.nn.functional.pad(seqlens, (1, 0))

    return {
    'inputs': inputs,
    'targets': targets,
    'indices': indices,
    'attn_mask': cu_seqlens
    }

    from flash_attn import flash_attn_varlen_func
    fa_varlen = lambda q, k, v, attn_mask: flash_attn_varlen_func(
    q.squeeze(0),
    k.squeeze(0),
    v.squeeze(0),
    cu_seqlens_q=attn_mask,
    cu_seqlens_k=attn_mask,
    max_seqlen_q=MAX_SEQ_LEN,
    max_seqlen_k=MAX_SEQ_LEN
    ).unsqueeze(0)

    fa_varlen_causal = lambda q, k, v, attn_mask: flash_attn_varlen_func(
    q.squeeze(0),
    k.squeeze(0),
    v.squeeze(0),
    cu_seqlens_q=attn_mask,
    cu_seqlens_k=attn_mask,
    max_seqlen_q=MAX_SEQ_LEN,
    max_seqlen_k=MAX_SEQ_LEN,
    causal=True
    ).unsqueeze(0)

    block_fn = functools.partial(MyAttentionBlock,
    attn_fn=fa_varlen,
    format='bshd')

    causal_block_fn = functools.partial(MyAttentionBlock,
    attn_fn=fa_varlen_causal,
    format='bshd')

    print('flash-attn eval')
    main(
    block_fn=block_fn,
    data_collate_fn=collate_concat,
    train=False
    )

    print('flash-attn train')
    main(
    block_fn=causal_block_fn,
    data_collate_fn=collate_concat,
    train=True,
    )

    The impact of this optimization is dramatic, 51 ms for evaluation and 160 ms for training, amounting to 2.6x and 2.1x performance boosts compared to our baseline experiment.

    XFormers Memory Efficient Attention


    In our previous post we demonstrated the use of the memory_efficient_attention operator from xFormers (0.0.28). Here we demonstrate the use of BlockDiagonalMask, specifically designed for input sequences of arbitrary length. The required collation function appears in the code block below followed by the evaluation and training experiments. Note, that torch.compile failed in training mode.

    from xformers.ops import fmha
    from xformers.ops import memory_efficient_attention as mea

    def collate_xformer(batch):
    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
    seqlens = [b[0].shape[0] for b in batch]
    batch_sizes = [1 for b in batch]
    block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, device='cpu')
    block_diag._batch_sizes = batch_sizes

    return {
    'inputs': inputs,
    'targets': targets,
    'indices': indices,
    'attn_mask': block_diag
    }

    mea_eval = lambda q, k, v, attn_mask: mea(
    q,k,v, attn_bias=attn_mask)

    mea_train = lambda q, k, v, attn_mask: mea(
    q,k,v, attn_bias=attn_mask.make_causal())

    block_fn = functools.partial(MyAttentionBlock,
    attn_fn=mea_eval,
    format='bshd')

    causal_block_fn = functools.partial(MyAttentionBlock,
    attn_fn=mea_train,
    format='bshd')

    print(f'xFormer Attention ')
    for compile in [False, True]:
    print(f'eval with xFormer Attention, '
    f'{"compiled" if compile else "uncompiled"}')
    main(block_fn=block_fn,
    train=False,
    data_collate_fn=collate_xformer,
    compile=compile)

    print(f'train with xFormer Attention')
    main(block_fn=causal_block_fn,
    train=True,
    data_collate_fn=collate_xformer)

    The resultant step time were 50 ms and 159 ms for evaluation and training without torch.compile. Evaluation with torch.compile resulted in a step time of 42 ms.

    Results


    The table below summarizes the results of our optimization methods.


    Step time results for different optimization methods (lower is better) — by Author

    The best performer for our toy model is xFormer’s memory_efficient_attention which delivered a ~3x performance for evaluation and ~2x performance for training. We caution against deriving any conclusions from these results as the performance impact of different attention functions can vary significantly depending on the specific model and use case.

    Optimizing a HuggingFace Model for Variable-Length Input


    The tools and techniques described above are easy to implement when creating a model from scratch. However, these days it is not uncommon for ML developers to adopt existing (pretrained) models and finetune them for their use case. While the optimizations we have described can be integrated without changing the set of model weights and without altering the model behavior, it is not entirely clear what the best way to do this is. In an ideal world, our ML framework would allow us to program the use of an attention mechanism that is optimized for variable-length inputs. In this section we demonstrate how to optimize HuggingFace models for variable-length inputs.

    A Toy HuggingFace Model - GPT2LMHeadModel


    To facilitate the discussion, we create a toy example in which we train a HuggingFace GPT2LMHead model on variable-length sequences. This requires adapting our random dataset and data-padding collation function according to HuggingFace's input specifications.

    from transformers import GPT2Config, GPT2LMHeadModel

    # Use random data
    class HuggingFaceFakeDataset(Dataset):
    def __len__(self):
    return 1000000

    def __getitem__(self, index):
    length = torch.randint(1, MAX_SEQ_LEN, (1,))
    input_ids = torch.randint(1, NUM_TOKENS, (length,))
    labels = input_ids.clone()
    labels[0] = PAD_ID # ignore first token
    return {
    'input_ids': input_ids,
    'labels': labels
    }
    return input_ids, labels

    def hf_collate_with_padding(batch):
    padded_inputs = []
    padded_labels = []
    for b in batch:
    input_ids = b['input_ids']
    labels = b['labels']
    padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))
    padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_labels = torch.stack(padded_labels, dim=0)
    return {
    'input_ids': padded_inputs,
    'labels': padded_labels,
    'attention_mask': (padded_inputs != PAD_ID)
    }
    Training Function


    Our training function instantiates a GPT2LMHeadModel based on the requested GPT2Config and proceeds to train it on our variable-length sequences.

    def hf_main(
    config,
    collate_fn=hf_collate_with_padding,
    compile=False
    ):
    torch.random.manual_seed(0)
    device = torch.device(DEVICE)
    torch.set_float32_matmul_precision("high")

    # Create dataset and dataloader
    data_set = HuggingFaceFakeDataset()
    data_loader = DataLoader(
    data_set,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=12 if DEVICE == "CUDA" else 0,
    pin_memory=True,
    drop_last=True
    )

    model = GPT2LMHeadModel(config).to(device)

    if compile:
    model = torch.compile(model)

    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
    optimizer = torch.optim.SGD(model.parameters())

    model.train()

    t0 = time.perf_counter()
    summ = 0
    count = 0

    for step, data in enumerate(data_loader):
    # Copy data to GPU
    data = data_to_device(data, device=device)
    input_ids = data['input_ids']
    labels = data['labels']
    position_ids = data.get('position_ids')
    attn_mask = data.get('attention_mask')
    with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
    outputs = model(input_ids=input_ids,
    position_ids=position_ids,
    attention_mask=attn_mask)
    logits = outputs.logits[..., :-1, :].contiguous()
    labels = labels[..., 1:].contiguous()
    loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Capture step time
    batch_time = time.perf_counter() - t0
    if step > 20: # Skip first steps
    summ += batch_time
    count += 1
    t0 = time.perf_counter()
    if step >= 100:
    break
    print(f'average step time: {summ / count}')
    SDPA with Padding


    In the callback below we call our training function with the default sequence-padding collator.

    config = GPT2Config(
    n_layer=DEPTH,
    n_embd=DIM,
    n_head=NUM_HEADS,
    vocab_size=NUM_TOKENS,
    )

    for compile in [False, True]:
    print(f"HF GPT2 train with SDPA, compile={compile}")
    hf_main(config=config, compile=compile)

    The resultant step times are 815 ms without torch.compile and 440 ms with torch.compile.

    FlashAttention2


    We now take advantage of HuggingFace’s built-in support for FlashAttention2, by setting the attn_implementation parameter to “flash_attention_2”. Behind the scenes, HuggingFace will unpad the padded data input and then pass them to the optimized flash_attn_varlen_func function we saw above:

    flash_config = GPT2Config(
    n_layer=DEPTH,
    n_embd=DIM,
    n_head=NUM_HEADS,
    vocab_size=NUM_TOKENS,
    attn_implementation='flash_attention_2'
    )

    print(f"HF GPT2 train with flash")
    hf_main(config=flash_config)

    The resultant time step is 620 ms, amounting to a 30% boost (in uncompiled mode) with just a simple flick of a switch.

    FlashAttention2 with Unpadded Input


    Of course, padding the sequences in the collation function only to have them unpadded, hardly seems sensible. In a recent update to HuggingFace, support was added for passing in concatenated (unpadded) sequences to a select number of models. Unfortunately, (as of the time of this writing) our GPT2 model did not make the cut. However, adding support requires just five small line additions changes to modeling_gpt2.py in order to propagate the sequence position_ids to the flash-attention kernel. The full patch appears in the block below:

    @@ -370,0 +371 @@
    + position_ids = None
    @@ -444,0 +446 @@
    + position_ids=position_ids
    @@ -611,0 +614 @@
    + position_ids=None
    @@ -621,0 +625 @@
    + position_ids=position_ids
    @@ -1140,0 +1145 @@
    + position_ids=position_ids

    We define a collate function that concatenates our sequences and train our hugging face model on unpadded sequences. (Also see the built-in DataCollatorWithFlattening utility.)

    def collate_flatten(batch):
    input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)
    labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)
    position_ids = [torch.arange(b['input_ids'].shape[0]) for b in batch]
    position_ids = torch.concat(position_ids)

    return {
    'input_ids': input_ids,
    'labels': labels,
    'position_ids': position_ids
    }

    print(f"HF GPT2 train with flash, no padding")
    hf_main(config=flash_config, collate_fn=collate_flatten)

    The resulting step time is 323 ms, 90% faster than running flash-attention on the padded input.

    Results


    The results of our HuggingFace experiments are summarized below.


    Step time results for different optimization methods (lower is better) — by Author

    With little effort, we were able to boost our runtime performance by 2.5x when compared to the uncompiled baseline experiment, and by 36% when compared to the compiled version.

    In this section, we demonstrated how the HuggingFace APIs allow us to leverage the optimized kernels in FlashAttention2, significantly boosting the training performance of existing models on sequences of varying length.

    Summary


    As AI models continue to grow in both popularity and complexity, optimizing their performance has become essential for reducing runtime and costs. This is especially true for compute-intensive components like attention layers. In this post, we have continued our exploration of attention layer optimization, and demonstrated new tools and techniques for enhancing Transformer model performance. For more insights on AI model optimization, be sure to check out the first post in this series as well as our many other posts on this topic.


    Optimizing Transformer Models for Variable-Length Input Sequences was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
     

    Похожие темы

    Сверху Снизу