{“payload”:{“allShortcutsEnabled”:false,”fileTree”:{“”:{“items”:[{“name”:”data”,”path”:”data”,”contentType”:”directory”},{“name”:”lib”,”path”:”lib”,”contentType”:”directory”},{“name”:”BERT_for_laptops.ipynb”,”path”:”BERT_for_laptops.ipynb”,”contentType”:”file”}],”totalCount”:3}},”fileTreeProcessingTime”:1.589477,”foldersToFetch”:[],”reducedMotionEnabled”:null,”repo”:{“id”:673536158,”defaultBranch”:”main”,”name”:”bert-for-laptops”,”ownerLogin”:”samvher”,”currentUserCanPush”:false,”isFork”:false,”isEmpty”:false,”createdAt”:”2023-08-01T21:19:30.000Z”,”ownerAvatar”:”https://avatars.githubusercontent.com/u/4366473?v=4″,”public”:true,”private”:false,”isOrgOwned”:false},”symbolsExpanded”:false,”treeExpanded”:true,”refInfo”:{“name”:”main”,”listCacheKey”:”v0:1691268014.0″,”canEdit”:false,”refType”:”branch”,”currentOid”:”1f904b870c455e909d2858428779b657e69445aa”},”path”:”BERT_for_laptops.ipynb”,”currentUser”:null,”blob”:{“rawLines”:[“{“,” “cells”: [“,” {“,” “cell_type”: “markdown”,”,” “id”: “8826cabb-5de1-494b-b814-0f17646c81dd”,”,” “metadata”: {},”,” “source”: [“,” “# A BERT for laptops, from scratch\n”,”,” “\n”,”,” “This is a simple BERT lookalike that was developed for training on a laptop (with an Nvidia 3070 RTX GPU). The notebook is developed for educational purposes more than performance, but in a bit more than half a day of training you can get a model that (after further finetuning) obtains ~94% of the performance of the original BERT-base on the GLUE benchmark. The code here builds on work by Geiping & Goldstein [0], Izsak et al [1] and Karpathy [2] who have all made large language models (LLMs) more accessible for modest budgets.\n”,”,” “\n”,”,” “You can execute this notebook from start to end to see the full process of setting up and training a tokenizer, pretraining a BERT model, and finetuning a BERT model on downstream NLP tasks. Most of the code from the notebook can also be found in this repository in regular Python files if you prefer, together with a few extra bits (e.g. [SpanBERT](https://arxiv.org/abs/1907.10529) style sample generation).\n”,”,” “\n”,”,” “The document is split into three sections:\n”,”,” “* _Data:_ This is where we obtain and preprocess the data for pretraining, and build and train the BPE tokenizer.\n”,”,” “* _Architecture:_ This is where we define our BERT.\n”,”,” “* _Training:_ First we pretrain the model on a \”masked language modeling\” (MLM) objective with a lot of data, then we finetune on a few smaller tasks from the GLUE benchmark.\n”,”,” “\n”,”,” “If you want to run the full notebook on a full size model, expect training the tokenizer to take ~15 hours, pretraining with the MLM objective to take ~17 hours (on a 3070 RTX, adjust expectations for your own system), and finetuning to take about an hour. The notebook was tested with 32GB of regular RAM and 8GB of GPU memory, if you have less you might need to make some changes.\n”,”,” “\n”,”,” “This BERT variant uses:\n”,”,” “* BPE (Byte Pair Encoding) tokenization.\n”,”,” “* Relative position embeddings.\n”,”,” “* Pre-layernorm.\n”,”,” “* No dropout.\n”,”,” “* Automatic mixed precision.\n”,”,” “\n”,”,” “[0] Geiping, Jonas, and Tom Goldstein. \”Cramming: Training a Language Model on a single GPU in one day.\” _International Conference on Machine Learning_. PMLR, 2023.\n”,”,” “\n”,”,” “[1] Izsak, Peter, Moshe Berchansky, and Omer Levy. \”How to train BERT with an academic budget.\” _arXiv preprint arXiv:2104.07705_ (2021).\n”,”,” “\n”,”,” “[2] Karpathy, Andrej. [\”MinGPT\”](https://github.com/karpathy/minGPT) (2020).””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 1,”,” “id”: “6af7245f-be5f-4a0a-82d7-b05d898727e9″,”,” “metadata”: {},”,” “outputs”: [],”,” “source”: [“,” “# Torch, cuda, numpy, scipy, matplotlib, etc are all assumed to be present\n”,”,” “\n”,”,” “# Other dependencies, uncomment this line if you don’t have these installed:\n”,”,” “# ! pip install datasets tqdm unidecode\n”,”,” “\n”,”,” “import math\n”,”,” “import os\n”,”,” “import pickle\n”,”,” “import random\n”,”,” “import re\n”,”,” “import string\n”,”,” “import time\n”,”,” “\n”,”,” “from collections import Counter\n”,”,” “from multiprocessing import Pool\n”,”,” “\n”,”,” “from matplotlib import pyplot as plt\n”,”,” “\n”,”,” “import numpy as np\n”,”,” “import scipy\n”,”,” “\n”,”,” “import torch\n”,”,” “import torch.nn as nn\n”,”,” “from torch import optim\n”,”,” “from torch.amp import autocast\n”,”,” “from torch.cuda.amp import GradScaler\n”,”,” “from torch.nn import functional as F\n”,”,” “\n”,”,” “from datasets import load_dataset\n”,”,” “from tqdm import tqdm\n”,”,” “from unidecode import unidecode””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “f0e20a1f-62ff-41df-abac-faa52644569b”,”,” “metadata”: {},”,” “source”: [“,” “## Data\n”,”,” “\n”,”,” “We will be working with two datasets: [BookCorpusOpen](https://huggingface.co/datasets/bookcorpusopen) and [Wikipedia English](https://huggingface.co/datasets/wikipedia), both of which are available from [Hugging Face](https://huggingface.co/). We are going to do a few things with the data:\n”,”,” “\n”,”,” “1. Fetch the data.\n”,”,” “2. Clean the data and collect unique words with their statistics.\n”,”,” “3. Generate a byte-pair encoding for the words in the data.\n”,”,” “4. Chunk training data to sequences of 128 tokens.\n”,”,” “5. Generate samples for training BERT.””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “5d101c75-2961-4aa2-a0d9-ae9dcae89796″,”,” “metadata”: {},”,” “source”: [“,” “### Fetching\n”,”,” “\n”,”,” “Fetch the datasets from Hugging Face. The first time you run this it will download ~26GB of data, plan accordingly. Later calls will use a local cache. The Wikipedia data is split into chunks of roughly similar size to support parallel processing later on.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 2,”,” “id”: “b7cb0485-7aa8-473a-a107-d4e9dd2e1d1a”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [“,” {“,” “name”: “stderr”,”,” “output_type”: “stream”,”,” “text”: [“,” “Found cached dataset bookcorpus (/home/sam/.cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/eddee3cae1cc263a431aa98207d4d27fd8a73b0a9742f692af0e6c65afa4d75f)\n”,”,” “Found cached dataset wikipedia (/home/sam/.cache/huggingface/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)\n”,”,” “Found cached dataset wikipedia (/home/sam/.cache/huggingface/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)\n”,”,” “Found cached dataset wikipedia (/home/sam/.cache/huggingface/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)\n”,”,” “Found cached dataset wikipedia (/home/sam/.cache/huggingface/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)\n”,”,” “Found cached dataset wikipedia (/home/sam/.cache/huggingface/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)\n””,” ]”,” }”,” ],”,” “source”: [“,” “bc = load_dataset(\”bookcorpus\”, split = \”train\”)\n”,”,” “wp_a = load_dataset(\”wikipedia\”, \”20220301.en\”, split = \”train[0:750000]\”)\n”,”,” “wp_b = load_dataset(\”wikipedia\”, \”20220301.en\”, split = \”train[750000:1500000]\”)\n”,”,” “wp_c = load_dataset(\”wikipedia\”, \”20220301.en\”, split = \”train[1500000:3250000]\”)\n”,”,” “wp_d = load_dataset(\”wikipedia\”, \”20220301.en\”, split = \”train[3250000:5000000]\”)\n”,”,” “wp_e = load_dataset(\”wikipedia\”, \”20220301.en\”, split = \”train[5000000:]\”)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “21adc2db-723a-49a9-be06-4fe40d652d73″,”,” “metadata”: {},”,” “source”: [“,” “### Cleaning and collecting word frequency\n”,”,” “\n”,”,” “Define the cleaning logic:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 3,”,” “id”: “0178fcd3-c522-4483-9f1c-298fb2dd867a”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def clean_string(s):\n”,”,” ” s = unidecode(s) # Make sure we have only ASCII characters\n”,”,” ” s = s.lower() # Lowercase\n”,”,” ” s = re.sub(‘[ \\t]+’, ‘ ‘, s) # Replace tabs and sequences of spaces with a single space\n”,”,” ” s = s.replace(‘\\n’, ‘\\\\n’) # Escape newlines\n”,”,” ” return s.strip() # Remove leading and trailing whitespace””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “b9d8a9b4-c65e-49e3-89eb-a56c939c3cc8″,”,” “metadata”: {},”,” “source”: [“,” “Preprocess the data. The code in this section will generate some large files in the `data/` directory.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 4,”,” “id”: “1cc037b7-c2d0-4fa2-809b-cf6787442a03″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def preprocess_dataset(d, tag):\n”,”,” ” c = Counter() # Will keep track of word counts\n”,”,” ” \n”,”,” ” # Save clean data to a local text file\n”,”,” ” with open(f\”data/{tag}.txt\”, \”w\”) as f:\n”,”,” ” for sample in d:\n”,”,” ” s_clean = clean_string(sample[‘text’])\n”,”,” ” f.write(s_clean + ‘\\n’)\n”,”,” ” words = re.findall(r'[a-zA-Z]+’, s_clean.replace(‘\\\\n’, ‘ ‘)) # avoid capturing the ‘n’s of newlines\n”,”,” ” c.update(words)\n”,”,” ” \n”,”,” ” # Pickle counts\n”,”,” ” with open(f\”data/{tag}_counts.pkl\”, \”wb\”) as f:\n”,”,” ” pickle.dump(c, f)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “44c935a2-7e0f-4cf6-8507-ec95a12765fa”,”,” “metadata”: {},”,” “source”: [“,” “The function above cleans the input data, stores the clean data, and counts word frequencies in the clean version of the data. These counts will be used to feed the training of the Byte Pair Encoding scheme later on.\n”,”,” “\n”,”,” “We’ll process the data in parallel to save quite a bit of time. The below cell might still take half an hour or so to complete.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 5,”,” “id”: “a7bfc86b-a80a-4519-915d-6c778948832b”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “with Pool(6) as p:\n”,”,” ” p.starmap(preprocess_dataset, [(bc, \”bookcorpus\”),\n”,”,” ” (wp_a, \”wikipedia_a\”),\n”,”,” ” (wp_b, \”wikipedia_b\”),\n”,”,” ” (wp_c, \”wikipedia_c\”),\n”,”,” ” (wp_d, \”wikipedia_d\”),\n”,”,” ” (wp_e, \”wikipedia_e\”)])””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “871a61a7-bc9a-4f73-a5cc-344258075d96″,”,” “metadata”: {},”,” “source”: [“,” “The above cell generated counts per subset of the data, let’s merge this into a single data structure:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 6,”,” “id”: “c3ebe174-a042-4c72-bcc8-3c4c8f74a340″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “word_counts = Counter()\n”,”,” “for tag in [\”bookcorpus\”, \”wikipedia_a\”, \”wikipedia_b\”, \”wikipedia_c\”, \”wikipedia_d\”, \”wikipedia_e\”]:\n”,”,” ” with open(f\”data/{tag}_counts.pkl\”, \”rb\”) as f:\n”,”,” ” word_counts.update(pickle.load(f))””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “1a45c3bf-8b87-4bef-8957-9f2ca104fc4e”,”,” “metadata”: {},”,” “source”: [“,” “And then let’s have a quick look at common and rare words, as a sanity check:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 7,”,” “id”: “b2868a89-6478-4a36-927d-9bfdf73d5216″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [“,” {“,” “name”: “stdout”,”,” “output_type”: “stream”,”,” “text”: [“,” “Most common:\n”,”,” “[(‘the’, 238227230), (‘of’, 117781415), (‘and’, 104334764), (‘in’, 97520033), (‘to’, 78612948), (‘a’, 76470732), (‘was’, 44506401), (‘he’, 30383513), (‘s’, 29998804), (‘for’, 29685590)]\n”,”,” “Least common:\n”,”,” “[(‘mattinglly’, 1), (‘tennley’, 1), (‘thatnormally’, 1), (‘forsince’, 1), (‘convincingthe’, 1), (‘darkchocolate’, 1), (‘towardyour’, 1), (‘nfidence’, 1), (‘thepenthouse’, 1), (‘andunyielding’, 1)]\n””,” ]”,” }”,” ],”,” “source”: [“,” “print(\”Most common:\”)\n”,”,” “print(sorted(word_counts.items(), key = lambda x: -x[1])[:10])\n”,”,” “print(\”Least common:\”)\n”,”,” “print(sorted(word_counts.items(), key = lambda x: x[1])[:10])””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “140a5869-5953-4b52-9426-baf0f000fa39″,”,” “metadata”: {},”,” “source”: [“,” “The most common sequences generally seem to make sense. It also makes sense that the least common sequences are mainly misspellings.””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “74771ef9-81ff-4e38-9f94-aedae4cb0885″,”,” “metadata”: {},”,” “source”: [“,” “### Byte Pair Encoding (BPE)\n”,”,” “\n”,”,” “For a general description of BPE check out [this Hugging Face tutorial](https://huggingface.co/learn/nlp-course/chapter6/5?fw=pt). At a high level, BPE is used to shorten the sequences we will feed into the language model in a meaningful way. Training the model on raw characters would be expensive – by using BPE we can represent common words by single unique tokens, and less common words by sequences of a few tokens, where each token hopefully carries some meaning of its own. We will briefly look at what the tokenizer does later on.\n”,”,” “\n”,”,” “We start by defining the alphabet for our data, which is basically all printable ASCII characters and a few special tokens. The `[CLS]` token is fed as the first token of every sequence and may aid fine tuning. The `[MASK]` token is used during pre-training to indicate hidden tokens. The `[SEP]` token can be used to separate sequences in the input (this one becomes important during finetuning). The `[PAD]` token is used to make sure all input sequences consist of a fixed number of tokens.\n”,”,” “\n”,”,” “We also differentiate between characters that are at the beginning of a word and characters that are in the middle. For example the token `\”_a\”` corresponds to an \”a\” starting a word, while `\”a\”` corresponds to an \”a\” inside the word.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 8,”,” “id”: “e48bfe6a-d4dc-40e6-bb68-908ce3c379c9″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “alphabet = ([\”[CLS]\”, \”[MASK]\”, \”[SEP]\”, \”[PAD]\”] +\n”,”,” ” [c for c in string.ascii_lowercase] +\n”,”,” ” [f\”_{c}\” for c in string.ascii_lowercase] +\n”,”,” ” [symbol for symbol in ‘0123456789!\”#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~’] +\n”,”,” ” [\”\\\\n\”])””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “d795423f-fc24-46be-a890-90116032477d”,”,” “metadata”: {},”,” “source”: [“,” “First we will define a class to take care of encoding sequences. When byte pair encoding a sequence, the alphabet is combined with an ordered set of merge rules which are applied in sequence.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 9,”,” “id”: “adc548d6-8f8a-4566-8cca-b77fb2c9a78b”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “class BPEncoder:\n”,”,” ” \n”,”,” ” def __init__(self, alphabet, merge_rules, bpe_cache = dict()):\n”,”,” ” \”\”\”\n”,”,” ” alphabet: a list of strings\n”,”,” ” merge_rules: a list of string pairs to be merged\n”,”,” ” \”\”\”\n”,”,” ” self.alphabet = alphabet\n”,”,” ” self.merge_rules = merge_rules\n”,”,” ” self.bpe_cache = bpe_cache\n”,”,” ” \n”,”,” ” def total_tokens(self):\n”,”,” ” return len(self.alphabet) + len(self.merge_rules)\n”,”,” ” \n”,”,” ” def all_tokens(self):\n”,”,” ” return self.alphabet + [a + b for a, b in self.merge_rules]\n”,”,” “\n”,”,” ” def token_mapping(self):\n”,”,” ” tokens = self.all_tokens()\n”,”,” ” return {tok: i for i, tok in enumerate(tokens)}\n”,”,” ” \n”,”,” ” def add_merge_rule(self, merge_rule):\n”,”,” ” self.merge_rules.append(merge_rule)\n”,”,” ” \n”,”,” ” def split_seq(self, s):\n”,”,” ” \”\”\”Split s into units from the alphabet.\”\”\”\n”,”,” ” t = sorted([a for a in alphabet if s.startswith(a)], key = lambda x: -len(x))[0]\n”,”,” ” if len(t) < len(s):\n","," " return [t] + self.split_seq(s[len(t):])\n","," " else:\n","," " return [t]\n","," " \n","," " def apply_merge_rule(self, merge_rule, bpe_seq):\n","," " ret = []\n","," " delta_dict = Counter()\n","," " i = 0\n","," " while i < len(bpe_seq) - 1:\n","," " if merge_rule == (bpe_seq[i], bpe_seq[i+1]):\n","," " ret.append(bpe_seq[i] + bpe_seq[i+1])\n","," " \n","," " # This part is a bit hairy and only really necessary for training the encoder (done further down).\n","," " # It's essentially accounting logic to keep track of the occurrence of\n","," " # sequential pairs: when we apply a merge rule, the merged pair disappears\n","," " # everywhere in the sequence, but new pairs also appear. Keeping track of\n","," " # that change this way is a bit more efficient than just re-counting all pairs.\n","," " \n","," " # Example:\n","," " # We have the sequence [t1, t2, t3, t4] and the merge rule (t2, t3).\n","," " # The pair (t2, t3) disappears, and pairs (t1, t2+t3) and (t2+t3, t4) appear.\n","," " \n","," " delta_dict.update({merge_rule: -1})\n","," " if i > 0:\n”,”,” ” delta_dict.update({(ret[-2], bpe_seq[i]): -1})\n”,”,” ” delta_dict.update({(ret[-2], bpe_seq[i] + bpe_seq[i+1]): 1})\n”,”,” ” if i < len(bpe_seq) - 2:\n","," " delta_dict.update({(bpe_seq[i+1], bpe_seq[i+2]): -1})\n","," " delta_dict.update({(bpe_seq[i] + bpe_seq[i+1], bpe_seq[i+2]): 1})\n","," " \n","," " i += 2\n","," " else:\n","," " ret.append(bpe_seq[i])\n","," " i += 1\n","," " if i == len(bpe_seq) - 1:\n","," " ret.append(bpe_seq[i])\n","," " return ret, delta_dict\n","," " \n","," " def encode(self, s):\n","," " \"\"\"\n","," " Apply BPE to s.\n","," " This implementation is very slow for encodings with many merge rules.\n","," " In our case that doesn't matter much, we will cache encodings.\n","," " \"\"\"\n","," " if s in self.bpe_cache:\n","," " return self.bpe_cache[s]\n","," " else:\n","," " ret = self.split_seq(s)\n","," " for mr in self.merge_rules:\n","," " ret, _ = self.apply_merge_rule(mr, ret)\n","," " self.bpe_cache[s] = ret\n","," " return ret""," ]"," },"," {"," "cell_type": "markdown","," "id": "ffa7b703-1d78-4ca5-86de-3a4949f173a8","," "metadata": {},"," "source": ["," "Let's go through an example to get a feel for this:""," ]"," },"," {"," "cell_type": "code","," "execution_count": 10,"," "id": "f67a05e2-1c1e-40b5-8f20-177b18fdcf04","," "metadata": {"," "tags": []"," },"," "outputs": ["," {"," "name": "stdout","," "output_type": "stream","," "text": ["," "Split to alphabet elements:\n","," "['_t', 'h', 'e', 'r', 'e', 'f', 'o', 'r', 'e']\n","," "\n","," "Apply BPE:\n","," "['_the', 'r', 'e', 'for', 'e']\n","," "\n","," "Convert to numeric representation:\n","," "[100, 21, 8, 102, 8]\n""," ]"," }"," ],"," "source": ["," "demo_bpe = BPEncoder(alphabet, [(\"h\", \"e\"), (\"_t\", \"he\"), (\"o\", \"r\"), (\"f\", \"or\")])\n","," "\n","," "print(\"Split to alphabet elements:\")\n","," "print(demo_bpe.split_seq(\"_therefore\"))\n","," "print()\n","," "print(\"Apply BPE:\")\n","," "print(demo_bpe.encode(\"_therefore\"))\n","," "print()\n","," "tok2idx = demo_bpe.token_mapping()\n","," "print(\"Convert to numeric representation:\")\n","," "print([tok2idx[tok] for tok in demo_bpe.encode(\"_therefore\")])""," ]"," },"," {"," "cell_type": "markdown","," "id": "1eb63a4c-404b-4d02-b4e3-d92dd7012621","," "metadata": {},"," "source": ["," "Now that we have a BPE encoder, we need to learn the encoding from our data. This is what we collected the `word_counts` for earlier. We are only going to generate merge rules for alphabetic sequences: any tokens, numbers or symbols will be kept at the alphabet level.\n","," "\n","," "The algorithm is roughly:\n","," "1. Split the sequences for all words into alphabet elements.\n","," "2. Count the number of occurrences of all pairs.\n","," "3. Add a merge rule for the pair that occurs most often.\n","," "4. Apply the merge rule for all words.\n","," "5. If we have not reached our desired number of tokens, repeat starting at step 2.\n","," "\n","," "The code has been set up to be a bit more efficient than a naive algorithm (e.g. pair occurrences are not re-calculated every loop, during the loop we keep track of the changes resulting from a new merge rule, see the `apply_merge_rule` method in `BPEncoder`). However, this algorithm could still be sped up significantly. That would add complexity and is beyond the scope of this project.""," ]"," },"," {"," "cell_type": "code","," "execution_count": 11,"," "id": "9d63a82e-5fc4-452b-8e6b-cb57e3bb085f","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "class BPETrainer:\n","," " \n","," " def __init__(self, word_counts, alphabet):\n","," " self.bpe = BPEncoder(alphabet, [])\n","," " \n","," " self.data = [] # Will hold the encoded version of the words from our data with its count\n","," " self.pair_counts = Counter() # Will hold occurrence frequencies of token pairs\n","," " self.token_word_index = {token: [] for token in self.bpe.all_tokens()} # Maps tokens to the words in which they occur\n","," " for i in range(len(word_counts)):\n","," " word, count = word_counts[i]\n","," " word_enc = self.bpe.split_seq('_' + word) # Prepend underscore to differentiate leading tokens\n","," " self.data.append((word_enc, count))\n","," " for j in range(0, len(word_enc) - 1):\n","," " self.pair_counts.update({(word_enc[j], word_enc[j+1]): count})\n","," " for tok in set(word_enc):\n","," " self.token_word_index[tok].append(i)\n","," " \n","," " def add_merge_rule(self, t1, t2):\n","," " \"\"\"Adds the rule to merge t1 and t2 to the BPE and updates internal statistics.\"\"\"\n","," " \n","," " # Add the new (merged) token to the word mapping\n","," " self.token_word_index[t1 + t2] = []\n","," " \n","," " # The below code finds words that contain *both* t1 and t2 in a somewhat efficient way.\n","," " # It relies on the fact that the list values in self.token_word_index are in sorted order.\n","," " i = 0\n","," " j = 0\n","," " while i < len(self.token_word_index[t1]) and j < len(self.token_word_index[t2]):\n","," "\n","," " if self.token_word_index[t1][i] < self.token_word_index[t2][j]:\n","," " i += 1\n","," "\n","," " elif self.token_word_index[t2][j] < self.token_word_index[t1][i]:\n","," " j += 1\n","," "\n","," " else:\n","," " # This word contains both t1 and t2: we might need to merge pairs here.\n","," " \n","," " word_idx = self.token_word_index[t1][i]\n","," " word_enc, count = self.data[word_idx]\n","," "\n","," " # Get the encoded word after applying the merge rule, and the changes\n","," " # that we need to make to our `pair_counts`.\n","," " word_enc_post, delta = self.bpe.apply_merge_rule((t1, t2), word_enc)\n","," " self.data[word_idx] = (word_enc_post, count)\n","," " self.pair_counts.update({pair: d*count for pair, d in delta.items()})\n","," " \n","," " # Update the word index\n","," " if t1 not in word_enc_post:\n","," " del self.token_word_index[t1][i]\n","," " else:\n","," " i += 1\n","," " if t2 not in word_enc_post:\n","," " if t2 != t1:\n","," " del self.token_word_index[t2][j]\n","," " else:\n","," " j += 1\n","," " if t1 + t2 in word_enc_post:\n","," " self.token_word_index[t1 + t2].append(word_idx)\n","," "\n","," " # Update the BPE to include the new merge rule\n","," " self.bpe.add_merge_rule((t1, t2))\n","," " \n","," " def find_merge_rules(self, token_limit, verbose = False):\n","," " \"\"\"Add merge rules to the BPE until token_limit is reached.\"\"\"\n","," " \n","," " while self.bpe.total_tokens() < token_limit:\n","," " \n","," " # Find the most frequent pair.\n","," " # This call could be sped up with a better data structure.\n","," " t1, t2 = max(self.pair_counts, key = self.pair_counts.get)\n","," " count = self.pair_counts.get((t1, t2))\n","," " \n","," " if count == 0:\n","," " print(f\"No more tokens to add, every word has its own token already.\")\n","," " break\n","," " \n","," " if verbose:\n","," " print(f\"{self.bpe.total_tokens()}: {t1} + {t2} -> {t1}{t2} (count = {self.pair_counts.get((t1, t2))})\”)\n”,”,” ” \n”,”,” ” # Add the most frequent pair as a merge rule.\n”,”,” ” self.add_merge_rule(t1, t2)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “ef91f887-b4a5-4bda-aaba-79035e4d0f11″,”,” “metadata”: {},”,” “source”: [“,” “For the purpose of this notebook, let’s run this process on a random subset of the data only, and up to a relatively small number of merge rules. Running the process on the full `word_counts` dictionary to the full token count (2^15 tokens) takes about 15 hours. We will also print some of the words on which the small tokenizer is trained and show how they are broken up into tokens:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 12,”,” “id”: “e0614787-a255-47b9-9b0a-800daf010e72″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [“,” {“,” “name”: “stdout”,”,” “output_type”: “stream”,”,” “text”: [“,” “gladding (282 occ): [‘_gl’, ‘ad’, ‘ding’]\n”,”,” “memoire (3863 occ): [‘_memoire’]\n”,”,” “camisa (380 occ): [‘_cam’, ‘isa’]\n”,”,” “democratico (771 occ): [‘_dem’, ‘oc’, ‘rat’, ‘ico’]\n”,”,” “templars (4578 occ): [‘_templars’]\n”,”,” “overconfident (1118 occ): [‘_overconfid’, ‘ent’]\n”,”,” “agenesis (598 occ): [‘_ag’, ‘enes’, ‘is’]\n”,”,” “bernese (2982 occ): [‘_bernese’]\n”,”,” “bonnies (655 occ): [‘_bon’, ‘nies’]\n”,”,” “shax (405 occ): [‘_sh’, ‘ax’]\n””,” ]”,” }”,” ],”,” “source”: [“,” “word_counts_small = random.sample([wc for wc in word_counts.items() if wc[1] >= 256], 2**13)\n”,”,” “\n”,”,” “bpet_small = BPETrainer(word_counts_small, alphabet)\n”,”,” “bpet_small.find_merge_rules(2**13)\n”,”,” “\n”,”,” “for word, count in word_counts_small[:10]:\n”,”,” ” print(f\”{word} ({count} occ): {bpet_small.bpe.encode(‘_’ + word)}\”)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “3e1ed7d6-1993-4668-bd42-3963d0a0049b”,”,” “metadata”: {},”,” “source”: [“,” “As you can see, common words are likely to be represented by a single token. Less common words are broken into chunks of common sequences, where these sequences often (though definitely not always) have some semantic meaning.””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “2599b4e3-a7ec-43a5-810d-fc4b81dcbfcd”,”,” “metadata”: {},”,” “source”: [“,” “To run the process on the full training data, up to the full token count, execute this cell:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: null,”,” “id”: “16bf52bf-fbb1-4445-9535-90bd54cc84c8″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “# This code generates the data loaded below\n”,”,” “\n”,”,” “bpet = BPETrainer(word_counts, alphabet)\n”,”,” “bpet.find_merge_rules(2**15, verbose = True)\n”,”,” ” \n”,”,” “with open(\”data/bert_mr.pkl\”, \”wb\”) as f:\n”,”,” ” pickle.dump(bpet.bpe.merge_rules, f)\n”,”,” “\n”,”,” “with open(\”data/bert_bpet_data.pkl\”, \”wb\”) as f:\n”,”,” ” pickle.dump(bpet.data, f)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “b445e90c-d609-4dba-a38d-abbfec2938f8″,”,” “metadata”: {},”,” “source”: [“,” “For reruns it’s easier to just load saved results from earlier:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 13,”,” “id”: “0c7a10ca-45e7-40f8-ac62-55626f2c8778″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “with open(\”data/bert_mr.pkl\”, \”rb\”) as f:\n”,”,” ” bert_mr = pickle.load(f)\n”,”,” “\n”,”,” “with open(\”data/bert_bpet_data.pkl\”, \”rb\”) as f:\n”,”,” ” bert_bpet_data = pickle.load(f)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “44b0afc8-a893-4725-bca0-ac6151bd3b93″,”,” “metadata”: {},”,” “source”: [“,” “A nice side effect of structuring the training process the way it is in this notebook is that for every single word in our pretraining data, we get the encoded form as part of the BPE training process. That means that rather than re-encoding a word every time we encounter it, we can create a lookup table to find the encoding, which is a lot faster (at least for this poorly optimized encoder):””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 14,”,” “id”: “f4d54e8e-e458-4d01-8074-c7266c839b1d”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “# Mapping from each word to its encoding\n”,”,” “bpe_cache = {”.join(w_enc): w_enc for w_enc, _ in bert_bpet_data}\n”,”,” “\n”,”,” “bert_bpe = BPEncoder(alphabet, bert_mr, bpe_cache)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “a115bc61-15b6-4b19-bd61-7a16b40abc85″,”,” “metadata”: {},”,” “source”: [“,” “### Chunking\n”,”,” “\n”,”,” “The cleaned data that we generated earlier did not take into account the fact that we want to train on fixed sample lengths – during pretraining, all our samples will be 128 tokens long. For this purpose we will apply one more preprocessing step, generating a dataset where each line of data has a suitable length, given our trained BPE encoding.\n”,”,” “\n”,”,” “We will need to do this slightly differently between the BookCorpus data and the Wikipedia data – with the BookCorpus data, each sample represents a sentence from a book, and the next sample is the following sentence. For our training process, we will merge such samples to try and fill up the 128 token frames. The Wikipedia data on the other hand has one sample per article, and we want to avoid combining unrelated sentences on different topics. For the Wikipedia data, once we reach the end of an article, we will fill the rest of the current training sample with `[PAD]` tokens.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 15,”,” “id”: “8b473261-67f9-49c1-81d3-9e3ba60e53c7″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def atomize(s):\n”,”,” ” \”\”\”Break down a sample into symbols and words.\”\”\”\n”,”,” ” atom_re = r'(\\[(CLS|SEP|PAD|MASK)\\]|[a-z]+|\\\\n|[0-9]|\\\\|[!#$%&\\'()*+,-./:;<=>?@[\\]^_`{|}~])’\n”,”,” ” return [m[0] for m in re.findall(atom_re, s)]\n”,”,” ” \n”,”,” “def chunks(fname, bpe, max_length, merge_lines = False):\n”,”,” ” ret_list = []\n”,”,” ” ret_tok_len = 0\n”,”,” ” \n”,”,” ” with open(fname, \”r\”) as f:\n”,”,” ” for line in f:\n”,”,” ” atoms = atomize(line)\n”,”,” ” for atom in atoms:\n”,”,” ” if atom.isalpha():\n”,”,” ” # Deal with some weird sequences in the training data\n”,”,” ” if len(bpe.encode(‘_’ + atom)) > max_length:\n”,”,” ” continue\n”,”,” ” if ret_tok_len + len(bpe.encode(‘_’ + atom)) > max_length:\n”,”,” ” yield ‘ ‘.join(ret_list)\n”,”,” ” ret_list = []\n”,”,” ” ret_tok_len = 0\n”,”,” ” ret_list.append(atom)\n”,”,” ” ret_tok_len += len(bpe.encode(‘_’ + atom))\n”,”,” ” else:\n”,”,” ” if ret_tok_len == max_length:\n”,”,” ” yield ‘ ‘.join(ret_list)\n”,”,” ” ret_list = []\n”,”,” ” ret_tok_len = 0\n”,”,” ” ret_list.append(atom)\n”,”,” ” ret_tok_len += 1\n”,”,” ” if not merge_lines:\n”,”,” ” yield ‘ ‘.join(ret_list)\n”,”,” ” ret_list = []\n”,”,” ” ret_tok_len = 0\n”,”,” ” yield ‘ ‘.join(ret_list)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “12a00622-07a7-4219-88d8-afa7add916f2″,”,” “metadata”: {},”,” “source”: [“,” “Let’s see what a sample looks like. We will take `max_length` to be 126, because during training we will still prepend a `[CLS]` token and append a `[SEP]` token.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 16,”,” “id”: “1ffa1ce4-f91d-4c98-834d-bb627292e81c”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [“,” {“,” “name”: “stdout”,”,” “output_type”: “stream”,”,” “text”: [“,” “the camelen iv 4 4 0 is a four wheel drive modular mission system vehicle designed by jez hermer mbe , ceo of ovik special vehicles . designed and developed in 2 0 1 0 , it is based upon the iveco daily 4 x 4 chassis but incorporates a number of modifications designed by ovik plus a range of specialist mission modules which can be interchanged rapidly , giving the vehicle a multi – functional utility . \\n \\n concept of use \\n the general concept behind the cameleom system is to provide military forces , civil and emergency services and commercial users with a modular vehicle which can be reconfigured , rapidly , into\n””,” ]”,” }”,” ],”,” “source”: [“,” “test_chunk = chunks(\”data/wikipedia_d.txt\”, bert_bpe, 126, merge_lines = True)\n”,”,” “print(next(test_chunk))””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “f2868da2-8ff9-44c3-97ee-a595c13f05db”,”,” “metadata”: {},”,” “source”: [“,” “The below code generates a chunked version of all the training data – again we will parallellize the process, it will take up quite a bit of disk space and take approximately half an hour.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 17,”,” “id”: “316e643a-c66a-410a-86d7-957258659d4a”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def chunk_samples(fname, merge_lines):\n”,”,” ” cs = chunks(\”data/\” + fname, bert_bpe, 126, merge_lines)\n”,”,” ” with open(f\”data/chunked_{fname}\”, \”w\”) as f:\n”,”,” ” for c in cs:\n”,”,” ” f.write(c + ‘\\n’)\n”,”,” “\n”,”,” “with Pool(6) as p:\n”,”,” ” p.starmap(chunk_samples, [(\”bookcorpus.txt\”, True),\n”,”,” ” (\”wikipedia_a.txt\”, False),\n”,”,” ” (\”wikipedia_b.txt\”, False),\n”,”,” ” (\”wikipedia_c.txt\”, False),\n”,”,” ” (\”wikipedia_d.txt\”, False),\n”,”,” ” (\”wikipedia_e.txt\”, False)])””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “ccfe47e2-129d-4ee1-ab76-a08f19723a91″,”,” “metadata”: {},”,” “source”: [“,” “We will now merge all the data into one file and shuffle it. First we shuffle the individual files, then we randomly merge the resulting files. The reason for the roundabout procedure is that the combined file is too large to fit in memory (at least on my machine). The reason that we apply weights during file merge is that we don’t want shorter files to be over-represented in earlier parts of the training data – the files we generated vary a bit in size, and also in content, so that could potentially cause problems.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 18,”,” “id”: “7fa9051e-a81b-4877-a692-ad6950715118″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “tags = [\”bookcorpus\”, \”wikipedia_a\”, \”wikipedia_b\”, \”wikipedia_c\”, \”wikipedia_d\”, \”wikipedia_e\”]””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 19,”,” “id”: “eef7b9b4-80a6-4d46-a223-c8ef1042d6b4″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “for tag in tags:\n”,”,” ” lines = open(f\”data/chunked_{tag}.txt\”, \”r\”).readlines()\n”,”,” ” random.shuffle(lines)\n”,”,” ” open(f\”data/shuffled_chunked_{tag}.txt\”, \”w\”).writelines(lines)””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 20,”,” “id”: “0b97ae38-ecb8-43f4-88e5-a97ad8168730″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [“,” {“,” “name”: “stdout”,”,” “output_type”: “stream”,”,” “text”: [“,” “Done with wikipedia_a\n”,”,” “Done with bookcorpus\n”,”,” “Done with wikipedia_b\n”,”,” “Done with wikipedia_c\n”,”,” “Done with wikipedia_d\n”,”,” “Done with wikipedia_e\n””,” ]”,” }”,” ],”,” “source”: [“,” “read_handles = [open(f\”data/shuffled_chunked_{tag}.txt\”, \”r\”) for tag in tags]\n”,”,” “sizes = [os.stat(f\”data/shuffled_chunked_{tag}.txt\”).st_size for tag in tags]\n”,”,” “names = [t for t in tags]\n”,”,” “\n”,”,” “with open(\”data/pretrain.txt\”, \”w\”) as f:\n”,”,” ” while len(read_handles) > 0:\n”,”,” ” i = random.choices(range(len(read_handles)), weights = sizes)[0]\n”,”,” ” try:\n”,”,” ” line = next(read_handles[i])\n”,”,” ” if line != ‘\\n’: # A few cases of empty lines show up\n”,”,” ” f.write(line)\n”,”,” ” except StopIteration:\n”,”,” ” read_handles[i].close()\n”,”,” ” print(f\”Done with {names[i]}\”)\n”,”,” ” del read_handles[i], sizes[i], names[i]””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “492c5a0d-9e3d-4d30-8247-66fb77275a24″,”,” “metadata”: {},”,” “source”: [“,” “### Generating training samples\n”,”,” “\n”,”,” “We now have input data of the right size in a workable format. The next step is to generate the actual training samples for pre-training BERT. Different approaches have been used for this by different people, we will go with a relatively simple approach here. We are going to train a BERT on 128 tokens at a time, and during pre-training we will mask 15% of the input tokens and score the model on how well it manages to predict what the missing tokens are.\n”,”,” “\n”,”,” “To give some intuition for what we are trying to do, an input sample might look something like\n”,”,” “\n”,”,” “`[‘i’, ‘went’, ‘to’, ‘the’, ‘[MASK]’, ‘for’, ‘lunch’]`\n”,”,” “\n”,”,” “for which the model would need to predict\n”,”,” “\n”,”,” “`[ – , – , – , – , ‘cafeteria’, – , – ]`\n”,”,” “\n”,”,” “The common case will be to leave out a token and replace it with `[MASK]` (80% of the time). However, in 10% of cases instead of `[MASK]` we will use a random token, and in 10% of cases we won’t perform a replacement. This is similar to how training was done in the original BERT paper.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 21,”,” “id”: “185a4460-dc29-4e2d-be9b-d095d3481dd0″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def encode_sample(sample, bpe):\n”,”,” ” encoded = []\n”,”,” ” for item in sample.strip().split(‘ ‘):\n”,”,” ” if item.isalpha():\n”,”,” ” encoded += [tok for tok in bpe.encode(‘_’ + item)]\n”,”,” ” else:\n”,”,” ” encoded.append(item)\n”,”,” ” return encoded\n”,”,” “\n”,”,” “def samples_and_masks(fname, length, bpe):\n”,”,” ” \”\”\”Assumes that all samples in fname have been sized not to exceed `length`.\”\”\”\n”,”,” ” tok2idx = bpe.token_mapping()\n”,”,” ” \n”,”,” ” with open(fname, \”r\”) as f:\n”,”,” ” for sample in f:\n”,”,” ” \n”,”,” ” # Apply BPE to the sample\n”,”,” ” encoded = [tok2idx[e] for e in encode_sample(sample, bpe)]\n”,”,” ” total_tokens = len(encoded)\n”,”,” ” \n”,”,” ” # Generate mask in the shape of the sample\n”,”,” ” mask_count = math.ceil(0.15 * total_tokens)\n”,”,” ” mask = [1] * mask_count + [0] * (total_tokens – mask_count)\n”,”,” ” random.shuffle(mask)\n”,”,” ” \n”,”,” ” # Generate ground truth and mask in matching shape\n”,”,” ” training_output = [tok2idx[\”[CLS]\”]] + encoded + [tok2idx[\”[SEP]\”]] + [tok2idx[\”[PAD]\”]] * (length – total_tokens – 2)\n”,”,” ” training_mask = [0] + mask + [0] + [0] * (length – total_tokens – 2)\n”,”,” ” \n”,”,” ” # Generate input data\n”,”,” ” training_input = [t for t in training_output]\n”,”,” ” for i in range(length):\n”,”,” ” if training_mask[i] == 1: # Mask this token\n”,”,” ” r = random.random()\n”,”,” ” if r < 0.8: # Regular masking\n","," " training_input[i] = tok2idx[\"[MASK]\"]\n","," " elif r < 0.9: # Random other token instead of [MASK]\n","," " training_input[i] = random.randrange(bpe.total_tokens())\n","," " else: # Feed the original token as input, untouched\n","," " pass\n","," " \n","," " yield [training_input, training_output, training_mask]""," ]"," },"," {"," "cell_type": "markdown","," "id": "b7b77786-d270-49b7-aed2-bd228ff2ce31","," "metadata": {},"," "source": ["," "This time we will not write training samples to disk - we will generate them on the fly during training. This makes it easier to apply different transformation (e.g. a SpanBERT objective) or to re-randomize for subsequent epochs.""," ]"," },"," {"," "cell_type": "markdown","," "id": "70fae2c2-8e22-43b5-8602-ea4db87665ed","," "metadata": {},"," "source": ["," "## Architecture\n","," "\n","," "We are now at a point where we can define the BERT model. We will pretty closely follow the architecture described in the [Cramming paper](https://arxiv.org/abs/2212.14034). A lot of the structure of the below code is borrowed from Andrej Karpathy's [minGPT](https://github.com/karpathy/minGPT).\n","," "\n","," "The main differences from the original BERT implementation are the following:\n","," "* Pre-LayerNorm (marked in the code below).\n","," "* No dropout.\n","," "* No biases in transformer attention, in transformer MLPs, or in the decoder of the model.\n","," "* An additional LayerNorm right after the embedding layer.\n","," "\n","," "The main difference from the Cramming paper is that the architecture here uses relative position embeddings as introduced by [Shaw et al](https://arxiv.org/abs/1803.02155). The implementation here only adds position representations to keys, not to values. Compared to the original absolute position embeddings, this approach slows down training by ~4% but it makes up for it in improved model performance.\n","," "\n","," "This is not the time and place to explain transformers in detail. For that, if you have a lot of time, go through [Stanford's CS 224N](https://web.stanford.edu/class/cs224n/). If you don't have a lot of time, Andrej Karpathy gives [a good high-level overview](https://www.youtube.com/watch?v=kCc8FmEb1nY).\n","," "\n","," "The first two components define the transformer block:""," ]"," },"," {"," "cell_type": "code","," "execution_count": 22,"," "id": "f687cbb2-8c7b-48c2-91db-ca7773ec0569","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "class SelfAttention(nn.Module):\n","," " \"\"\"\n","," " Bi-directional transformer self-attention.\n","," " Uses relative position embeddings, shared across tokens and attention heads, but unique for each layer.\n","," " \"\"\"\n","," "\n","," " def __init__(self, config):\n","," " super().__init__()\n","," " \n","," " self.config = config\n","," " embed_size = config[\"embed_size\"]\n","," " n_head = config[\"n_head\"]\n","," " assert embed_size % n_head == 0\n","," " \n","," " # This is clipping distance (k) in Shaw et al\n","," " pos_emb_radius = config[\"pos_emb_radius\"]\n","," " pos_emb_units = config[\"embed_size\"] // config[\"n_head\"]\n","," " \n","," " # Position embedding vectors for use on keys\n","," " # This is w^K in Shaw et al\n","," " self.pos_emb_k = nn.Parameter(torch.zeros(2 * pos_emb_radius, pos_emb_units))\n","," " torch.nn.init.normal_(self.pos_emb_k, mean=0.0, std=0.02)\n","," " \n","," " # key, query, value projections for all heads\n","," " self.key = nn.Linear(embed_size, embed_size, bias = False)\n","," " self.query = nn.Linear(embed_size, embed_size, bias = False)\n","," " self.value = nn.Linear(embed_size, embed_size, bias = False)\n","," " \n","," " # output projection\n","," " self.proj = nn.Linear(embed_size, embed_size, bias = False)\n","," "\n","," " def forward(self, x):\n","," " batch_size, context_size, embed_size = x.size()\n","," " assert embed_size == self.config[\"embed_size\"]\n","," " \n","," " n_head = self.config[\"n_head\"]\n","," " head_size = embed_size // n_head\n","," " \n","," " pos_emb_size, head_size = self.pos_emb_k.size()\n","," " pos_emb_radius = self.config[\"pos_emb_radius\"]\n","," " assert pos_emb_size == 2 * pos_emb_radius\n","," "\n","," " # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n","," " k = self.key(x).view(batch_size, context_size, n_head, head_size).transpose(1, 2) # (batch_size, n_head, context_size, head_size)\n","," " q = self.query(x).view(batch_size, context_size, n_head, head_size).transpose(1, 2) # (batch_size, n_head, context_size, head_size)\n","," " v = self.value(x).view(batch_size, context_size, n_head, head_size).transpose(1, 2) # (batch_size, n_head, context_size, head_size)\n","," " \n","," " # Below section implements x_i W^Q (a_{ij}^K)^T from Shaw et al\n","," " # position attention: (batch_size, n_head, context_size, head_size) x (1, 1, pos_emb_size, head_size) -> (batch_size, n_head, context_size, pos_emb_size)\n”,”,” ” att_rel_pos = q @ self.pos_emb_k.view(1, 1, pos_emb_size, head_size).transpose(-2, -1)\n”,”,” ” att_idxs = (torch.clamp(torch.arange(context_size)[None, :] – torch.arange(context_size)[:, None], -pos_emb_radius, pos_emb_radius-1) % pos_emb_size).to(\”cuda\”)\n”,”,” ” att_pos = torch.gather(att_rel_pos, 3, att_idxs.expand((batch_size, n_head, context_size, context_size)))\n”,”,” ” assert att_pos.shape == (batch_size, n_head, context_size, context_size)\n”,”,” ” \n”,”,” ” # value attention: (batch_size, n_head, context_size, head_size) x (batch_size, n_head, context_size, head_size) -> (batch_size, n_head, context_size, context_size)\n”,”,” ” att_val = q @ k.transpose(-2, -1)\n”,”,” ” \n”,”,” ” # combined attention\n”,”,” ” att_scale = 1.0 / math.sqrt(k.size(-1))\n”,”,” ” att = F.softmax((att_val + att_pos) * att_scale, dim=-1) # Equation (5) from Shaw et al\n”,”,” ” \n”,”,” ” y = att @ v # (batch_size, n_head, context_size, context_size) x (batch_size, n_head, context_size, head_size) -> (batch_size, n_head, context_size, head_size)\n”,”,” ” y = y.transpose(1, 2).contiguous().view(batch_size, context_size, embed_size) # re-assemble all head outputs side by side\n”,”,” “\n”,”,” ” # output projection\n”,”,” ” y = self.proj(y)\n”,”,” ” return y””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “f2a02333-c9ac-406e-85c9-ac52d018196e”,”,” “metadata”: {},”,” “source”: [“,” “The relative position embeddings make the attention code a bit more complex than it could be, e.g. MinGPT has simpler attent logic. It also slows down training by about 4%. It seems to make up for this in improved training, however, and provides a lot of flexibility for training on larger samples (more than 128 tokens) after pretraining. With absolute position embeddings it’s not straightforward to change the context size.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 23,”,” “id”: “41a4d5a1-ec0f-45a3-84c7-5919e554125e”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “class Block(nn.Module):\n”,”,” ” \”\”\”Pre-LayerNorm transformer block.\”\”\”\n”,”,” “\n”,”,” ” def __init__(self, config):\n”,”,” ” super().__init__()\n”,”,” ” \n”,”,” ” embed_size = config[\”embed_size\”]\n”,”,” ” \n”,”,” ” self.norm1 = nn.LayerNorm(embed_size, eps = 1e-6)\n”,”,” ” self.attn = SelfAttention(config)\n”,”,” ” \n”,”,” ” self.norm2 = nn.LayerNorm(embed_size, eps = 1e-6)\n”,”,” ” self.mlp = nn.Sequential(\n”,”,” ” nn.Linear(embed_size, 4 * embed_size, bias = False),\n”,”,” ” nn.GELU(),\n”,”,” ” nn.Linear(4 * embed_size, embed_size, bias = False),\n”,”,” ” )\n”,”,” “\n”,”,” ” def forward(self, x):\n”,”,” ” # This is Pre-LayerNorm\n”,”,” ” x = x + self.attn(self.norm1(x))\n”,”,” ” x = x + self.mlp(self.norm2(x))\n”,”,” ” \n”,”,” ” # Post-LayerNorm would look more like\n”,”,” ” # x = self.norm1(x + self.attn)\n”,”,” ” # x = self.norm2(x + self.mlp)\n”,”,” ” \n”,”,” ” return x””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “c35d2a7a-8914-4739-8c55-96c2e0705478″,”,” “metadata”: {},”,” “source”: [“,” “We then build BERT out of an embedding layer and a sequence of transformer blocks, with some carefully placed layernorms.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 24,”,” “id”: “8b3eed86-67eb-49db-b76f-0945a469e1ea”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “class BERT(nn.Module):\n”,”,” ” \”\”\”Headless BERT.\”\”\”\n”,”,” “\n”,”,” ” def __init__(self, config):\n”,”,” ” super().__init__()\n”,”,” ” \n”,”,” ” self.config = config\n”,”,” ” vocab_size = config[\”vocab_size\”]\n”,”,” ” embed_size = config[\”embed_size\”]\n”,”,” ” n_layer = config[\”n_layer\”]\n”,”,” “\n”,”,” ” # token embedding\n”,”,” ” self.tok_emb = nn.Embedding(vocab_size, embed_size)\n”,”,” ” self.norm_emb = nn.LayerNorm(embed_size, eps = 1e-6)\n”,”,” ” \n”,”,” ” # transformer\n”,”,” ” self.transformer = nn.Sequential(*[Block(config) for _ in range(n_layer)])\n”,”,” ” \n”,”,” ” # final layernorm\n”,”,” ” self.norm_final = nn.LayerNorm(embed_size, eps = 1e-6)\n”,”,” “\n”,”,” ” print(\”number of parameters: {}\”.format(sum(p.numel() for p in self.parameters())))\n”,”,” ” \n”,”,” ” self.apply(self._init_weights)\n”,”,” ” for pn, p in self.named_parameters():\n”,”,” ” if pn.endswith(‘proj.weight’):\n”,”,” ” torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * n_layer))\n”,”,” “\n”,”,” ” def _init_weights(self, module):\n”,”,” ” if isinstance(module, nn.Linear):\n”,”,” ” torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n”,”,” ” if module.bias is not None:\n”,”,” ” torch.nn.init.zeros_(module.bias)\n”,”,” ” elif isinstance(module, nn.Embedding):\n”,”,” ” torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n”,”,” ” elif isinstance(module, nn.LayerNorm):\n”,”,” ” torch.nn.init.zeros_(module.bias)\n”,”,” ” torch.nn.init.ones_(module.weight)\n”,”,” “\n”,”,” ” def forward(self, x):\n”,”,” ” batch_size, context_size = x.size()\n”,”,” ” \n”,”,” ” x = self.tok_emb(x)\n”,”,” ” x = self.norm_emb(x)\n”,”,” ” x = self.transformer(x)\n”,”,” ” x = self.norm_final(x)\n”,”,” ” \n”,”,” ” return x””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “17ed9e50-3a49-4f5c-8bf5-85f0e7984141″,”,” “metadata”: {},”,” “source”: [“,” “We have BERT, but we can’t do anything with it yet: the top layer generates only hidden unit activations/embeddings. Depending on the training scenario, we will put a different \”head\” on BERT. During pre-training, BERT is learning masked language modeling (MLM), and it will take an MLM head that makes a token prediction for every position in the input. For our fine-tuning, BERT learns a single output for each input sentence (e.g. a true/false classification or some score), and we will use two different heads, one for classification and one for regression. These heads ignore most of the activations at BERT’s top, and use only the activations for the very first token as input features during training.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 25,”,” “id”: “4d5f8222-3422-43ea-9292-675da6ab427d”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “class MLMHead(nn.Module):\n”,”,” ” \”\”\”\n”,”,” ” BERT head for masked language modeling.\n”,”,” ” Note that this does *not* implement sparse prediction as mentioned in the Cramming paper. Predictions are calculated for all tokens.\n”,”,” ” \”\”\”\n”,”,” ” \n”,”,” ” def __init__(self, config):\n”,”,” ” super().__init__()\n”,”,” ” \n”,”,” ” self.config = config\n”,”,” ” vocab_size = config[\”vocab_size\”]\n”,”,” ” embed_size = config[\”embed_size\”]\n”,”,” ” \n”,”,” ” self.tok_unemb = nn.Linear(embed_size, vocab_size, bias = False)\n”,”,” ” \n”,”,” ” def forward(self, x, y):\n”,”,” ” logits = self.tok_unemb(x)\n”,”,” ” loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index = 0)\n”,”,” ” return logits, loss””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 26,”,” “id”: “7698a79c-764b-4f56-8cad-55bb20ac7f77″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “class CLSHead(nn.Module):\n”,”,” ” \”\”\”\n”,”,” ” BERT head for classification.\n”,”,” ” A prediction is only calculated for the first ([CLS]) token.\n”,”,” ” \”\”\”\n”,”,” ” \n”,”,” ” def __init__(self, config, n_classes):\n”,”,” ” super().__init__()\n”,”,” ” \n”,”,” ” self.config = config\n”,”,” ” embed_size = config[\”embed_size\”]\n”,”,” ” \n”,”,” ” self.classifier = nn.Linear(embed_size, n_classes)\n”,”,” ” \n”,”,” ” def forward(self, x, y = None):\n”,”,” ” logits = self.classifier(x[:, 0, :])\n”,”,” ” loss = None\n”,”,” ” if y is not None:\n”,”,” ” loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n”,”,” ” return logits, loss””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 27,”,” “id”: “c35e1242-197b-42bf-8ef8-c421823c92b7″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “class RegHead(nn.Module):\n”,”,” ” \”\”\”\n”,”,” ” BERT head for regression.\n”,”,” ” A prediction is only calculated for the first ([CLS]) token.\n”,”,” ” \”\”\”\n”,”,” ” \n”,”,” ” def __init__(self, config):\n”,”,” ” super().__init__()\n”,”,” ” \n”,”,” ” self.config = config\n”,”,” ” embed_size = config[\”embed_size\”]\n”,”,” ” \n”,”,” ” self.regressor = nn.Linear(embed_size, 1)\n”,”,” ” self.loss_fn = nn.MSELoss()\n”,”,” ” \n”,”,” ” def forward(self, x, y = None):\n”,”,” ” y_hat = self.regressor(x[:, 0, :])\n”,”,” ” loss = None\n”,”,” ” if y is not None:\n”,”,” ” loss = self.loss_fn(y_hat.view(-1), y.view(-1))\n”,”,” ” return y_hat, loss””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “977de2bb-ec76-4f87-9064-079c34267cb3″,”,” “metadata”: {“,” “tags”: []”,” },”,” “source”: [“,” “## Training\n”,”,” “\n”,”,” “We have our data and we have our model. Time to set up training loops.\n”,”,” “\n”,”,” “We will train in two phases:\n”,”,” “1. Unsupervised pre-training on the data we’re generating with the `samples_and_masks` function we defined earlier.\n”,”,” “2. Supervised fine-tuning on specific tasks where we take a model from (1) and specialize it further.\n”,”,” “\n”,”,” “### Pre-training””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “a6d4a4f3-45a0-4906-805e-a3b287660213″,”,” “metadata”: {},”,” “source”: [“,” “Configure the model:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 28,”,” “id”: “30a264fd-19ef-4968-ac38-46c5de7d7a4a”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [“,” {“,” “name”: “stdout”,”,” “output_type”: “stream”,”,” “text”: [“,” “number of parameters: 110164992\n””,” ]”,” }”,” ],”,” “source”: [“,” “# BERT base, but with relative position embeddings\n”,”,” “config = {\”vocab_size\”: 2**15, \”embed_size\”: 768, \”context_size\”: 128, \”n_layer\”: 12, \”n_head\”: 12, \”pos_emb_radius\”: 16}\n”,”,” “device = \”cuda\”\n”,”,” “bert = BERT(config).to(device)\n”,”,” “mlm_head = MLMHead(config).to(device)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “57526ca9-a12f-4629-80ff-1367d44394be”,”,” “metadata”: {},”,” “source”: [“,” “Configure the different levels of batching. In this notebook, a minibatch is a set of training samples for which gradients are calculated simultaneously (they are processed by the GPU at the same time). A (regular) batch is a set of training samples for which gradients are accumulated before a training step is taken, batch sizes changed throughout training and are defined in the `get_batch_size` function below. A macrobatch is a set of training samples that gets transferred to the GPU simultaneously. It is the granularity at which learning rate and batch size are varied, and the full duration of the training procedure is expressed in number of macrobatches (`macrobatch_count`). On a 3070 RTX GPU one macrobatch takes a bit over 4 minutes to process.””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 29,”,” “id”: “0353b3f4-bbdd-428e-a68d-1c0f73e5cb24″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “minibatch_size = 2**5 # Number of samples on which we compute gradients simultaneously\n”,”,” “macrobatch_size = 2**15 # Number of samples we transfer to the GPU at the same time\n”,”,” “macrobatch_count = 2**8 # Total number of macrobatches we transfer in the whole training process””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “b30f9be6-fb3e-4e5b-84b1-badc672149fc”,”,” “metadata”: {},”,” “source”: [“,” “The below function yields the training data as Torch tensors, one macrobatch at a time. It returns two values, the first being inputs (the \”x\” data), the second being training targets (the \”y\” data).””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 30,”,” “id”: “929ee0e3-7bfc-4645-aeba-776383f46f88″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def macrobatches(macrobatch_size):\n”,”,” ” \”\”\”Convert `samples_and_masks` to Torch matrices of size macrobatch_size * 128.\”\”\”\n”,”,” ” ss = samples_and_masks(\”data/pretrain.txt\”, 128, bert_bpe)\n”,”,” ” training_data = []\n”,”,” ” for s in ss:\n”,”,” ” training_data.append(s)\n”,”,” ” if len(training_data) == macrobatch_size:\n”,”,” ” training_data = np.array(training_data, dtype = ‘int16’)\n”,”,” ” yield training_data[:, 0, :], training_data[:, 1, :] * training_data[:, 2, :]\n”,”,” ” training_data = []””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “bbbc13a8-cc4e-46dd-85b8-7cb5bec47712″,”,” “metadata”: {},”,” “source”: [“,” “We vary the learning rate throughout training. The learning rate is first gradually increased (\”warmup\”) and at the end is gradually decreased (\”annealed\”).””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 31,”,” “id”: “ae62d551-e161-4238-94d1-7e61b78229c2″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def get_lr(macrobatch, max_lr = 1e-3):\n”,”,” ” \”\”\”\n”,”,” ” One-cycle LR schedule, scaled by fraction of training time remaining, as described in Cramming.\n”,”,” ” See plot below.\n”,”,” ” \”\”\”\n”,”,” ” c = macrobatch + 0.5 # Midpoint of chunk\n”,”,” ” lr = max_lr\n”,”,” ” if c / macrobatch_count < 0.5:\n","," " lr = lr * 2 * c / macrobatch_count\n","," " else:\n","," " lr = lr * 2 * (macrobatch_count - c) / macrobatch_count\n","," " lr = lr * (macrobatch_count - c) / macrobatch_count\n","," " return lr""," ]"," },"," {"," "cell_type": "markdown","," "id": "e7ab5a95-8d82-4946-ad80-1c5bfd1e65ae","," "metadata": {},"," "source": ["," "Here is a visualization of how learning rate changes throughout pretraining:""," ]"," },"," {"," "cell_type": "code","," "execution_count": 32,"," "id": "8b424de5-0774-41ee-979a-73cc51e13445","," "metadata": {"," "tags": []"," },"," "outputs": ["," {"," "data": {"," "image/png": "","," "text/plain": ["," "

“”,” ]”,” },”,” “metadata”: {},”,” “output_type”: “display_data””,” }”,” ],”,” “source”: [“,” “plt.plot(range(macrobatch_count), [get_lr(m) for m in range(macrobatch_count)])\n”,”,” “plt.title(\”Learning rate schedule\”)\n”,”,” “plt.xlabel(\”Macrobatch\”)\n”,”,” “plt.ylabel(\”Learning rate\”)\n”,”,” “plt.show()””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “b941d9d9-631b-4b05-b359-218f0404cbdd”,”,” “metadata”: {},”,” “source”: [“,” “Batch size is also varied throughout pretraining – in the beginning our network is still poorly fit and our priority is to make rapid progress, so we use small batch sizes (noisy gradients). As the network gets better, it becomes more important to take our optimization steps carefully, so we gradually increase batch sizes (for more precise gradients).””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 33,”,” “id”: “239b68eb-f9fd-446b-913e-7650f49b48a7″,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “def get_batch_size(macrobatch):\n”,”,” ” \”\”\”Gradually increasing batch size through the training process based on description in Cramming.\”\”\”\n”,”,” ” if macrobatch >= 2**7:\n”,”,” ” return 2**11\n”,”,” ” elif macrobatch >= 2**6:\n”,”,” ” return 2**10\n”,”,” ” elif macrobatch >= 2**5:\n”,”,” ” return 2**9\n”,”,” ” elif macrobatch >= 2**4:\n”,”,” ” return 2**8\n”,”,” ” elif macrobatch >= 2**3:\n”,”,” ” return 2**7\n”,”,” ” elif macrobatch >= 2**2:\n”,”,” ” return 2**6\n”,”,” ” else:\n”,”,” ” return 2**5 # == 32, is minibatch_size, which we don’t want to go below””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “9761539d-1371-46a7-a67e-c158c7d7afcf”,”,” “metadata”: {},”,” “source”: [“,” “Again, a visualization might help:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 34,”,” “id”: “4a5e31f7-5898-4f9f-8458-10ad5a66517f”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [“,” {“,” “data”: {“,” “image/png”: “”,”,” “text/plain”: [“,” “

“”,” ]”,” },”,” “metadata”: {},”,” “output_type”: “display_data””,” }”,” ],”,” “source”: [“,” “plt.plot(range(macrobatch_count), [get_batch_size(m) for m in range(macrobatch_count)])\n”,”,” “plt.title(\”Batch size schedule\”)\n”,”,” “plt.xlabel(\”Macrobatch\”)\n”,”,” “plt.ylabel(\”Batch size\”)\n”,”,” “plt.show()””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “135dd4c6-7497-4559-bc7d-25ee9b3c099d”,”,” “metadata”: {},”,” “source”: [“,” “The optimizer is configured so that regularization is only applied to weights, not biases:””,” ]”,” },”,” {“,” “cell_type”: “code”,”,” “execution_count”: 35,”,” “id”: “be266c48-482d-4a01-9983-1256f24bc04e”,”,” “metadata”: {“,” “tags”: []”,” },”,” “outputs”: [],”,” “source”: [“,” “param_groups = [{‘params’: [p for p in list(bert.parameters()) + list(mlm_head.parameters()) if p.dim() >= 2], ‘weight_decay’: 0.01},\n”,”,” ” {‘params’: [p for p in list(bert.parameters()) + list(mlm_head.parameters()) if p.dim() < 2], 'weight_decay': 0}]\n","," "optimizer = optim.AdamW(param_groups, lr = get_lr(0), betas = (0.9, 0.98), eps = 1e-12, fused = True)\n","," "scaler = GradScaler() # This is for automatic mixed precision""," ]"," },"," {"," "cell_type": "markdown","," "id": "37b3c0a5-6bf0-441e-b042-f784027e7495","," "metadata": {},"," "source": ["," "And finally we get to the training loop. A small amount of complexity is introduced by the use of automatic mixed precision, but this is worthwhile as it speeds up training approximately two-fold (!). Note that the code below will run for 2 iterations only, to do a full training run comment out/remove the indicated lines. Later on in the notebook there will be a sleight of hand where we actually load weights fitted during a full run.""," ]"," },"," {"," "cell_type": "code","," "execution_count": 36,"," "id": "ee60f6c0-6f40-4103-913b-9253bdfe99f3","," "metadata": {"," "tags": []"," },"," "outputs": ["," {"," "name": "stderr","," "output_type": "stream","," "text": ["," " 1%|▎ | 2/256 [08:46<18:34:23, 263.24s/it]\n""," ]"," }"," ],"," "source": ["," "cumulative_samples = 0\n","," "mbs = macrobatches(macrobatch_size)\n","," "\n","," "f_log = \"BERT.csv\"\n","," "\n","," "with open(f_log, \"w\") as f:\n","," " f.write(\"macrobatch,cumulative_samples,duration,loss,lr\")\n","," "\n","," "for macrobatch in tqdm(range(macrobatch_count)):\n","," " \n","," " # REMOVE THE BELOW TWO LINES IF YOU WANT TO DO A FULL TRAINING RUN\n","," " if macrobatch == 2:\n","," " break\n","," " \n","," " # Set chunk training parameters\n","," " batch_size = get_batch_size(macrobatch)\n","," " lr = get_lr(macrobatch)\n","," " for g in optimizer.param_groups:\n","," " g['lr'] = get_lr(macrobatch)\n","," " \n","," " # Load a new macrobatch\n","," " xs, ys = next(mbs)\n","," " torch_xs = torch.LongTensor(xs).to(device)\n","," " torch_ys = torch.LongTensor(ys).to(device)\n","," " \n","," " # Iterate over the batches in the macrobatch\n","," " for i in range(0, xs.shape[0] // batch_size):\n","," " batch_start_time = time.time()\n","," " \n","," " batch_loss = 0\n","," " \n","," " batch_start_idx = i * batch_size\n","," " batch_data_torch_xs = torch_xs[batch_start_idx:batch_start_idx+batch_size, :]\n","," " batch_data_torch_ys = torch_ys[batch_start_idx:batch_start_idx+batch_size, :]\n","," "\n","," " optimizer.zero_grad(set_to_none = True)\n","," " \n","," " # Iterate over the minibatches in the batch\n","," " for j in range(0, batch_size // minibatch_size):\n","," " mb_start_idx = minibatch_size * j\n","," " mb_end_idx = mb_start_idx + minibatch_size\n","," "\n","," " # Use automatic mixed precision for (much) better performance\n","," " with autocast(device_type='cuda', dtype=torch.float16):\n","," " _, loss = mlm_head(bert(batch_data_torch_xs[mb_start_idx:mb_end_idx]), batch_data_torch_ys[mb_start_idx:mb_end_idx])\n","," "\n","," " # Correct for the fact that we are minibatching\n","," " corrected_loss = loss / (batch_size // minibatch_size)\n","," " batch_loss += corrected_loss\n","," "\n","," " # Need to use scaler.scale for automatic mixed precision\n","," " scaler.scale(corrected_loss).backward()\n","," " \n","," " cumulative_samples += minibatch_size\n","," "\n","," " # If we don't scaler.unscale_ here, gradient clipping will fail spectacularly, because it will act on arbitrarily scaled gradients\n","," " scaler.unscale_(optimizer)\n","," " torch.nn.utils.clip_grad_norm_(bert.parameters(), 0.5)\n","," " scaler.step(optimizer)\n","," " scaler.update()\n","," " \n","," " batch_duration = time.time() - batch_start_time\n","," " \n","," " with open(f_log, \"a\") as f:\n","," " f.write(f\"{macrobatch:03d},{cumulative_samples:09d},{batch_duration:05.3f},{batch_loss.item():05.2f},{lr:0.6f}\")\n","," " \n","," " torch.save(bert.state_dict(), f\"BERT.weights\")\n","," " torch.save(mlm_head.state_dict(), f\"MLMHead.weights\")""," ]"," },"," {"," "cell_type": "markdown","," "id": "b46ceabe-5be7-4231-b7da-8441b3b98ab8","," "metadata": {},"," "source": ["," "### Fine-tuning""," ]"," },"," {"," "cell_type": "markdown","," "id": "b84246f0-eead-4d52-ba7b-cf0a19cff44d","," "metadata": {},"," "source": ["," "We have a pre-trained BERT! That's great, but now let's do something real with it (because who cares about guessing tokens that we hid on purpose?).\n","," "\n","," "We will now train on two tasks from the GLUE benchmark. There are more tasks in the benchmark (BERT was evaluated on 6 other tasks as well) and you can find code for those in this repository, but for this notebook we will keep it simple.\n","," "\n","," "The first task we will use is STS-B, the Semantic Textual Similary Benchmark. In this task, the model gets two input sentences, and has to predict how similar they are in meaning on a scale of 1 to 5. An example from the data set is: (sentence 1) \"People are playing cricket.\", (sentence 2) \"Men are playing cricket.\", which has a similarity score of 3.2 in the data, indicating that the sentences are fairly similar, but not perfectly so (people could include women, after all). This is the only regression task in the GLUE benchmark on which BERT is evaluated, all other tasks are classification tasks.\n","," "\n","," "The second task is SST-2, the Stanford Sentiment Treebank. This task takes only a single sentence as input, and the requirement is to determine whether the sentence has a positive (1) or negative (0) sentiment. A positive example from the dataset is \"a gorgeous , witty , seductive movie . \", while a negative example is \"unflinchingly bleak and desperate \".\n","," "\n","," "Because we are using different datasets than during pretraining, we need to redo some of the logic for cleaning input sentences:""," ]"," },"," {"," "cell_type": "code","," "execution_count": 37,"," "id": "ad990858-dbc5-44db-bbaf-220a3cfe49c2","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def encode_sentence(sentence, bpe):\n","," " \"\"\"Take a string sentence and turn it into a list of BPE tokens.\"\"\"\n","," " encoded = []\n","," " for atom in atomize(clean_string(sentence)):\n","," " if atom.isalpha():\n","," " encoded += [tok for tok in bpe.encode('_' + atom)]\n","," " else:\n","," " encoded.append(atom)\n","," " return encoded""," ]"," },"," {"," "cell_type": "markdown","," "id": "cd15a973-30c8-4ee6-8a12-b0b8e17f3f9a","," "metadata": {},"," "source": ["," "For a number of downstream tasks (including STS-B), the structure of the input data is also different because now we have _2_ sentences as input. This is where the `[SEP]` token that we introduced earlier comes in. Training data is fed in in the form `['[CLS]'] + sentence1 + ['[SEP]'] + sentence2 + ['[PAD]'] * x` where the number of `[PAD]` tokens at the end is chosen such that the total number of tokens has the right length.\n","," "\n","," "Some downstream tasks also include data that requires more than 128 tokens to represent. This is where we can benefit from the fact that our BERT uses relative position embeddings rather than absolute: we can use the model on longer samples by need.""," ]"," },"," {"," "cell_type": "code","," "execution_count": 38,"," "id": "fe263ed3-4613-4890-a2b5-83512fe06dbe","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def prep_data(left_sentences, right_sentences, targets, bpe, length = 128, classification_target = True):\n","," " \"\"\"\n","," " Take two lists of string sentences and a list of targets and generate Torch matrices for training.\n","," " If the targets are not categorical (i.e. we're doing regression), set classification_target = False.\n","," " \"\"\"\n","," " assert len(left_sentences) == len(right_sentences) == len(targets)\n","," " num_samples = len(left_sentences)\n","," " tok2idx = bpe.token_mapping()\n","," " xs = []\n","," " ys = []\n","," " skipped = 0\n","," " for i in range(num_samples):\n","," " left_encoded = encode_sentence(left_sentences[i], bpe)\n","," " right_encoded = encode_sentence(right_sentences[i], bpe)\n","," " x = ([tok2idx[\"[CLS]\"]] + \n","," " [tok2idx[e] for e in left_encoded] +\n","," " [tok2idx[\"[SEP]\"]] +\n","," " [tok2idx[e] for e in right_encoded] +\n","," " [tok2idx[\"[PAD]\"]] * (length - len(left_encoded) - len(right_encoded) - 2))\n","," " if len(x) == length:\n","," " xs.append(x)\n","," " ys.append(targets[i])\n","," " else:\n","," " print(f\"WARNING: Skipping sample of length {len(x)} at index {i}\")\n","," " skipped += 1\n","," " print(f\"Skipped {skipped} samples ({skipped/num_samples * 100}%)\")\n","," " joint = list(zip(xs, ys))\n","," " random.shuffle(joint)\n","," " xs, ys = zip(*joint)\n","," " xs = torch.LongTensor(xs).to(device)\n","," " if classification_target:\n","," " ys = torch.LongTensor(ys).to(device)\n","," " else:\n","," " ys = torch.tensor(ys, device = device)\n","," " return xs, ys""," ]"," },"," {"," "cell_type": "markdown","," "id": "8f131088-6d65-4649-b9a4-b05e9103a77e","," "metadata": {},"," "source": ["," "This helper function implements the training loop for fine tuning. The datasets for the GLUE benchmark all fit in GPU memory entirely, so there is no need for \"macrobatching\" logic here. The implementation also uses a much simpler learning rate schedule, and a constant batch size. The current settings are likely suboptimal, and are probably one of the first places to look if you want to get a higher GLUE score out of this model (and you don't just want to train for longer).""," ]"," },"," {"," "cell_type": "code","," "execution_count": 39,"," "id": "44ba075a-3ab8-47a0-b51c-d5faf486842c","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def finetune(bert, head, xs, ys):\n","," " \"\"\"\n","," " Fairly simple training procedure going through xs and ys for 5 epochs.\n","," " Batch size is constant, learning rate is warmed up and decayed but is constant per epoch.\n","," " `bert` and `head` are modified in-place (you might not want to do this at home), this function does not return anything.\n","," " \"\"\"\n","," " batch_size = 16\n","," " total_samples = xs.shape[0]\n","," " \n","," " param_groups = [{'params': [p for p in list(bert.parameters()) + list(head.parameters()) if p.dim() >= 2], ‘weight_decay’: 0.01},\n”,”,” ” {‘params’: [p for p in list(bert.parameters()) + list(head.parameters()) if p.dim() < 2], 'weight_decay': 0}]\n","," " optimizer = optim.AdamW(param_groups, lr = 4e-5, betas = (0.9, 0.98), eps = 1e-12, fused = True)\n","," " scaler = GradScaler()\n","," " \n","," " # Poor man's warmup and decay\n","," " lrs = [1e-5, 4e-5, 4e-5, 2e-5, 1e-5]\n","," " \n","," " for epoch in tqdm(range(5)):\n","," " \n","," " for g in optimizer.param_groups:\n","," " g['lr'] = lrs[epoch]\n","," " \n","," " i = 0\n","," " while i < total_samples:\n","," "\n","," " batch_xs = xs[i:min(i+batch_size, total_samples), :]\n","," " batch_ys = ys[i:min(i+batch_size, total_samples)]\n","," "\n","," " optimizer.zero_grad(set_to_none = True)\n","," "\n","," " with autocast(device_type='cuda', dtype=torch.float16):\n","," " _, loss = head(bert(batch_xs), batch_ys)\n","," "\n","," " scaler.scale(loss).backward()\n","," " scaler.step(optimizer)\n","," " scaler.update()\n","," "\n","," " i += batch_size""," ]"," },"," {"," "cell_type": "markdown","," "id": "cee62797-fc05-4bf0-bb48-b027613067ad","," "metadata": {},"," "source": ["," "Once we have a finetuned trained model, we need to evaluate its performance on test data (we will use the validation datasets for that, we haven't used them for any other purpose so that can give a safe performance estimate).""," ]"," },"," {"," "cell_type": "code","," "execution_count": 40,"," "id": "997cf039-b295-418a-b7e0-0cde43f9f946","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def cls_predict(bert, cls_head, xs):\n","," " \"\"\"Take a trained BERT and CLSHead and generate predictions for the inputs xs.\"\"\"\n","," " pred = []\n","," " for i in tqdm(range(xs.shape[0])):\n","," " with torch.no_grad():\n","," " logits, _ = cls_head(bert(xs[i:i+1]))\n","," " pred.append(torch.argmax(logits))\n","," " return torch.LongTensor(pred).to(device)""," ]"," },"," {"," "cell_type": "code","," "execution_count": 41,"," "id": "90840e68-ca60-4f52-ac82-32247cc6032f","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def reg_predict(bert, reg_head, xs):\n","," " \"\"\"Take a trained BERT and RegHead and generate predictions for the inputs xs.\"\"\"\n","," " pred = []\n","," " for i in tqdm(range(xs.shape[0])):\n","," " with torch.no_grad():\n","," " y_hat, _ = reg_head(bert(xs[i:i+1]))\n","," " pred.append(y_hat)\n","," " return torch.tensor(pred, device = device)""," ]"," },"," {"," "cell_type": "markdown","," "id": "f0fa9b3f-64b9-4591-88fd-b322e0d890af","," "metadata": {},"," "source": ["," "And once we have predictions for our validation data, we need a way to quantify how good those predictions are. STS-B uses Spearman correlation in the GLUE benchmark, and SST-2 uses plain old accuracy.""," ]"," },"," {"," "cell_type": "code","," "execution_count": 42,"," "id": "7a8a669a-a400-4279-a6c2-db58373f73ae","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def accuracy(pred, true):\n","," " \"\"\"Calculate accuracy from predictions and ground truth.\"\"\"\n","," " return (torch.sum(pred == true) / pred.shape[0]).item()""," ]"," },"," {"," "cell_type": "code","," "execution_count": 43,"," "id": "224ca233-f31e-4370-85a2-178b2ea8eae8","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def spearman(pred, true):\n","," " \"\"\"Return Spearman correlation for predictions and ground truth.\"\"\"\n","," " return scipy.stats.spearmanr(np.array(pred.cpu()), np.array(true.cpu())).correlation""," ]"," },"," {"," "cell_type": "markdown","," "id": "a36d486d-8aee-49b0-a4ff-1cb47742247d","," "metadata": {},"," "source": ["," "We now have all the components we need to evaluate the performance of our pretrained model after finetuning:""," ]"," },"," {"," "cell_type": "code","," "execution_count": 44,"," "id": "abec54ab-d8ff-43ab-b5e9-b2f06f02bd40","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def eval_stsb(bert, bpe, length = 192):\n","," " \"\"\"Take a pre-trained BERT, finetune on STS-B, and return performance.\"\"\"\n","," " reg_head_stsb = RegHead(config).to(device)\n","," " \n","," " stsb_train = load_dataset(\"glue\", \"stsb\", split = \"train\")\n","," " stsb_train_xs, stsb_train_ys = prep_data([s['sentence1'] for s in stsb_train],\n","," " [s['sentence2'] for s in stsb_train],\n","," " [s['label'] for s in stsb_train],\n","," " bpe,\n","," " length = length,\n","," " classification_target = False)\n","," "\n","," " finetune(bert, reg_head_stsb, stsb_train_xs, stsb_train_ys)\n","," "\n","," " stsb_val = load_dataset(\"glue\", \"stsb\", split = \"validation\")\n","," " stsb_val_xs, stsb_val_ys = prep_data([s['sentence1'] for s in stsb_val],\n","," " [s['sentence2'] for s in stsb_val],\n","," " [s['label'] for s in stsb_val],\n","," " bpe,\n","," " length = length,\n","," " classification_target = False)\n","," "\n","," " return spearman(reg_predict(bert, reg_head_stsb, stsb_val_xs), stsb_val_ys)""," ]"," },"," {"," "cell_type": "code","," "execution_count": 45,"," "id": "ff3d665c-4080-42e6-a7b3-dac57ec0facc","," "metadata": {"," "tags": []"," },"," "outputs": [],"," "source": ["," "def eval_sst2(bert, bpe):\n","," " \"\"\"Take a pre-trained BERT, finetune on SST2, and return performance.\"\"\"\n","," " cls_head_sst2 = CLSHead(config, 2).to(device)\n","," " \n","," " sst2_train = load_dataset(\"glue\", \"sst2\", split = \"train\")\n","," " sst2_train_xs, sst2_train_ys = prep_data([s['sentence'] for s in sst2_train],\n","," " ['' for s in sst2_train],\n","," " [s['label'] for s in sst2_train],\n","," " bpe)\n","," "\n","," " finetune(bert, cls_head_sst2, sst2_train_xs, sst2_train_ys)\n","," "\n","," " sst2_val = load_dataset(\"glue\", \"sst2\", split = \"validation\")\n","," " sst2_val_xs, sst2_val_ys = prep_data([s['sentence'] for s in sst2_val],\n","," " ['' for s in sst2_val],\n","," " [s['label'] for s in sst2_val],\n","," " bpe)\n","," "\n","," " return accuracy(cls_predict(bert, cls_head_sst2, sst2_val_xs), sst2_val_ys)""," ]"," },"," {"," "cell_type": "code","," "execution_count": 47,"," "id": "3ab3865f-c6cb-45d9-8802-d1d597b17b50","," "metadata": {"," "tags": []"," },"," "outputs": ["," {"," "name": "stdout","," "output_type": "stream","," "text": ["," "number of parameters: 110164992\n","," "BERT.weights -> Starting STS-B…\n””,” ]”,” },”,” {“,” “name”: “stderr”,”,” “output_type”: “stream”,”,” “text”: [“,” “Found cached dataset glue (/home/sam/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n””,” ]”,” },”,” {“,” “name”: “stdout”,”,” “output_type”: “stream”,”,” “text”: [“,” “Skipped 0 samples (0.0%)\n””,” ]”,” },”,” {“,” “name”: “stderr”,”,” “output_type”: “stream”,”,” “text”: [“,” “100%|█████████████████████████████████████████████| 5/5 [05:18<00:00, 63.77s/it]\n","," "Found cached dataset glue (/home/sam/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n""," ]"," },"," {"," "name": "stdout","," "output_type": "stream","," "text": ["," "Skipped 0 samples (0.0%)\n""," ]"," },"," {"," "name": "stderr","," "output_type": "stream","," "text": ["," "100%|██████████████████████████████████████| 1500/1500 [00:13<00:00, 108.07it/s]\n""," ]"," },"," {"," "name": "stdout","," "output_type": "stream","," "text": ["," "BERT.weights -> STS-B score: 0.8353077198403909\n”,”,” “BERT.weights -> Starting SST-2…\n””,” ]”,” },”,” {“,” “name”: “stderr”,”,” “output_type”: “stream”,”,” “text”: [“,” “Found cached dataset glue (/home/sam/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n””,” ]”,” },”,” {“,” “name”: “stdout”,”,” “output_type”: “stream”,”,” “text”: [“,” “Skipped 0 samples (0.0%)\n””,” ]”,” },”,” {“,” “name”: “stderr”,”,” “output_type”: “stream”,”,” “text”: [“,” “100%|████████████████████████████████████████████| 5/5 [48:37<00:00, 583.49s/it]\n","," "Found cached dataset glue (/home/sam/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n""," ]"," },"," {"," "name": "stdout","," "output_type": "stream","," "text": ["," "Skipped 0 samples (0.0%)\n""," ]"," },"," {"," "name": "stderr","," "output_type": "stream","," "text": ["," "100%|████████████████████████████████████████| 872/872 [00:05<00:00, 147.42it/s]""," ]"," },"," {"," "name": "stdout","," "output_type": "stream","," "text": ["," "BERT.weights -> SST-2 score: 0.8841742873191833\n””,” ]”,” },”,” {“,” “name”: “stderr”,”,” “output_type”: “stream”,”,” “text”: [“,” “\n””,” ]”,” }”,” ],”,” “source”: [“,” “results = []\n”,”,” “bert = BERT(config).to(device)\n”,”,” “\n”,”,” “ws = \”BERT.weights\”\n”,”,” “\n”,”,” “print(f\”{ws} -> Starting STS-B…\”)\n”,”,” “bert.load_state_dict(torch.load(ws), strict = False)\n”,”,” “stsb_score = eval_stsb(bert, bert_bpe)\n”,”,” “print(f\”{ws} -> STS-B score: {stsb_score}\”)\n”,”,” “\n”,”,” “print(f\”{ws} -> Starting SST-2…\”)\n”,”,” “bert.load_state_dict(torch.load(ws), strict = False)\n”,”,” “sst2_score = eval_sst2(bert, bert_bpe)\n”,”,” “print(f\”{ws} -> SST-2 score: {sst2_score}\”)””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “41a47354-4d55-4160-9f37-b34d0b6175ab”,”,” “metadata”: {},”,” “source”: [“,” “The convention with GLUE is to scale these scores by 100, so we scored 83.5 on STS-B and 88.4 on SST-2. For reference, the original BERT-base scored 85.8 on STS-B and 93.5 on SST-2. Not bad! (Note though that BERT was evaluated on a held out test set, and in our case we’re evaluating on the validation set, so some caution is required).””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “618c5738-e9b9-4de4-8b10-364cfebea774″,”,” “metadata”: {},”,” “source”: [“,” “### Tinker time\n”,”,” “\n”,”,” “That’s it – you’ve now seen the whole process of training and evaluating a BERT lookalike. The training process is sufficiently fast that you can do some interesting experimentation even just on a laptop. Here are some results from variants that I’ve tried:\n”,”,” “\n”,”,” “| | Tokens seen | MLM loss | MNLI m | MNLI mm | QQP | QNLI | SST-2 | CoLA | CoLA run 2 | STS-B | MRPC | RTE | RTE run 2 | Average |\n”,”,” “| —————————————————- | ———– | ——– | —— | ——- | —- | —- | —– | —- | ———- | —– | —- | —- | ——— | ——- |\n”,”,” “| % samples longer than 128 tokens | | | 0.32 | 0.34 | 0.02 | 0.55 | 0 | 0 | 0 | 0 | 0 | 12.6 | 12.6 | |\n”,”,” “| —————————————————- | ———– | ——– | —— | ——- | —- | —- | —– | —- | ———- | —– | —- | —- | ——— | ——- |\n”,”,” “| Absolute position embeddings | 2^30 | 2.07 | 75.8 | 75.3 | 84.2 | 82.3 | 88.2 | 36.9 | 35.7 | 81 | 82.2 | 52.5 | 50 | 73.0 |\n”,”,” “| Relative position embeddings | 2^30 | 1.99 | 76.5 | 77.4 | 86.2 | 85.2 | 87.6 | 37.2 | 37.3 | 83.5 | 84.2 | 53.1 | 57 | 74.8 |\n”,”,” “| Relative position embeddings | 2^31 | 1.82 | 77.8 | 77.5 | 86.2 | 86.3 | 88.2 | 42.2 | 45.2 | 84.7 | 85.5 | 52.3 | 50.9 | 75.7 |\n”,”,” “| Relative position embeddings, [Sophia optimizer](https://arxiv.org/abs/2305.14342) | 2^30 | 1.92 | 76.4 | 76.4 | 85.2 | 84.1 | 88 | 44.4 | 25.2 | 80.2 | 85.1 | 46.2 | 56 | 73.5 |\n”,”,” “| Relative position embeddings, Sophia, span objective | 2^30 | 3.74 | 75.6 | 76.5 | 84.6 | 83.9 | 84.7 | 29.5 | 30.4 | 83.9 | 83.4 | 58.1 | 63.5 | 73.7 |\n”,”,” “| Relative position embeddings, Sophia, span objective | 2^31 | 3.58 | 76.1 | 76.2 | 85.2 | 85.1 | 87.4 | 37.3 | 42.9 | 83.7 | 87 | 64.6 | 53.8 | 75.6 |\n”,”,” “| —————————————————- | ———– | ——– | —— | ——- | —- | —- | —– | —- | ———- | —– | —- | —- | ——— | ——- |\n”,”,” “| Cramming results on 2080 TI (Arxiv version) | 2^32-ish | 1.84 | 82.8 | 83.4 | 87.2 | 89 | 91.5 | 47.2 | – | 83.1 | 86.2 | 54 | – | 78.3 |\n”,”,” “\n”,”,” “As mentioned above, the finetuning process could probably be improved quite a bit – in particular MNLI scores seem low compared to what’s reported in the Cramming paper (also for runs with poor MLM loss, where the Cramming paper still manages to obtain good MNLI performance). It might also be possible to still do better with Sophia – in my tests Sophia improved MLM performance but that did not translate to better GLUE performance. However, I didn’t really do any hyperparameter optimization, using only a single set of (mostly default) settings.\n”,”,” “\n”,”,” “For Cramming, the authors also implemented \”sparse token prediction\” which improves efficiency by only generating token predictions for masked tokens. This wouldn’t affect the accuracy of the model, but it would make it faster to train. Similarly, something like [FlashAttention](https://arxiv.org/abs/2205.14135) could bring some welcome performance gains that make laptop training more feasible. Both of these changes would add some complexity, however.\n”,”,” “\n”,”,” “What will you try?””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “67c96f0c-366a-4ef4-ab78-29f92f4a150a”,”,” “metadata”: {“,” “tags”: []”,” },”,” “source”: [“,” “### Bloopers\n”,”,” “\n”,”,” “A variety of things went wrong as I went through the process of getting this BERT to train. I learned a lot from making, finding and solving these mistakes so it seems worth at least mentioning them:\n”,”,” “* `torch.optim.Adam` and `torch.optim.AdamW` are *not* the same thing, and `Adam` (without `W`) actually fails to converge on this model. The difference is in the way weight decay (~ L2 regularization) is implemented in the optimizer.\n”,”,” “* Similarly I had convergence issues when I forgot to apply the layernorm right after the embedding layer. This modification to the original BERT architecture is mentioned in the Cramming paper as improving training stability, so it was interesting to see that play out.\n”,”,” “* I also faced issues when I used Torch default random initializations for most of the model weights. The current initialization scheme comes from [nanoGPT](https://github.com/karpathy/nanoGPT) and again it makes the difference between convergence and divergence.\n”,”,” “\n”,”,” “Needless to say, these seemingly subtle differences can be hard to identify: [this blog post](http://karpathy.github.io/2019/04/25/recipe/) by Andrej Karpathy was very helpful for strategizing through the debugging process.””,” ]”,” },”,” {“,” “cell_type”: “markdown”,”,” “id”: “a321cbe6-169c-4a55-9bb4-8527e4500772″,”,” “metadata”: {},”,” “source”: [“,” “### Have fun, and good luck!””,” ]”,” }”,” ],”,” “metadata”: {“,” “kernelspec”: {“,” “display_name”: “Python 3 (ipykernel)”,”,” “language”: “python”,”,” “name”: “python3″”,” },”,” “language_info”: {“,” “codemirror_mode”: {“,” “name”: “ipython”,”,” “version”: 3″,” },”,” “file_extension”: “.py”,”,” “mimetype”: “text/x-python”,”,” “name”: “python”,”,” “nbconvert_exporter”: “python”,”,” “pygments_lexer”: “ipython3″,”,” “version”: “3.8.10””,” }”,” },”,” “nbformat”: 4,”,” “nbformat_minor”: 5″,”}”],”stylingDirectives”:null,”csv”:null,”csvError”:null,”dependabotInfo”:{“showConfigurationBanner”:false,”configFilePath”:null,”networkDependabotPath”:”/samvher/bert-for-laptops/network/updates”,”dismissConfigurationNoticePath”:”/settings/dismiss-notice/dependabot_configuration_notice”,”configurationNoticeDismissed”:null,”repoAlertsPath”:”/samvher/bert-for-laptops/security/dependabot”,”repoSecurityAndAnalysisPath”:”/samvher/bert-for-laptops/settings/security_analysis”,”repoOwnerIsOrg”:false,”currentUserCanAdminRepo”:false},”displayName”:”BERT_for_laptops.ipynb”,”displayUrl”:”https://notebooks.githubusercontent.com/view/ipynb?browser=chrome&bypass_fastly=true&color_mode=auto&commit=1f904b870c455e909d2858428779b657e69445aa&device=unknown_device&docs_host=https%3A%2F%2Fdocs.github.com&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f73616d766865722f626572742d666f722d6c6170746f70732f316639303462383730633435356539303964323835383432383737396236353765363934343561612f424552545f666f725f6c6170746f70732e6970796e62&logged_in=false&nwo=samvher%2Fbert-for-laptops&path=BERT_for_laptops.ipynb&platform=windows&repository_id=673536158&repository_type=Repository&version=41″,”headerInfo”:{“blobSize”:”157 KB”,”deleteInfo”:{“deleteTooltip”:”You must be signed in to make or propose changes”},”editInfo”:{“editTooltip”:”You must be signed in to make or propose changes”},”ghDesktopPath”:”https://desktop.github.com”,”gitLfsPath”:null,”onBranch”:true,”shortPath”:”3ebb8d2″,”siteNavLoginPath”:”/login?return_to=https%3A%2F%2Fgithub.com%2Fsamvher%2Fbert-for-laptops%2Fblob%2Fmain%2FBERT_for_laptops.ipynb”,”isCSV”:false,”isRichtext”:false,”toc”:null,”lineInfo”:{“truncatedLoc”:”2163″,”truncatedSloc”:”2163″},”mode”:”file”},”image”:false,”isCodeownersFile”:null,”isPlain”:false,”isValidLegacyIssueTemplate”:false,”issueTemplateHelpUrl”:”https://docs.github.com/articles/about-issue-and-pull-request-templates”,”issueTemplate”:null,”discussionTemplate”:null,”language”:”Jupyter Notebook”,”languageID”:185,”large”:false,”loggedIn”:false,”newDiscussionPath”:”/samvher/bert-for-laptops/discussions/new”,”newIssuePath”:”/samvher/bert-for-laptops/issues/new”,”planSupportInfo”:{“repoIsFork”:null,”repoOwnedByCurrentUser”:null,”requestFullPath”:”/samvher/bert-for-laptops/blob/main/BERT_for_laptops.ipynb”,”showFreeOrgGatedFeatureMessage”:null,”showPlanSupportBanner”:null,”upgradeDataAttributes”:null,”upgradePath”:null},”publishBannersInfo”:{“dismissActionNoticePath”:”/settings/dismiss-notice/publish_action_from_dockerfile”,”dismissStackNoticePath”:”/settings/dismiss-notice/publish_stack_from_file”,”releasePath”:”/samvher/bert-for-laptops/releases/new?marketplace=true”,”showPublishActionBanner”:false,”showPublishStackBanner”:false},”renderImageOrRaw”:false,”richText”:null,”renderedFileInfo”:{“identityUUID”:”b676dd1d-ab16-4794-abf4-6ce7b1f90426″,”renderFileType”:”ipynb”,”size”:161162},”shortPath”:null,”tabSize”:8,”topBannersInfo”:{“overridingGlobalFundingFile”:false,”globalPreferredFundingPath”:null,”repoOwner”:”samvher”,”repoName”:”bert-for-laptops”,”showInvalidCitationWarning”:false,”citationHelpUrl”:”https://docs.github.com/en/github/creating-cloning-and-archiving-repositories/creating-a-repository-on-github/about-citation-files”,”showDependabotConfigurationBanner”:false,”actionsOnboardingTip”:null},”truncated”:false,”viewable”:true,”workflowRedirectUrl”:null,”symbols”:{“timedOut”:false,”notAnalyzed”:true,”symbols”:[]}},”copilotInfo”:null,”csrf_tokens”:{“/samvher/bert-for-laptops/branches”:{“post”:”HBql4DVjmCCCqK-S-NYaXOL4qnQLiY6XL3t87k_RLo5QcZjrL7V6TMDxzdDpZXplgTxjCDORDcqjpWc0fX67VQ”},”/repos/preferences”:{“post”:”5BOnpR9OqcoTdoMIqy38WZkttHnWWiidXLcGl_AkTC2yLzrWGUrBD90gEOJk_oKQCoFwVVnvQH1HtMuNAIzDlA”}}},”title”:”bert-for-laptops/BERT_for_laptops.ipynb at main · samvher/bert-for-laptops”}

Read More