import dataclasses

 

import datasets

import torch

import torch.nn as nn

import tqdm

 

 

@dataclasses.dataclass

class BertConfig:

    “”“Configuration for BERT model.”“”

    vocab_size: int = 30522

    num_layers: int = 12

    hidden_size: int = 768

    num_heads: int = 12

    dropout_prob: float = 0.1

    pad_id: int = 0

    max_seq_len: int = 512

    num_types: int = 2

 

 

 

class BertBlock(nn.Module):

    “”“One transformer block in BERT.”“”

    def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_size, num_heads,

                                               dropout=dropout_prob, batch_first=True)

        self.attn_norm = nn.LayerNorm(hidden_size)

        self.ff_norm = nn.LayerNorm(hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

        self.feed_forward = nn.Sequential(

            nn.Linear(hidden_size, 4 * hidden_size),

            nn.GELU(),

            nn.Linear(4 * hidden_size, hidden_size),

        )

 

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:

        # self-attention with padding mask and post-norm

        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

        x = self.attn_norm(x + attn_output)

        # feed-forward with GeLU activation and post-norm

        ff_output = self.feed_forward(x)

        x = self.ff_norm(x + self.dropout(ff_output))

        return x

 

 

class BertPooler(nn.Module):

    “”“Pooler layer for BERT to process the [CLS] token output.”“”

    def __init__(self, hidden_size: int):

        super().__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)

        self.activation = nn.Tanh()

 

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.dense(x)

        x = self.activation(x)

        return x

 

 

class BertModel(nn.Module):

    “”“Backbone of BERT model.”“”

    def __init__(self, config: BertConfig):

        super().__init__()

        # embedding layers

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

                                            padding_idx=config.pad_id)

        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.embeddings_norm = nn.LayerNorm(config.hidden_size)

        self.embeddings_dropout = nn.Dropout(config.dropout_prob)

        # transformer blocks

        self.blocks = nn.ModuleList([

            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

            for _ in range(config.num_layers)

        ])

        # [CLS] pooler layer

        self.pooler = BertPooler(config.hidden_size)

 

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # create attention mask for padding tokens

        pad_mask = input_ids == pad_id

        # convert integer tokens to embedding vectors

        batch_size, seq_len = input_ids.shape

        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(position_ids)

        type_embeddings = self.type_embeddings(token_type_ids)

        token_embeddings = self.word_embeddings(input_ids)

        x = token_embeddings + type_embeddings + position_embeddings

        x = self.embeddings_norm(x)

        x = self.embeddings_dropout(x)

        # process the sequence with transformer blocks

        for block in self.blocks:

            x = block(x, pad_mask)

        # pool the hidden state of the `[CLS]` token

        pooled_output = self.pooler(x[:, 0, :])

        return x, pooled_output

 

 

class BertPretrainingModel(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()

        self.bert = BertModel(config)

        self.mlm_head = nn.Sequential(

            nn.Linear(config.hidden_size, config.hidden_size),

            nn.GELU(),

            nn.LayerNorm(config.hidden_size),

            nn.Linear(config.hidden_size, config.vocab_size),

        )

        self.nsp_head = nn.Linear(config.hidden_size, 2)

 

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

                ) -> tuple[torch.Tensor, torch.Tensor]:

        # Process the sequence with the BERT model backbone

        x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)

        # Predict the masked tokens for the MLM task and the classification for the NSP task

        mlm_logits = self.mlm_head(x)

        nsp_logits = self.nsp_head(pooled_output)

        return mlm_logits, nsp_logits

 

 

# Training parameters

epochs = 10

learning_rate = 1e4

batch_size = 32

 

# Load dataset and set up dataloader

dataset = datasets.Dataset.from_parquet(“wikitext-2_train_data.parquet”)

 

def collate_fn(batch: list[dict]):

    “”“Custom collate function to handle variable-length sequences in dataset.”“”

    # always at max length: tokens, segment_ids; always singleton: is_random_next

    input_ids = torch.tensor([item[“tokens”] for item in batch])

    token_type_ids = torch.tensor([item[“segment_ids”] for item in batch]).abs()

    is_random_next = torch.tensor([item[“is_random_next”] for item in batch]).to(int)

    # variable length: masked_positions, masked_labels

    masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item[“masked_positions”]]

    masked_labels = torch.tensor([label for item in batch for label in item[“masked_labels”]])

    return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels

 

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,

                                         collate_fn=collate_fn, num_workers=8)

 

# train the model

 

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

model = BertPretrainingModel(BertConfig()).to(device)

model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

loss_fn = nn.CrossEntropyLoss()

 

for epoch in range(epochs):

    pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)

    for batch in pbar:

        # get batched data

        input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch

        input_ids = input_ids.to(device)

        token_type_ids = token_type_ids.to(device)

        is_random_next = is_random_next.to(device)

        masked_labels = masked_labels.to(device)

        # extract output from model

        mlm_logits, nsp_logits = model(input_ids, token_type_ids)

        # MLM loss: masked_positions is a list of tuples of (B, S), extract the

        # corresponding logits from tensor mlm_logits of shape (B, S, V)

        batch_indices, token_positions = zip(*masked_pos)

        mlm_logits = mlm_logits[batch_indices, token_positions]

        mlm_loss = loss_fn(mlm_logits, masked_labels)

        # Compute the loss for the NSP task

        nsp_loss = loss_fn(nsp_logits, is_random_next)

        # backward with total loss

        total_loss = mlm_loss + nsp_loss

        pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())

        optimizer.zero_grad()

        total_loss.backward()

        optimizer.step()

        scheduler.step()

        pbar.update(1)

    pbar.close()

 

# Save the model

torch.save(model.state_dict(), “bert_pretraining_model.pth”)

torch.save(model.bert.state_dict(), “bert_model.pth”)



Source link


Leave a Reply

Your email address will not be published. Required fields are marked *