{ "cells": [ { "cell_type": "markdown", "id": "45398736-7e89-4263-89c8-92153baff553", "metadata": { "id": "45398736-7e89-4263-89c8-92153baff553" }, "source": [ "\n", "\n", "\n", "\n", "
\n", "\n", "
This notebook is an adapted version of https://github.com/rasbt/LLMs-from-scratch\n", "
\n", "
" ] }, { "cell_type": "markdown", "id": "66dd524e-864c-4012-b0a2-ccfc56e80024", "metadata": { "id": "66dd524e-864c-4012-b0a2-ccfc56e80024" }, "source": [ "# Building and training a GPT-2 like model to generate text" ] }, { "cell_type": "markdown", "id": "0a3bdf9e-2ff0-4a57-abab-ede2d955a237", "metadata": { "id": "0a3bdf9e-2ff0-4a57-abab-ede2d955a237" }, "source": [ "- GPT (Generative Pre-trained Transformer) generates words sequentially and is based on the decoder part of the original transformer architecture.\n", "- Therefore, this LLM are often referred to as \"decoder-only\" or \"decoder-like\" LLM\n", "- LLMs are larger, mainly due to their vast number of parameters, not the amount of code\n", "\n", "---\n", "In this class:\n", "- Build and train GPT-2 like model\n", "- We'll specifically code the architecture of the smallest GPT-2 model (124 million parameters), as outlined in Radford et al.'s [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)\n", "- Evaluate the model\n", "- Load openly available pretrained weights from OpenAI into our model" ] }, { "cell_type": "code", "source": [ "pip install tiktoken" ], "metadata": { "id": "fwW6nfaqDJ4b" }, "id": "fwW6nfaqDJ4b", "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "id": "92b989e9-da36-4159-b212-799184764dd9", "metadata": { "id": "92b989e9-da36-4159-b212-799184764dd9" }, "outputs": [], "source": [ "from importlib.metadata import version\n", "\n", "pkgs = [\"matplotlib\",\n", " \"numpy\",\n", " \"tiktoken\",\n", " \"torch\",\n", " \"tensorflow\" # For OpenAI's pretrained weights\n", " ]\n", "for p in pkgs:\n", " print(f\"{p} version: {version(p)}\")" ] }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "vEXv1XM8eN1w" }, "id": "vEXv1XM8eN1w" }, { "cell_type": "markdown", "id": "bdc1cf3f-82d8-46c7-9ecc-58979ce87cdd", "metadata": { "id": "bdc1cf3f-82d8-46c7-9ecc-58979ce87cdd" }, "source": [ "## Building GPT model" ] }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn" ], "metadata": { "id": "KxLv-_E8fXEi" }, "id": "KxLv-_E8fXEi", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "GPT_CONFIG_124M = {\n", " \"vocab_size\": 50257, # Vocabulary size\n", " \"context_length\": 256, # Context length\n", " \"emb_dim\": 768, # Embedding dimension\n", " \"n_heads\": 12, # Number of attention heads\n", " \"n_layers\": 12, # Number of layers\n", " \"drop_rate\": 0.1, # Dropout rate\n", " \"qkv_bias\": False # Query-Key-Value bias\n", "}" ], "metadata": { "id": "tlN0Xe26dS3l" }, "id": "tlN0Xe26dS3l", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "- We use short variable names to avoid long lines of code later\n", "- `\"vocab_size\"` indicates a vocabulary size of 50,257 words, supported by the BPE tokenizer\n", "- `\"context_length\"` represents the model's maximum input token count, as enabled by positional embeddings covered\n", "- `\"emb_dim\"` is the embedding size for token inputs, converting each input token into a 768-dimensional vector\n", "- `\"n_heads\"` is the number of attention heads in the multi-head attention mechanism\n", "- `\"n_layers\"` is the number of transformer blocks within the model, which we'll implement in upcoming sections\n", "- `\"drop_rate\"` is the dropout mechanism's intensity; 0.1 means dropping 10% of hidden units during training to mitigate overfitting\n", "- `\"qkv_bias\"` decides if the `Linear` layers in the multi-head attention mechanism should include a bias vector when computing query (Q), key (K), and value (V) tensors; we'll disable this option, which is standard practice in modern LLMs;" ], "metadata": { "id": "843lA0NsdWA9" }, "id": "843lA0NsdWA9" }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "WUrPj9C1dfjd" }, "id": "WUrPj9C1dfjd" }, { "cell_type": "markdown", "source": [ "NOTE: We will not code the GPT backbone. If you want to have a look at it, then check the implementation [here](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch04/01_main-chapter-code/ch04.ipynb)" ], "metadata": { "id": "J-zQjpcHeb74" }, "id": "J-zQjpcHeb74" }, { "cell_type": "markdown", "source": [ "### Layer normalization\n", "\n", "- Layer normalization, also known as LayerNorm ([Ba et al. 2016](https://arxiv.org/abs/1607.06450)), centers the activations of a neural network layer around a mean of 0 and normalizes their variance to 1\n", "- This stabilizes training and enables faster convergence to effective weights\n", "- Layer normalization is applied both before and after the multi-head attention module within the transformer block, which we will implement later; it's also applied before the final output layer" ], "metadata": { "id": "qfoRsTQAfGeZ" }, "id": "qfoRsTQAfGeZ" }, { "cell_type": "code", "source": [ "class LayerNorm(nn.Module):\n", " def __init__(self, emb_dim):\n", " super().__init__()\n", " self.eps = 1e-5\n", " self.scale = nn.Parameter(torch.ones(emb_dim))\n", " self.shift = nn.Parameter(torch.zeros(emb_dim))\n", "\n", " def forward(self, x):\n", " mean = x.mean(dim=-1, keepdim=True)\n", " var = x.var(dim=-1, keepdim=True, unbiased=False)\n", " norm_x = (x - mean) / torch.sqrt(var + self.eps)\n", " return self.scale * norm_x + self.shift" ], "metadata": { "id": "E34ACo3lflBK" }, "id": "E34ACo3lflBK", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "**Scale and shift**\n", "\n", "- Note that in addition to performing the normalization by subtracting the mean and dividing by the variance, we added two trainable parameters, a `scale` and a `shift` parameter\n", "- The initial `scale` (multiplying by 1) and `shift` (adding 0) values don't have any effect; however, `scale` and `shift` are trainable parameters that the LLM automatically adjusts during training if it is determined that doing so would improve the model's performance on its training task\n", "- This allows the model to learn appropriate scaling and shifting that best suit the data it is processing\n", "- Note that we also add a smaller value (`eps`) before computing the square root of the variance; this is to avoid division-by-zero errors if the variance is 0\n", "\n", "**Biased variance**\n", "- In the variance calculation above, setting `unbiased=False` means using the formula $\\frac{\\sum_i (x_i - \\bar{x})^2}{n}$ to compute the variance where n is the sample size (here, the number of features or columns); this formula does not include Bessel's correction (which uses `n-1` in the denominator), thus providing a biased estimate of the variance\n", "- For LLMs, where the embedding dimension `n` is very large, the difference between using n and `n-1`\n", " is negligible" ], "metadata": { "id": "d6gT786UfxO5" }, "id": "d6gT786UfxO5" }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "bGbaRtXzgCg-" }, "id": "bGbaRtXzgCg-" }, { "cell_type": "markdown", "source": [ "### GELU activation\n", "\n", "- GELU ([Hendrycks and Gimpel 2016](https://arxiv.org/abs/1606.08415)) can be implemented in several ways; the exact version is defined as GELU(x)=x⋅Φ(x), where Φ(x) is the cumulative distribution function of the standard Gaussian distribution.\n", "- In practice, it's common to implement a computationally cheaper approximation: $\\text{GELU}(x) \\approx 0.5 \\cdot x \\cdot \\left(1 + \\tanh\\left[\\sqrt{\\frac{2}{\\pi}} \\cdot \\left(x + 0.044715 \\cdot x^3\\right)\\right]\\right)\n", "$ (the original GPT-2 model was also trained with this approximation)" ], "metadata": { "id": "wYE1jqbwgKbo" }, "id": "wYE1jqbwgKbo" }, { "cell_type": "code", "source": [ "class GELU(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, x):\n", " return 0.5 * x * (1 + torch.tanh(\n", " torch.sqrt(torch.tensor(2.0 / torch.pi)) *\n", " (x + 0.044715 * torch.pow(x, 3))\n", " ))" ], "metadata": { "id": "nUBkKnRRgUqy" }, "id": "nUBkKnRRgUqy", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import matplotlib.pyplot as plt\n", "\n", "gelu, relu = GELU(), nn.ReLU()\n", "\n", "# Some sample data\n", "x = torch.linspace(-3, 3, 100)\n", "y_gelu, y_relu = gelu(x), relu(x)\n", "\n", "plt.figure(figsize=(8, 3))\n", "for i, (y, label) in enumerate(zip([y_gelu, y_relu], [\"GELU\", \"ReLU\"]), 1):\n", " plt.subplot(1, 2, i)\n", " plt.plot(x, y)\n", " plt.title(f\"{label} activation function\")\n", " plt.xlabel(\"x\")\n", " plt.ylabel(f\"{label}(x)\")\n", " plt.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.show()" ], "metadata": { "id": "5HfDPKo4gc9E" }, "id": "5HfDPKo4gc9E", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "- As we can see, ReLU is a piecewise linear function that outputs the input directly if it is positive; otherwise, it outputs zero\n", "- GELU is a smooth, non-linear function that approximates ReLU but with a non-zero gradient for negative values (except at approximately -0.75)" ], "metadata": { "id": "aXoF1nu-glPk" }, "id": "aXoF1nu-glPk" }, { "cell_type": "markdown", "source": [ "### Feed-forward Neural network\n", "\n" ], "metadata": { "id": "FBhr0rtNgr9i" }, "id": "FBhr0rtNgr9i" }, { "cell_type": "code", "source": [ "class FeedForward(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.layers = nn.Sequential(\n", " nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n", " GELU(),\n", " nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n", " )\n", "\n", " def forward(self, x):\n", " return self.layers(x)" ], "metadata": { "id": "lv92SN0dg0QV" }, "id": "lv92SN0dg0QV", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "dMjZHvZ-g472" }, "id": "dMjZHvZ-g472" }, { "cell_type": "markdown", "source": [ "#### Adding shortcut connections\n", "\n", "- Shortcut connections are also called skip or residual connections\n", "- A shortcut connection creates an alternative shorter path for the gradient to flow through the network\n", "- This is achieved by adding the output of one layer to the output of a later layer, usually skipping one or more layers in between\n", "- Let's illustrate this idea with a small example network:\n", "\n", "" ], "metadata": { "id": "hAn-g5WdhM3c" }, "id": "hAn-g5WdhM3c" }, { "cell_type": "markdown", "source": [ "### Connecting attention and linear layers in a transformer block" ], "metadata": { "id": "c3dYT9s4hsWH" }, "id": "c3dYT9s4hsWH" }, { "cell_type": "code", "execution_count": null, "id": "86000d74-624a-48f0-86da-f41926cb9e04", "metadata": { "id": "86000d74-624a-48f0-86da-f41926cb9e04" }, "outputs": [], "source": [ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n", " super().__init__()\n", " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", "\n", " self.d_out = d_out\n", " self.num_heads = num_heads\n", " self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n", "\n", " self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n", " self.dropout = nn.Dropout(dropout)\n", " self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n", "\n", " keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n", " queries = self.W_query(x)\n", " values = self.W_value(x)\n", "\n", " # We implicitly split the matrix by adding a `num_heads` dimension\n", " # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n", " keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n", " values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n", " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n", "\n", " # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n", " keys = keys.transpose(1, 2)\n", " queries = queries.transpose(1, 2)\n", " values = values.transpose(1, 2)\n", "\n", " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", "\n", " # Original mask truncated to the number of tokens and converted to boolean\n", " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", "\n", " # Use the mask to fill attention scores\n", " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", "\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " attn_weights = self.dropout(attn_weights)\n", "\n", " # Shape: (b, num_tokens, num_heads, head_dim)\n", " context_vec = (attn_weights @ values).transpose(1, 2)\n", "\n", " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", " context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n", " context_vec = self.out_proj(context_vec) # optional projection\n", "\n", " return context_vec\n", "\n", "\n", "\n", "\n", "class TransformerBlock(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.att = MultiHeadAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n", " context_length=cfg[\"context_length\"],\n", " num_heads=cfg[\"n_heads\"],\n", " dropout=cfg[\"drop_rate\"],\n", " qkv_bias=cfg[\"qkv_bias\"])\n", " self.ff = FeedForward(cfg)\n", " self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n", " self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n", " self.drop_shortcut = nn.Dropout(cfg[\"drop_rate\"])\n", "\n", " def forward(self, x):\n", " # Shortcut connection for attention block\n", " shortcut = x\n", " x = self.norm1(x)\n", " x = self.att(x) # Shape [batch_size, num_tokens, emb_size]\n", " x = self.drop_shortcut(x)\n", " x = x + shortcut # Add the original input back\n", "\n", " # Shortcut connection for feed forward block\n", " shortcut = x\n", " x = self.norm2(x)\n", " x = self.ff(x)\n", " x = self.drop_shortcut(x)\n", " x = x + shortcut # Add the original input back\n", "\n", " return x\n" ] }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "0dD7RX35ihvX" }, "id": "0dD7RX35ihvX" }, { "cell_type": "markdown", "id": "09c6cf0f-7458-48a2-97fd-aa5068d65e8c", "metadata": { "id": "09c6cf0f-7458-48a2-97fd-aa5068d65e8c" }, "source": [ "- We use dropout of 0.1 above, but it's relatively common to train LLMs without dropout nowadays\n", "- Modern LLMs also don't use bias vectors in the `nn.Linear` layers for the query, key, and value matrices (unlike earlier GPT models), which is achieved by setting `\"qkv_bias\": False`\n", "- We reduce the context length (`context_length`) of only 256 tokens to reduce the computational resource requirements for training the model, whereas the original 124 million parameter GPT-2 model used 1024 tokens\n", " - This is so that more readers will be able to follow and execute the code examples on their laptop computer\n", " - However, please feel free to increase the `context_length` to 1024 tokens (this would not require any code changes)\n", " - We will also load a model with a 1024 `context_length` later from pretrained weights" ] }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "lXQdLo6riqr_" }, "id": "lXQdLo6riqr_" }, { "cell_type": "markdown", "source": [ "### Coding the GPT model\n", "\n", "- We are almost there: now let's plug in the transformer block\n", "- Note that the transformer block is repeated multiple times; in the case of the smallest 124M GPT-2 model, we repeat it 12 times:" ], "metadata": { "id": "D1fD1FJXjBpv" }, "id": "D1fD1FJXjBpv" }, { "cell_type": "markdown", "source": [ "" ], "metadata": { "id": "Si7fNw3djOvv" }, "id": "Si7fNw3djOvv" }, { "cell_type": "markdown", "source": [ "- The corresponding code implementation, where `cfg[\"n_layers\"] = 12`:" ], "metadata": { "id": "RTYiSF1FjSCh" }, "id": "RTYiSF1FjSCh" }, { "cell_type": "code", "source": [ "class GPTModel(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n", " self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n", " self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n", "\n", " self.trf_blocks = nn.Sequential(\n", " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", "\n", " self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n", " self.out_head = nn.Linear(\n", " cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False\n", " )\n", "\n", " def forward(self, in_idx):\n", " batch_size, seq_len = in_idx.shape\n", " tok_embeds = self.tok_emb(in_idx)\n", " pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n", " x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n", " x = self.drop_emb(x)\n", " x = self.trf_blocks(x)\n", " x = self.final_norm(x)\n", " logits = self.out_head(x)\n", " return logits" ], "metadata": { "id": "4YvGZU7rjVLW" }, "id": "4YvGZU7rjVLW", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "torch.manual_seed(123)\n", "model = GPTModel(GPT_CONFIG_124M)\n", "\n", "\n", "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Total number of parameters: {total_params:,}\")" ], "metadata": { "id": "d6nYc8eBjqyi" }, "id": "d6nYc8eBjqyi", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "- As we see above, this model has 163M, not 124M parameters; why?\n", "- In the original GPT-2 paper, the researchers applied weight tying, which means that they reused the token embedding layer (`tok_emb`) as the output layer, which means setting `self.out_head.weight = self.tok_emb.weight`\n", "- The token embedding layer projects the 50,257-dimensional one-hot encoded input tokens to a 768-dimensional embedding representation\n", "- The output layer projects 768-dimensional embeddings back into a 50,257-dimensional representation so that we can convert these back into words (more about that in the next section)\n", "- So, the embedding and output layer have the same number of weight parameters, as we can see based on the shape of their weight matrices\n", "- However, a quick note about its size: we previously referred to it as a 124M parameter model; we can double check this number as follows:" ], "metadata": { "id": "xjVbFi43j4Db" }, "id": "xjVbFi43j4Db" }, { "cell_type": "code", "source": [ "print(\"Token embedding layer shape:\", model.tok_emb.weight.shape)\n", "print(\"Output layer shape:\", model.out_head.weight.shape)" ], "metadata": { "id": "OsS1_HTyj6mh" }, "id": "OsS1_HTyj6mh", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Generating text\n", "\n" ], "metadata": { "id": "OlpQH4HOj5Ri" }, "id": "OlpQH4HOj5Ri" }, { "cell_type": "markdown", "source": [ "- LLMs like the GPT model we implemented above are used to generate one word at a time" ], "metadata": { "id": "eUpgOlLzkgAj" }, "id": "eUpgOlLzkgAj" }, { "cell_type": "markdown", "source": [ "\n" ], "metadata": { "id": "M6jGaFw4nfRa" }, "id": "M6jGaFw4nfRa" }, { "cell_type": "markdown", "id": "59f80895-be35-4bb5-81cb-f357ef7367fe", "metadata": { "id": "59f80895-be35-4bb5-81cb-f357ef7367fe" }, "source": [ "- Next, we use the `generate_text_simple` function to generate text\n", "- In addition, we define two convenience functions, `text_to_token_ids` and `token_ids_to_text`, for converting between token and text representations" ] }, { "cell_type": "markdown", "id": "741881f3-cee0-49ad-b11d-b9df3b3ac234", "metadata": { "id": "741881f3-cee0-49ad-b11d-b9df3b3ac234" }, "source": [ "" ] }, { "cell_type": "code", "execution_count": null, "id": "5e062b82-3540-48ce-8eb4-009686d0d16c", "metadata": { "id": "5e062b82-3540-48ce-8eb4-009686d0d16c" }, "outputs": [], "source": [ "import tiktoken\n", "def generate_text_simple(model, idx, max_new_tokens, context_size):\n", " # idx is (batch, n_tokens) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", "\n", " # Crop current context if it exceeds the supported context size\n", " # E.g., if LLM supports only 5 tokens, and the context size is 10\n", " # then only the last 5 tokens are used as context\n", " idx_cond = idx[:, -context_size:]\n", "\n", " # Get the predictions\n", " with torch.no_grad():\n", " logits = model(idx_cond)\n", "\n", " # Focus only on the last time step\n", " # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)\n", " logits = logits[:, -1, :]\n", "\n", " # Apply softmax to get probabilities\n", " probas = torch.softmax(logits, dim=-1) # (batch, vocab_size)\n", "\n", " # Get the idx of the vocab entry with the highest probability value\n", " idx_next = torch.argmax(probas, dim=-1, keepdim=True) # (batch, 1)\n", "\n", " # Append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)\n", "\n", " return idx\n", "\n", "def text_to_token_ids(text, tokenizer):\n", " encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})\n", " encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension\n", " return encoded_tensor\n", "\n", "def token_ids_to_text(token_ids, tokenizer):\n", " flat = token_ids.squeeze(0) # remove batch dimension\n", " return tokenizer.decode(flat.tolist())\n", "\n", "start_context = \"Every effort moves you\"\n", "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", "\n", "token_ids = generate_text_simple(\n", " model=model,\n", " idx=text_to_token_ids(start_context, tokenizer),\n", " max_new_tokens=10,\n", " context_size=GPT_CONFIG_124M[\"context_length\"]\n", ")\n", "\n", "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] }, { "cell_type": "markdown", "id": "e4d3249b-b2a0-44c4-b589-ae4b403b8305", "metadata": { "id": "e4d3249b-b2a0-44c4-b589-ae4b403b8305" }, "source": [ "- As we can see above, the model does not produce good text because it has not been trained yet\n", "- How do we measure or capture what \"good text\" is, in a numeric form, to track it during training?\n", "- The next subsection introduces metrics to calculate a loss metric for the generated outputs that we can use to measure the training progress" ] }, { "cell_type": "markdown", "id": "955f9e1a-7bf7-40d8-b1fa-eacabdee8d8e", "metadata": { "id": "955f9e1a-7bf7-40d8-b1fa-eacabdee8d8e" }, "source": [ "
" ] }, { "cell_type": "markdown", "id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98", "metadata": { "id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98" }, "source": [ "### Calculating the text generation loss: cross-entropy and perplexity" ] }, { "cell_type": "markdown", "id": "9e1ba8aa-fb03-4d25-957f-fe8778762440", "metadata": { "id": "9e1ba8aa-fb03-4d25-957f-fe8778762440" }, "source": [ "- Suppose we have an `inputs` tensor containing the token IDs for 2 training examples (rows)\n", "- Corresponding to the `inputs`, the `targets` contain the desired token IDs that we want the model to generate\n", "- Notice that the `targets` are the `inputs` shifted by 1 position" ] }, { "cell_type": "code", "execution_count": null, "id": "6b5402f8-ec0c-4a44-9892-18a97779ee4f", "metadata": { "id": "6b5402f8-ec0c-4a44-9892-18a97779ee4f" }, "outputs": [], "source": [ "inputs = torch.tensor([[16833, 3626, 6100], # [\"every effort moves\",\n", " [40, 1107, 588]]) # \"I really like\"]\n", "\n", "targets = torch.tensor([[3626, 6100, 345 ], # [\" effort moves you\",\n", " [1107, 588, 11311]]) # \" really like chocolate\"]" ] }, { "cell_type": "markdown", "id": "33dc0645-ac2c-4973-9b40-6da40515bede", "metadata": { "id": "33dc0645-ac2c-4973-9b40-6da40515bede" }, "source": [ "- Feeding the `inputs` to the model, we obtain the logits vector for the 2 input examples that consist of 3 tokens each\n", "- Each of the tokens is a 50,257-dimensional vector corresponding to the size of the vocabulary\n", "- Applying the softmax function, we can turn the logits tensor into a tensor of the same dimension containing probability scores" ] }, { "cell_type": "code", "execution_count": null, "id": "e7b6ec51-6f8c-49bd-a349-95ba38b46fb6", "metadata": { "id": "e7b6ec51-6f8c-49bd-a349-95ba38b46fb6" }, "outputs": [], "source": [ "with torch.no_grad():\n", " logits = model(inputs)\n", "\n", "probas = torch.softmax(logits, dim=-1) # Probability of each token in vocabulary\n", "print(probas.shape) # Shape: (batch_size, num_tokens, vocab_size)\n", "print(probas)" ] }, { "cell_type": "markdown", "id": "5c36a382-b5e2-4de6-9e65-0b69b685013b", "metadata": { "id": "5c36a382-b5e2-4de6-9e65-0b69b685013b" }, "source": [ "- The figure below, using a very small vocabulary for illustration purposes, outlines how we convert the probability scores back into text" ] }, { "cell_type": "markdown", "id": "384d86a9-0013-476c-bb6b-274fd5f20b29", "metadata": { "id": "384d86a9-0013-476c-bb6b-274fd5f20b29" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "e8480efd-d419-4954-9ecc-2876055334bd", "metadata": { "id": "e8480efd-d419-4954-9ecc-2876055334bd" }, "source": [ "- We can apply the `argmax` function to convert the probability scores into predicted token IDs\n", "- The softmax function above produced a 50,257-dimensional vector for each token; the `argmax` function returns the position of the highest probability score in this vector, which is the predicted token ID for the given token" ] }, { "cell_type": "markdown", "id": "f3b84c9f-dd08-482e-b903-a86fe44e1144", "metadata": { "id": "f3b84c9f-dd08-482e-b903-a86fe44e1144" }, "source": [ "- Since we have 2 input batches with 3 tokens each, we obtain 2 by 3 predicted token IDs:" ] }, { "cell_type": "code", "execution_count": null, "id": "34ebd76a-16ec-4c17-8958-8a135735cc1c", "metadata": { "id": "34ebd76a-16ec-4c17-8958-8a135735cc1c" }, "outputs": [], "source": [ "token_ids = torch.argmax(probas, dim=-1, keepdim=True)\n", "print(\"Token IDs:\\n\", token_ids)" ] }, { "cell_type": "markdown", "id": "cee4072c-21ed-4df7-8721-dd2535362573", "metadata": { "id": "cee4072c-21ed-4df7-8721-dd2535362573" }, "source": [ "- If we decode these tokens, we find that these are quite different from the tokens we want the model to predict, namely the target tokens:" ] }, { "cell_type": "code", "execution_count": null, "id": "c990ead6-53cd-49a7-a6d1-14d8c1518249", "metadata": { "id": "c990ead6-53cd-49a7-a6d1-14d8c1518249" }, "outputs": [], "source": [ "print(f\"Targets batch 1: {token_ids_to_text(targets[0], tokenizer)}\")\n", "print(f\"Outputs batch 1: {token_ids_to_text(token_ids[0].flatten(), tokenizer)}\")" ] }, { "cell_type": "markdown", "id": "a53eb8a7-070e-46d6-930c-314ba55a6ff2", "metadata": { "id": "a53eb8a7-070e-46d6-930c-314ba55a6ff2" }, "source": [ "- That's because the model wasn't trained yet\n", "- To train the model, we need to know how far it is away from the correct predictions (targets)" ] }, { "cell_type": "markdown", "id": "ad90592f-0d5d-4ec8-9ff5-e7675beab10e", "metadata": { "id": "ad90592f-0d5d-4ec8-9ff5-e7675beab10e" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "c7251bf5-a079-4782-901d-68c9225d3157", "metadata": { "id": "c7251bf5-a079-4782-901d-68c9225d3157" }, "source": [ "- The token probabilities corresponding to the target indices are as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "54aef09c-d6e3-4238-8653-b3a1b0a1077a", "metadata": { "id": "54aef09c-d6e3-4238-8653-b3a1b0a1077a" }, "outputs": [], "source": [ "text_idx = 0\n", "target_probas_1 = probas[text_idx, [0, 1, 2], targets[text_idx]]\n", "print(\"Text 1:\", target_probas_1)\n", "\n", "text_idx = 1\n", "target_probas_2 = probas[text_idx, [0, 1, 2], targets[text_idx]]\n", "print(\"Text 2:\", target_probas_2)" ] }, { "cell_type": "markdown", "id": "a0e89a19-73c2-4e49-93b4-861f699f1cbf", "metadata": { "id": "a0e89a19-73c2-4e49-93b4-861f699f1cbf" }, "source": [ "- We want to maximize all these values, bringing them close to a probability of 1\n", "- In mathematical optimization, it is easier to maximize the logarithm of the probability score than the probability score itself; this is out of the scope of this book, but I have recorded a lecture with more details here: [L8.2 Logistic Regression Loss Function](https://www.youtube.com/watch?v=GxJe0DZvydM)" ] }, { "cell_type": "code", "execution_count": null, "id": "31402a67-a16e-4aeb-977e-70abb9c9949b", "metadata": { "id": "31402a67-a16e-4aeb-977e-70abb9c9949b" }, "outputs": [], "source": [ "# Compute logarithm of all token probabilities\n", "log_probas = torch.log(torch.cat((target_probas_1, target_probas_2)))\n", "print(log_probas)" ] }, { "cell_type": "markdown", "id": "c4261441-a511-4633-9c4c-67998af31b84", "metadata": { "id": "c4261441-a511-4633-9c4c-67998af31b84" }, "source": [ "- Next, we compute the average log probability:" ] }, { "cell_type": "code", "execution_count": null, "id": "9b003797-161b-4d98-81dc-e68320e09fec", "metadata": { "id": "9b003797-161b-4d98-81dc-e68320e09fec" }, "outputs": [], "source": [ "# Calculate the average probability for each token\n", "avg_log_probas = torch.mean(log_probas)\n", "print(avg_log_probas)" ] }, { "cell_type": "markdown", "id": "36d51994-ad17-4ba3-a6ec-f588b4b13585", "metadata": { "id": "36d51994-ad17-4ba3-a6ec-f588b4b13585" }, "source": [ "- The goal is to make this average log probability as large as possible by optimizing the model weights\n", "- Due to the log, the largest possible value is 0, and we are currently far away from 0" ] }, { "cell_type": "markdown", "id": "3de388a1-8a0a-4c94-8894-9041dc6ad514", "metadata": { "id": "3de388a1-8a0a-4c94-8894-9041dc6ad514" }, "source": [ "- In deep learning, instead of maximizing the average log-probability, it's a standard convention to minimize the *negative* average log-probability value; in our case, instead of maximizing -10.7722 so that it approaches 0, in deep learning, we would minimize 10.7722 so that it approaches 0\n", "- The value negative of -10.7722, i.e., 10.7722, is also called cross-entropy loss in deep learning" ] }, { "cell_type": "code", "execution_count": null, "id": "176ddf35-1c5f-4d7c-bf17-70f3e7069bd4", "metadata": { "id": "176ddf35-1c5f-4d7c-bf17-70f3e7069bd4" }, "outputs": [], "source": [ "neg_avg_log_probas = avg_log_probas * -1\n", "print(neg_avg_log_probas)" ] }, { "cell_type": "markdown", "id": "84eeb868-abd8-4028-82db-107546bf7c2c", "metadata": { "id": "84eeb868-abd8-4028-82db-107546bf7c2c" }, "source": [ "- PyTorch already implements a `cross_entropy` function that carries out the previous steps" ] }, { "cell_type": "markdown", "id": "5bd24b7f-b760-47ad-bc84-86d13794aa54", "metadata": { "id": "5bd24b7f-b760-47ad-bc84-86d13794aa54" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "e8aaf9dd-3ee6-42bf-a63f-6e93dbfb989d", "metadata": { "id": "e8aaf9dd-3ee6-42bf-a63f-6e93dbfb989d" }, "source": [ "- Before we apply the `cross_entropy` function, let's check the shape of the logits and targets" ] }, { "cell_type": "code", "execution_count": null, "id": "695d6f64-5084-4c23-aea4-105c9e38cfe4", "metadata": { "id": "695d6f64-5084-4c23-aea4-105c9e38cfe4" }, "outputs": [], "source": [ "# Logits have shape (batch_size, num_tokens, vocab_size)\n", "print(\"Logits shape:\", logits.shape)\n", "\n", "# Targets have shape (batch_size, num_tokens)\n", "print(\"Targets shape:\", targets.shape)" ] }, { "cell_type": "markdown", "id": "1d3d65f0-6566-4865-93e4-0c0bcb10cd06", "metadata": { "id": "1d3d65f0-6566-4865-93e4-0c0bcb10cd06" }, "source": [ "- For the `cross_entropy` function in PyTorch, we want to flatten these tensors by combining them over the batch dimension:" ] }, { "cell_type": "code", "execution_count": null, "id": "0e17e027-ab9f-4fb5-ac9b-a009b831c122", "metadata": { "id": "0e17e027-ab9f-4fb5-ac9b-a009b831c122" }, "outputs": [], "source": [ "logits_flat = logits.flatten(0, 1)\n", "targets_flat = targets.flatten()\n", "\n", "print(\"Flattened logits:\", logits_flat.shape)\n", "print(\"Flattened targets:\", targets_flat.shape)" ] }, { "cell_type": "markdown", "id": "4921a57f-3a79-473e-a863-6d63b495010f", "metadata": { "id": "4921a57f-3a79-473e-a863-6d63b495010f" }, "source": [ "- Note that the targets are the token IDs, which also represent the index positions in the logits tensors that we want to maximize\n", "- The `cross_entropy` function in PyTorch will automatically take care of applying the softmax and log-probability computation internally over those token indices in the logits that are to be maximized" ] }, { "cell_type": "code", "execution_count": null, "id": "62d0816e-b29a-4c8f-a9a5-a167562de978", "metadata": { "id": "62d0816e-b29a-4c8f-a9a5-a167562de978" }, "outputs": [], "source": [ "loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)\n", "print(loss)" ] }, { "cell_type": "markdown", "id": "0f15ce17-fd7b-4d8e-99da-b237523a7a80", "metadata": { "id": "0f15ce17-fd7b-4d8e-99da-b237523a7a80" }, "source": [ "- A concept related to the cross-entropy loss is the perplexity of an LLM\n", "- The perplexity is simply the exponential of the cross-entropy loss" ] }, { "cell_type": "code", "execution_count": null, "id": "168952a1-b964-4aa7-8e49-966fa26add54", "metadata": { "id": "168952a1-b964-4aa7-8e49-966fa26add54" }, "outputs": [], "source": [ "perplexity = torch.exp(loss)\n", "print(perplexity)" ] }, { "cell_type": "markdown", "id": "71ae26dd-d77e-41fd-b924-6bd103dd4ee7", "metadata": { "id": "71ae26dd-d77e-41fd-b924-6bd103dd4ee7" }, "source": [ "- The perplexity is often considered more interpretable because it can be understood as the effective vocabulary size that the model is uncertain about at each step (in the example above, that'd be 48,725 words or tokens)\n", "- In other words, perplexity provides a measure of how well the probability distribution predicted by the model matches the actual distribution of the words in the dataset\n", "- Similar to the loss, a lower perplexity indicates that the model predictions are closer to the actual distribution" ] }, { "cell_type": "markdown", "id": "2ec6c217-e429-40c7-ad71-5d0a9da8e487", "metadata": { "id": "2ec6c217-e429-40c7-ad71-5d0a9da8e487" }, "source": [ "### Calculating the training and validation set losses" ] }, { "cell_type": "markdown", "id": "530da89e-2448-436c-8f1b-28e8a31ef85c", "metadata": { "id": "530da89e-2448-436c-8f1b-28e8a31ef85c" }, "source": [ "- We use a relatively small dataset for training the LLM (in fact, only one short story)\n", "- The reasons are:\n", " - You can run the code examples in a few minutes on a laptop computer without a suitable GPU\n", " - The training finishes relatively fast (minutes instead of weeks), which is good for educational purposes\n", " - We use a text from the public domain, which can be included in this GitHub repository without violating any usage rights or bloating the repository size\n", "\n", "\n", "- For example, Llama 2 7B required 184,320 GPU hours on A100 GPUs to be trained on 2 trillion tokens\n", " - At the time of this writing, the hourly cost of an 8xA100 cloud server at AWS is approximately \\\\$30\n", " - So, via an off-the-envelope calculation, training this LLM would cost 184,320 / 8 * \\\\$30 = \\\\$690,000\n" ] }, { "cell_type": "code", "execution_count": null, "id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff", "metadata": { "id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff" }, "outputs": [], "source": [ "import os\n", "import urllib.request\n", "\n", "file_path = \"the-verdict.txt\"\n", "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt\"\n", "\n", "if not os.path.exists(file_path):\n", " with urllib.request.urlopen(url) as response:\n", " text_data = response.read().decode('utf-8')\n", " with open(file_path, \"w\", encoding=\"utf-8\") as file:\n", " file.write(text_data)\n", "else:\n", " with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", " text_data = file.read()" ] }, { "cell_type": "markdown", "id": "379330f1-80f4-4e34-8724-41d892b04cee", "metadata": { "id": "379330f1-80f4-4e34-8724-41d892b04cee" }, "source": [ "- A quick check that the text loaded ok by printing the first and last 100 words" ] }, { "cell_type": "code", "execution_count": null, "id": "6kgJbe4ehI4q", "metadata": { "id": "6kgJbe4ehI4q" }, "outputs": [], "source": [ "# First 100 characters\n", "print(text_data[:99])" ] }, { "cell_type": "code", "execution_count": null, "id": "j2XPde_ThM_e", "metadata": { "id": "j2XPde_ThM_e" }, "outputs": [], "source": [ "# Last 100 characters\n", "print(text_data[-99:])" ] }, { "cell_type": "code", "execution_count": null, "id": "6b46a952-d50a-4837-af09-4095698f7fd1", "metadata": { "id": "6b46a952-d50a-4837-af09-4095698f7fd1" }, "outputs": [], "source": [ "total_characters = len(text_data)\n", "total_tokens = len(tokenizer.encode(text_data))\n", "\n", "print(\"Characters:\", total_characters)\n", "print(\"Tokens:\", total_tokens)" ] }, { "cell_type": "markdown", "id": "a8830cb9-90f6-4e7c-8620-beeabc2d39f7", "metadata": { "id": "a8830cb9-90f6-4e7c-8620-beeabc2d39f7" }, "source": [ "- With 5,145 tokens, the text is very short for training an LLM, but again, it's for educational purposes (we will also load pretrained weights later)" ] }, { "cell_type": "markdown", "id": "bedcad87-a0e8-4b9d-ac43-4e927ccbb50f", "metadata": { "id": "bedcad87-a0e8-4b9d-ac43-4e927ccbb50f" }, "source": [ "- Next, we divide the dataset into a training and a validation set and use the data loaders to prepare the batches for LLM training\n", "- For visualization purposes, the figure below assumes a `max_length=6`, but for the training loader, we set the `max_length` equal to the context length that the LLM supports\n", "- The figure below only shows the input tokens for simplicity\n", " - Since we train the LLM to predict the next word in the text, the targets look the same as these inputs, except that the targets are shifted by one position" ] }, { "cell_type": "markdown", "id": "46bdaa07-ba96-4ac1-9d71-b3cc153910d9", "metadata": { "id": "46bdaa07-ba96-4ac1-9d71-b3cc153910d9" }, "source": [ "" ] }, { "cell_type": "code", "execution_count": null, "id": "0959c855-f860-4358-8b98-bc654f047578", "metadata": { "id": "0959c855-f860-4358-8b98-bc654f047578" }, "outputs": [], "source": [ "from torch.utils.data import Dataset, DataLoader\n", "\n", "def create_dataloader_v1(txt, batch_size=4, max_length=256,\n", " stride=128, shuffle=True, drop_last=True, num_workers=0):\n", " # Initialize the tokenizer\n", " tokenizer = tiktoken.get_encoding(\"gpt2\")\n", "\n", " # Create dataset\n", " dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n", "\n", " # Create dataloader\n", " dataloader = DataLoader(\n", " dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)\n", "\n", " return dataloader\n", "\n", "class GPTDatasetV1(Dataset):\n", " def __init__(self, txt, tokenizer, max_length, stride):\n", " self.input_ids = []\n", " self.target_ids = []\n", "\n", " # Tokenize the entire text\n", " token_ids = tokenizer.encode(txt, allowed_special={\"<|endoftext|>\"})\n", "\n", " # Use a sliding window to chunk the book into overlapping sequences of max_length\n", " for i in range(0, len(token_ids) - max_length, stride):\n", " input_chunk = token_ids[i:i + max_length]\n", " target_chunk = token_ids[i + 1: i + max_length + 1]\n", " self.input_ids.append(torch.tensor(input_chunk))\n", " self.target_ids.append(torch.tensor(target_chunk))\n", "\n", " def __len__(self):\n", " return len(self.input_ids)\n", "\n", " def __getitem__(self, idx):\n", " return self.input_ids[idx], self.target_ids[idx]\n", "# Train/validation ratio\n", "train_ratio = 0.90\n", "split_idx = int(train_ratio * len(text_data))\n", "train_data = text_data[:split_idx]\n", "val_data = text_data[split_idx:]\n", "\n", "\n", "torch.manual_seed(123)\n", "\n", "train_loader = create_dataloader_v1(\n", " train_data,\n", " batch_size=2,\n", " max_length=GPT_CONFIG_124M[\"context_length\"],\n", " stride=GPT_CONFIG_124M[\"context_length\"],\n", " drop_last=True,\n", " shuffle=True,\n", " num_workers=0\n", ")\n", "\n", "val_loader = create_dataloader_v1(\n", " val_data,\n", " batch_size=2,\n", " max_length=GPT_CONFIG_124M[\"context_length\"],\n", " stride=GPT_CONFIG_124M[\"context_length\"],\n", " drop_last=False,\n", " shuffle=False,\n", " num_workers=0\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f37b3eb0-854e-4895-9898-fa7d1e67566e", "metadata": { "id": "f37b3eb0-854e-4895-9898-fa7d1e67566e" }, "outputs": [], "source": [ "# Sanity check\n", "\n", "if total_tokens * (train_ratio) < GPT_CONFIG_124M[\"context_length\"]:\n", " print(\"Not enough tokens for the training loader. \"\n", " \"Try to lower the `GPT_CONFIG_124M['context_length']` or \"\n", " \"increase the `training_ratio`\")\n", "\n", "if total_tokens * (1-train_ratio) < GPT_CONFIG_124M[\"context_length\"]:\n", " print(\"Not enough tokens for the validation loader. \"\n", " \"Try to lower the `GPT_CONFIG_124M['context_length']` or \"\n", " \"decrease the `training_ratio`\")" ] }, { "cell_type": "markdown", "id": "e7ac3296-a4d1-4303-9ac5-376518960c33", "metadata": { "id": "e7ac3296-a4d1-4303-9ac5-376518960c33" }, "source": [ "- We use a relatively small batch size to reduce the computational resource demand, and because the dataset is very small to begin with\n", "- Llama 2 7B was trained with a batch size of 1024, for example" ] }, { "cell_type": "markdown", "id": "a8e0514d-b990-4dc0-9afb-7721993284a0", "metadata": { "id": "a8e0514d-b990-4dc0-9afb-7721993284a0" }, "source": [ "- An optional check that the data was loaded correctly:" ] }, { "cell_type": "code", "execution_count": null, "id": "ca0116d0-d229-472c-9fbf-ebc229331c3e", "metadata": { "id": "ca0116d0-d229-472c-9fbf-ebc229331c3e" }, "outputs": [], "source": [ "print(\"Train loader:\")\n", "for x, y in train_loader:\n", " print(x.shape, y.shape)\n", "\n", "print(\"\\nValidation loader:\")\n", "for x, y in val_loader:\n", " print(x.shape, y.shape)" ] }, { "cell_type": "markdown", "id": "f7b9b1a4-863d-456f-a8dd-c07fb5c024ed", "metadata": { "id": "f7b9b1a4-863d-456f-a8dd-c07fb5c024ed" }, "source": [ "- Another optional check that the token sizes are in the expected ballpark:" ] }, { "cell_type": "code", "execution_count": null, "id": "eb860488-5453-41d7-9870-23b723f742a0", "metadata": { "id": "eb860488-5453-41d7-9870-23b723f742a0" }, "outputs": [], "source": [ "train_tokens = 0\n", "for input_batch, target_batch in train_loader:\n", " train_tokens += input_batch.numel()\n", "\n", "val_tokens = 0\n", "for input_batch, target_batch in val_loader:\n", " val_tokens += input_batch.numel()\n", "\n", "print(\"Training tokens:\", train_tokens)\n", "print(\"Validation tokens:\", val_tokens)\n", "print(\"All tokens:\", train_tokens + val_tokens)" ] }, { "cell_type": "markdown", "id": "5c3085e8-665e-48eb-bb41-cdde61537e06", "metadata": { "id": "5c3085e8-665e-48eb-bb41-cdde61537e06" }, "source": [ "- Next, we implement a utility function to calculate the cross-entropy loss of a given batch\n", "- In addition, we implement a second utility function to compute the loss for a user-specified number of batches in a data loader" ] }, { "cell_type": "code", "execution_count": null, "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc", "metadata": { "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc" }, "outputs": [], "source": [ "def calc_loss_batch(input_batch, target_batch, model, device):\n", " input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", " logits = model(input_batch)\n", " loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())\n", " return loss\n", "\n", "\n", "def calc_loss_loader(data_loader, model, device, num_batches=None):\n", " total_loss = 0.\n", " if len(data_loader) == 0:\n", " return float(\"nan\")\n", " elif num_batches is None:\n", " num_batches = len(data_loader)\n", " else:\n", " # Reduce the number of batches to match the total number of batches in the data loader\n", " # if num_batches exceeds the number of batches in the data loader\n", " num_batches = min(num_batches, len(data_loader))\n", " for i, (input_batch, target_batch) in enumerate(data_loader):\n", " if i < num_batches:\n", " loss = calc_loss_batch(input_batch, target_batch, model, device)\n", " total_loss += loss.item()\n", " else:\n", " break\n", " return total_loss / num_batches" ] }, { "cell_type": "code", "execution_count": null, "id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a", "metadata": { "id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a" }, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "\n", "model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n", "\n", "\n", "torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n", "\n", "with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n", " train_loss = calc_loss_loader(train_loader, model, device)\n", " val_loss = calc_loss_loader(val_loader, model, device)\n", "\n", "print(\"Training loss:\", train_loss)\n", "print(\"Validation loss:\", val_loss)" ] }, { "cell_type": "markdown", "id": "43875e95-190f-4b17-8f9a-35034ba649ec", "metadata": { "id": "43875e95-190f-4b17-8f9a-35034ba649ec" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "b9339f8d-00cb-4206-af67-58c32bd72055", "metadata": { "id": "b9339f8d-00cb-4206-af67-58c32bd72055" }, "source": [ "## Training an LLM" ] }, { "cell_type": "markdown", "id": "652a4cf4-e98f-46d9-bdec-60e7ccb8d6bd", "metadata": { "id": "652a4cf4-e98f-46d9-bdec-60e7ccb8d6bd" }, "source": [ "- In this section, we finally implement the code for training the LLM\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "id": "Mtp4gY0ZO-qq", "metadata": { "id": "Mtp4gY0ZO-qq" }, "outputs": [], "source": [ "def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,\n", " eval_freq, eval_iter, start_context, tokenizer):\n", " # Initialize lists to track losses and tokens seen\n", " train_losses, val_losses, track_tokens_seen = [], [], []\n", " tokens_seen, global_step = 0, -1\n", "\n", " # Main training loop\n", " for epoch in range(num_epochs):\n", " model.train() # Set model to training mode\n", "\n", " for input_batch, target_batch in train_loader:\n", " optimizer.zero_grad() # Reset loss gradients from previous batch iteration\n", " loss = calc_loss_batch(input_batch, target_batch, model, device)\n", " loss.backward() # Calculate loss gradients\n", " optimizer.step() # Update model weights using loss gradients\n", " tokens_seen += input_batch.numel()\n", " global_step += 1\n", "\n", " # Optional evaluation step\n", " if global_step % eval_freq == 0:\n", " train_loss, val_loss = evaluate_model(\n", " model, train_loader, val_loader, device, eval_iter)\n", " train_losses.append(train_loss)\n", " val_losses.append(val_loss)\n", " track_tokens_seen.append(tokens_seen)\n", " print(f\"Ep {epoch+1} (Step {global_step:06d}): \"\n", " f\"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}\")\n", "\n", " # Print a sample text after each epoch\n", " generate_and_print_sample(\n", " model, tokenizer, device, start_context\n", " )\n", "\n", " return train_losses, val_losses, track_tokens_seen\n", "\n", "\n", "def evaluate_model(model, train_loader, val_loader, device, eval_iter):\n", " model.eval()\n", " with torch.no_grad():\n", " train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)\n", " val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)\n", " model.train()\n", " return train_loss, val_loss\n", "\n", "\n", "def generate_and_print_sample(model, tokenizer, device, start_context):\n", " model.eval()\n", " context_size = model.pos_emb.weight.shape[0]\n", " encoded = text_to_token_ids(start_context, tokenizer).to(device)\n", " with torch.no_grad():\n", " token_ids = generate_text_simple(\n", " model=model, idx=encoded,\n", " max_new_tokens=50, context_size=context_size\n", " )\n", " decoded_text = token_ids_to_text(token_ids, tokenizer)\n", " print(decoded_text.replace(\"\\n\", \" \")) # Compact print format\n", " model.train()" ] }, { "cell_type": "markdown", "id": "a301b333-b9d4-4eeb-a212-3a9874e3ac47", "metadata": { "id": "a301b333-b9d4-4eeb-a212-3a9874e3ac47" }, "source": [ "- Now, let's train the LLM using the training function defined above:" ] }, { "cell_type": "code", "execution_count": null, "id": "3422000b-7aa2-485b-92df-99372cd22311", "metadata": { "id": "3422000b-7aa2-485b-92df-99372cd22311" }, "outputs": [], "source": [ "# Note:\n", "# Uncomment the following code to calculate the execution time\n", "import time\n", "start_time = time.time()\n", "\n", "torch.manual_seed(123)\n", "model = GPTModel(GPT_CONFIG_124M)\n", "model.to(device)\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)\n", "\n", "num_epochs = 25\n", "train_losses, val_losses, tokens_seen = train_model_simple(\n", " model, train_loader, val_loader, optimizer, device,\n", " num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n", " start_context=\"Every effort moves you\", tokenizer=tokenizer\n", ")\n", "\n", "# Note:\n", "# Uncomment the following code to show the execution time\n", "end_time = time.time()\n", "execution_time_minutes = (end_time - start_time) / 60\n", "# print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0WSRu2i0iHJE", "metadata": { "id": "0WSRu2i0iHJE" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib.ticker import MaxNLocator\n", "\n", "\n", "def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):\n", " fig, ax1 = plt.subplots(figsize=(5, 3))\n", "\n", " # Plot training and validation loss against epochs\n", " ax1.plot(epochs_seen, train_losses, label=\"Training loss\")\n", " ax1.plot(epochs_seen, val_losses, linestyle=\"-.\", label=\"Validation loss\")\n", " ax1.set_xlabel(\"Epochs\")\n", " ax1.set_ylabel(\"Loss\")\n", " ax1.legend(loc=\"upper right\")\n", " ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis\n", "\n", " # Create a second x-axis for tokens seen\n", " ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis\n", " ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks\n", " ax2.set_xlabel(\"Tokens seen\")\n", "\n", " fig.tight_layout() # Adjust layout to make room\n", " plt.savefig(\"loss-plot.pdf\")\n", " plt.show()\n", "\n", "epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n", "plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)" ] }, { "cell_type": "markdown", "id": "8bc83ded-5f80-4e1c-bf4d-ccb59999d995", "metadata": { "id": "8bc83ded-5f80-4e1c-bf4d-ccb59999d995" }, "source": [ "- Looking at the results above, we can see that the model starts out generating incomprehensible strings of words, whereas towards the end, it's able to produce grammatically more or less correct sentences\n", "- However, based on the training and validation set losses, we can see that the model starts overfitting\n", "- If we were to check a few passages it writes towards the end, we would find that they are contained in the training set verbatim -- it simply memorizes the training data\n", "- Later, we will cover decoding strategies that can mitigate this memorization by a certain degree\n", "- Note that the overfitting here occurs because we have a very, very small training set, and we iterate over it so many times\n", " - The LLM training here primarily serves educational purposes; we mainly want to see that the model can learn to produce coherent text\n", " - Instead of spending weeks or months on training this model on vast amounts of expensive hardware, we load pretrained weights later" ] }, { "cell_type": "markdown", "id": "eb380c42-b31c-4ee1-b8b9-244094537272", "metadata": { "id": "eb380c42-b31c-4ee1-b8b9-244094537272" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "699f45fc-bf78-42f2-bd24-2355db41b28f", "metadata": { "id": "699f45fc-bf78-42f2-bd24-2355db41b28f" }, "source": [ "## Decoding strategies to control randomness" ] }, { "cell_type": "markdown", "id": "6be9086e-2c27-41da-97d0-49137d0ba3c7", "metadata": { "id": "6be9086e-2c27-41da-97d0-49137d0ba3c7" }, "source": [ "- Inference is relatively cheap with a relatively small LLM as the GPT model we trained above, so there's no need to use a GPU for it in case you used a GPU for training it above\n", "- Using the `generate_text_simple` function that we used earlier inside the simple training function, we can generate new text one word (or token) at a time\n", "- The next generated token is the token corresponding to the largest probability score among all tokens in the vocabulary" ] }, { "cell_type": "code", "execution_count": null, "id": "2734cee0-f6f9-42d5-b71c-fa7e0ef28b6d", "metadata": { "id": "2734cee0-f6f9-42d5-b71c-fa7e0ef28b6d" }, "outputs": [], "source": [ "model.eval()\n", "\n", "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", "\n", "token_ids = generate_text_simple(\n", " model=model,\n", " idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n", " max_new_tokens=25,\n", " context_size=GPT_CONFIG_124M[\"context_length\"]\n", ")\n", "\n", "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] }, { "cell_type": "markdown", "id": "d25dbe31-bb7c-4893-b25b-47d0492d4aa4", "metadata": { "id": "d25dbe31-bb7c-4893-b25b-47d0492d4aa4" }, "source": [ "- Even if we execute the `generate_text_simple` function above multiple times, the LLM will always generate the same outputs\n", "- We now introduce two concepts, so-called decoding strategies, to modify the `generate_text_simple`: *temperature scaling* and *top-k* sampling\n", "- These will allow the model to control the randomness and diversity of the generated text" ] }, { "cell_type": "markdown", "id": "4bb6f380-a798-4fd9-825c-17b7cd29a994", "metadata": { "id": "4bb6f380-a798-4fd9-825c-17b7cd29a994" }, "source": [ "### Temperature scaling" ] }, { "cell_type": "markdown", "id": "a7f4f53c-0612-43d3-aa82-52447eac50fa", "metadata": { "id": "a7f4f53c-0612-43d3-aa82-52447eac50fa" }, "source": [ "- Previously, we always sampled the token with the highest probability as the next token using `torch.argmax`\n", "- To add variety, we can sample the next token using The `torch.multinomial(probs, num_samples=1)`, sampling from a probability distribution\n", "- Here, each index's chance of being picked corresponds to its probability in the input tensor" ] }, { "cell_type": "markdown", "id": "e7531bae-d5de-44c0-bc78-78fed077e22a", "metadata": { "id": "e7531bae-d5de-44c0-bc78-78fed077e22a" }, "source": [ "- Here's a little recap of generating the next token, assuming a very small vocabulary for illustration purposes:" ] }, { "cell_type": "code", "execution_count": null, "id": "01a5ce39-3dc8-4c35-96bc-6410a1e42412", "metadata": { "id": "01a5ce39-3dc8-4c35-96bc-6410a1e42412" }, "outputs": [], "source": [ "vocab = {\n", " \"closer\": 0,\n", " \"every\": 1,\n", " \"effort\": 2,\n", " \"forward\": 3,\n", " \"inches\": 4,\n", " \"moves\": 5,\n", " \"pizza\": 6,\n", " \"toward\": 7,\n", " \"you\": 8,\n", "}\n", "\n", "inverse_vocab = {v: k for k, v in vocab.items()}\n", "\n", "# Suppose input is \"every effort moves you\", and the LLM\n", "# returns the following logits for the next token:\n", "next_token_logits = torch.tensor(\n", " [4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]\n", ")\n", "\n", "probas = torch.softmax(next_token_logits, dim=0)\n", "next_token_id = torch.argmax(probas).item()\n", "\n", "# The next generated token is then as follows:\n", "print(inverse_vocab[next_token_id])" ] }, { "cell_type": "code", "execution_count": null, "id": "6400572f-b3c8-49e2-95bc-433e55c5b3a1", "metadata": { "id": "6400572f-b3c8-49e2-95bc-433e55c5b3a1" }, "outputs": [], "source": [ "torch.manual_seed(123)\n", "next_token_id = torch.multinomial(probas, num_samples=1).item()\n", "print(inverse_vocab[next_token_id])" ] }, { "cell_type": "markdown", "id": "c63d0a27-830b-42b5-9986-6d1a7de04dd9", "metadata": { "id": "c63d0a27-830b-42b5-9986-6d1a7de04dd9" }, "source": [ "- Instead of determining the most likely token via `torch.argmax`, we use `torch.multinomial(probas, num_samples=1)` to determine the most likely token by sampling from the softmax distribution\n", "- For illustration purposes, let's see what happens when we sample the next token 1,000 times using the original softmax probabilities:" ] }, { "cell_type": "code", "execution_count": null, "id": "b23b863e-252a-403c-b5b1-62bc0a42319f", "metadata": { "id": "b23b863e-252a-403c-b5b1-62bc0a42319f" }, "outputs": [], "source": [ "def print_sampled_tokens(probas):\n", " torch.manual_seed(123) # Manual seed for reproducibility\n", " sample = [torch.multinomial(probas, num_samples=1).item() for i in range(1_000)]\n", " sampled_ids = torch.bincount(torch.tensor(sample))\n", " for i, freq in enumerate(sampled_ids):\n", " print(f\"{freq} x {inverse_vocab[i]}\")\n", "\n", "print_sampled_tokens(probas)" ] }, { "cell_type": "markdown", "id": "32e7d9cf-a26d-4d9a-8664-4af1efa73832", "metadata": { "id": "32e7d9cf-a26d-4d9a-8664-4af1efa73832" }, "source": [ "- We can control the distribution and selection process via a concept called temperature scaling\n", "- \"Temperature scaling\" is just a fancy word for dividing the logits by a number greater than 0\n", "- Temperatures greater than 1 will result in more uniformly distributed token probabilities after applying the softmax\n", "- Temperatures smaller than 1 will result in more confident (sharper or more peaky) distributions after applying the softmax" ] }, { "cell_type": "code", "execution_count": null, "id": "0759e4c8-5362-467c-bec6-b0a19d1ba43d", "metadata": { "id": "0759e4c8-5362-467c-bec6-b0a19d1ba43d" }, "outputs": [], "source": [ "def softmax_with_temperature(logits, temperature):\n", " scaled_logits = logits / temperature\n", " return torch.softmax(scaled_logits, dim=0)\n", "\n", "# Temperature values\n", "temperatures = [1, 0.1, 5] # Original, higher confidence, and lower confidence\n", "\n", "# Calculate scaled probabilities\n", "scaled_probas = [softmax_with_temperature(next_token_logits, T) for T in temperatures]" ] }, { "cell_type": "code", "execution_count": null, "id": "2e66e613-4aca-4296-a984-ddd0d80c6578", "metadata": { "id": "2e66e613-4aca-4296-a984-ddd0d80c6578" }, "outputs": [], "source": [ "# Plotting\n", "x = torch.arange(len(vocab))\n", "bar_width = 0.15\n", "\n", "fig, ax = plt.subplots(figsize=(5, 3))\n", "for i, T in enumerate(temperatures):\n", " rects = ax.bar(x + i * bar_width, scaled_probas[i], bar_width, label=f'Temperature = {T}')\n", "\n", "ax.set_ylabel('Probability')\n", "ax.set_xticks(x)\n", "ax.set_xticklabels(vocab.keys(), rotation=90)\n", "ax.legend()\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"temperature-plot.pdf\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "d750e989-842a-4cfa-a44b-cf44d6e49163", "metadata": { "id": "d750e989-842a-4cfa-a44b-cf44d6e49163" }, "source": [ "- We can see that the rescaling via temperature 0.1 results in a sharper distribution, approaching `torch.argmax`, such that the most likely word is almost always selected:" ] }, { "cell_type": "code", "execution_count": null, "id": "e4600713-c51e-4f53-bf58-040a6eb362b8", "metadata": { "id": "e4600713-c51e-4f53-bf58-040a6eb362b8" }, "outputs": [], "source": [ "print_sampled_tokens(scaled_probas[1])" ] }, { "cell_type": "markdown", "id": "526e93cb-8e2a-42a1-b1ba-4fd5fe64c26b", "metadata": { "id": "526e93cb-8e2a-42a1-b1ba-4fd5fe64c26b" }, "source": [ "- The rescaled probabilities via temperature 5 are more uniformly distributed:" ] }, { "cell_type": "code", "execution_count": null, "id": "9dfb48f0-bc3f-46a5-9844-33b6c9b0f4df", "metadata": { "id": "9dfb48f0-bc3f-46a5-9844-33b6c9b0f4df" }, "outputs": [], "source": [ "print_sampled_tokens(scaled_probas[2])" ] }, { "cell_type": "markdown", "id": "0c83f0c4-3774-4375-ad7f-96440ba5fef7", "metadata": { "id": "0c83f0c4-3774-4375-ad7f-96440ba5fef7" }, "source": [ "- Assuming an LLM input \"every effort moves you\", using the approach above can sometimes result in nonsensical texts, such as \"every effort moves you pizza\", 3.2% of the time (32 out of 1000 times)" ] }, { "cell_type": "markdown", "id": "c6e4873e-07e4-4abb-85df-bdaedcc1a6f7", "metadata": { "id": "c6e4873e-07e4-4abb-85df-bdaedcc1a6f7" }, "source": [ "### Top-k sampling" ] }, { "cell_type": "markdown", "id": "6d4da95a-8bb2-4f69-a9b0-a643531db5df", "metadata": { "id": "6d4da95a-8bb2-4f69-a9b0-a643531db5df" }, "source": [ "- To be able to use higher temperatures to increase output diversity and to reduce the probability of nonsensical sentences, we can restrict the sampled tokens to the top-k most likely tokens:" ] }, { "cell_type": "markdown", "id": "7ae6fffd-2730-4abe-a2d3-781fc4836f17", "metadata": { "id": "7ae6fffd-2730-4abe-a2d3-781fc4836f17" }, "source": [ "\n", "\n", "- (Please note that the numbers in this figure are truncated to two\n", "digits after the decimal point to reduce visual clutter. The values in the Softmax row should add up to 1.0.)" ] }, { "cell_type": "markdown", "id": "0ba12da5-6ff1-4008-91b8-d2d537cbc14c", "metadata": { "id": "0ba12da5-6ff1-4008-91b8-d2d537cbc14c" }, "source": [ "- In code, we can implement this as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "2a7f908a-e9ec-446a-b407-fb6dbf05c806", "metadata": { "id": "2a7f908a-e9ec-446a-b407-fb6dbf05c806" }, "outputs": [], "source": [ "top_k = 3\n", "top_logits, top_pos = torch.topk(next_token_logits, top_k)\n", "\n", "print(\"Top logits:\", top_logits)\n", "print(\"Top positions:\", top_pos)" ] }, { "cell_type": "code", "execution_count": null, "id": "753865ed-79c5-48b1-b9f2-ccb132ff1d2f", "metadata": { "id": "753865ed-79c5-48b1-b9f2-ccb132ff1d2f" }, "outputs": [], "source": [ "new_logits = torch.where(\n", " condition=next_token_logits < top_logits[-1],\n", " input=torch.tensor(float(\"-inf\")),\n", " other=next_token_logits\n", ")\n", "\n", "print(new_logits)" ] }, { "cell_type": "markdown", "id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00", "metadata": { "id": "dfa6fa49-6e99-459d-a517-d7d0f51c4f00" }, "source": [ "> NOTE: \n", ">\n", "> An alternative, slightly more efficient implementation of the previous code cell is the following:\n", ">\n", "> ```python\n", "> new_logits = torch.full_like( # create tensor containing -inf values\n", "> next_token_logits, -torch.inf\n", ">) \n", "> new_logits[top_pos] = next_token_logits[top_pos] # copy top k values into the -inf tensor\n", "> ```\n", ">
\n", "> For more details, see https://github.com/rasbt/LLMs-from-scratch/discussions/326\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4844f000-c329-4e7e-aa89-16a2c4ebee43", "metadata": { "id": "4844f000-c329-4e7e-aa89-16a2c4ebee43" }, "outputs": [], "source": [ "topk_probas = torch.softmax(new_logits, dim=0)\n", "print(topk_probas)" ] }, { "cell_type": "markdown", "id": "56056503-a15d-4315-a3ff-46647a4c7c45", "metadata": { "id": "56056503-a15d-4315-a3ff-46647a4c7c45" }, "source": [ "### Modifying the text generation function" ] }, { "cell_type": "markdown", "id": "34770423-473d-46f6-a5fa-6b2979564d26", "metadata": { "id": "34770423-473d-46f6-a5fa-6b2979564d26" }, "source": [ "- The previous two subsections introduced temperature sampling and top-k sampling\n", "- Let's use these two concepts to modify the `generate_simple` function we used to generate text via the LLM earlier, creating a new `generate` function:" ] }, { "cell_type": "code", "execution_count": null, "id": "8e318891-bcc0-4d71-b147-33ce55febfa3", "metadata": { "id": "8e318891-bcc0-4d71-b147-33ce55febfa3" }, "outputs": [], "source": [ "def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n", "\n", " # For-loop is the same as before: Get logits, and only focus on last time step\n", " for _ in range(max_new_tokens):\n", " idx_cond = idx[:, -context_size:]\n", " with torch.no_grad():\n", " logits = model(idx_cond)\n", " logits = logits[:, -1, :]\n", "\n", " # New: Filter logits with top_k sampling\n", " if top_k is not None:\n", " # Keep only top_k values\n", " top_logits, _ = torch.topk(logits, top_k)\n", " min_val = top_logits[:, -1]\n", " logits = torch.where(logits < min_val, torch.tensor(float(\"-inf\")).to(logits.device), logits)\n", "\n", " # New: Apply temperature scaling\n", " if temperature > 0.0:\n", " logits = logits / temperature\n", "\n", " # Apply softmax to get probabilities\n", " probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n", "\n", " # Sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n", "\n", " # Otherwise same as before: get idx of the vocab entry with the highest logits value\n", " else:\n", " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n", "\n", " if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n", " break\n", "\n", " # Same as before: append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n", "\n", " return idx" ] }, { "cell_type": "code", "execution_count": null, "id": "aa2a0d7d-0457-42d1-ab9d-bd67683e7ed8", "metadata": { "id": "aa2a0d7d-0457-42d1-ab9d-bd67683e7ed8" }, "outputs": [], "source": [ "torch.manual_seed(123)\n", "\n", "token_ids = generate(\n", " model=model,\n", " idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n", " max_new_tokens=15,\n", " context_size=GPT_CONFIG_124M[\"context_length\"],\n", " top_k=25,\n", " temperature=1.4\n", ")\n", "\n", "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] }, { "cell_type": "markdown", "id": "4e2002ca-f4c1-48af-9e0a-88bfc163ba0b", "metadata": { "id": "4e2002ca-f4c1-48af-9e0a-88bfc163ba0b" }, "source": [ "## Loading and saving model weights in PyTorch" ] }, { "cell_type": "markdown", "id": "0fc52676-f026-4566-a226-2a90269f9d53", "metadata": { "id": "0fc52676-f026-4566-a226-2a90269f9d53" }, "source": [ "- Training LLMs is computationally expensive, so it's crucial to be able to save and load LLM weights\n", "\n", "" ] }, { "cell_type": "markdown", "id": "10e4c7f9-592f-43d6-a00e-598fa01dfb82", "metadata": { "id": "10e4c7f9-592f-43d6-a00e-598fa01dfb82" }, "source": [ "- The recommended way in PyTorch is to save the model weights, the so-called `state_dict` via by applying the `torch.save` function to the `.state_dict()` method:" ] }, { "cell_type": "code", "execution_count": null, "id": "3d67d869-ac04-4382-bcfb-c96d1ca80d47", "metadata": { "id": "3d67d869-ac04-4382-bcfb-c96d1ca80d47" }, "outputs": [], "source": [ "torch.save(model.state_dict(), \"model.pth\")" ] }, { "cell_type": "markdown", "id": "90e889e0-07bf-43e5-8f92-5c5c7aeaad9e", "metadata": { "id": "90e889e0-07bf-43e5-8f92-5c5c7aeaad9e" }, "source": [ "- Then we can load the model weights into a new `GPTModel` model instance as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "9d57d914-60a3-47f1-b499-5352f4c457cb", "metadata": { "id": "9d57d914-60a3-47f1-b499-5352f4c457cb" }, "outputs": [], "source": [ "model = GPTModel(GPT_CONFIG_124M)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))\n", "model.eval();" ] }, { "cell_type": "markdown", "id": "caa81aec-9c72-4f46-8ae2-4a4fde3edbc1", "metadata": { "id": "caa81aec-9c72-4f46-8ae2-4a4fde3edbc1" }, "source": [ "- It's common to train LLMs with adaptive optimizers like Adam or AdamW instead of regular SGD\n", "- These adaptive optimizers store additional parameters for each model weight, so it makes sense to save them as well in case we plan to continue the pretraining later:" ] }, { "cell_type": "code", "execution_count": null, "id": "bbd175bb-edf4-450e-a6de-d3e8913c6532", "metadata": { "id": "bbd175bb-edf4-450e-a6de-d3e8913c6532" }, "outputs": [], "source": [ "torch.save({\n", " \"model_state_dict\": model.state_dict(),\n", " \"optimizer_state_dict\": optimizer.state_dict(),\n", " },\n", " \"model_and_optimizer.pth\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "8a0c7295-c822-43bf-9286-c45abc542868", "metadata": { "id": "8a0c7295-c822-43bf-9286-c45abc542868" }, "outputs": [], "source": [ "checkpoint = torch.load(\"model_and_optimizer.pth\", weights_only=True)\n", "\n", "model = GPTModel(GPT_CONFIG_124M)\n", "model.load_state_dict(checkpoint[\"model_state_dict\"])\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)\n", "optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n", "model.train();" ] }, { "cell_type": "markdown", "id": "4194350e-0409-4a63-8ffd-d3a896509032", "metadata": { "id": "4194350e-0409-4a63-8ffd-d3a896509032" }, "source": [ "## Loading pretrained weights from OpenAI" ] }, { "cell_type": "markdown", "id": "83eb6c38-7278-40e0-bd9f-8a2b1feac3ec", "metadata": { "id": "83eb6c38-7278-40e0-bd9f-8a2b1feac3ec" }, "source": [ "- Previously, we only trained a small GPT-2 model using a very small short-story book for educational purposes\n", "- Interested readers can also find a longer pretraining run on the complete Project Gutenberg book corpus in [../03_bonus_pretraining_on_gutenberg](../03_bonus_pretraining_on_gutenberg)\n", "- Fortunately, we don't have to spend tens to hundreds of thousands of dollars to pretrain the model on a large pretraining corpus but can load the pretrained weights provided by OpenAI" ] }, { "cell_type": "markdown", "id": "127ddbdb-3878-4669-9a39-d231fbdfb834", "metadata": { "id": "127ddbdb-3878-4669-9a39-d231fbdfb834" }, "source": [ "- For an alternative way to load the weights from the Hugging Face Hub, see [../02_alternative_weight_loading](../02_alternative_weight_loading)" ] }, { "cell_type": "markdown", "id": "75cab892-a165-4f43-9601-f517bc212ab6", "metadata": { "id": "75cab892-a165-4f43-9601-f517bc212ab6" }, "source": [ "- First, some boilerplate code to download the files from OpenAI and load the weights into Python\n", "- Since OpenAI used [TensorFlow](https://www.tensorflow.org/), we will have to install and use TensorFlow for loading the weights; [tqdm](https://github.com/tqdm/tqdm) is a progress bar library\n", "- Uncomment and run the next cell to install the required libraries" ] }, { "cell_type": "code", "execution_count": null, "id": "fb9fdf02-972a-444e-bf65-8ffcaaf30ce8", "metadata": { "id": "fb9fdf02-972a-444e-bf65-8ffcaaf30ce8" }, "outputs": [], "source": [ "# pip install tensorflow tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "a0747edc-559c-44ef-a93f-079d60227e3f", "metadata": { "id": "a0747edc-559c-44ef-a93f-079d60227e3f" }, "outputs": [], "source": [ "print(\"TensorFlow version:\", version(\"tensorflow\"))\n", "print(\"tqdm version:\", version(\"tqdm\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "c5bc89eb-4d39-4287-9b0c-e459ebe7f5ed", "metadata": { "id": "c5bc89eb-4d39-4287-9b0c-e459ebe7f5ed" }, "outputs": [], "source": [ "import json\n", "import numpy as np\n", "import tensorflow as tf\n", "from tqdm import tqdm\n", "\n", "\n", "def download_and_load_gpt2(model_size, models_dir):\n", " # Validate model size\n", " allowed_sizes = (\"124M\", \"355M\", \"774M\", \"1558M\")\n", " if model_size not in allowed_sizes:\n", " raise ValueError(f\"Model size not in {allowed_sizes}\")\n", "\n", " # Define paths\n", " model_dir = os.path.join(models_dir, model_size)\n", " base_url = \"https://openaipublic.blob.core.windows.net/gpt-2/models\"\n", " filenames = [\n", " \"checkpoint\", \"encoder.json\", \"hparams.json\",\n", " \"model.ckpt.data-00000-of-00001\", \"model.ckpt.index\",\n", " \"model.ckpt.meta\", \"vocab.bpe\"\n", " ]\n", "\n", " # Download files\n", " os.makedirs(model_dir, exist_ok=True)\n", " for filename in filenames:\n", " file_url = os.path.join(base_url, model_size, filename)\n", " file_path = os.path.join(model_dir, filename)\n", " download_file(file_url, file_path)\n", "\n", " # Load settings and params\n", " tf_ckpt_path = tf.train.latest_checkpoint(model_dir)\n", " settings = json.load(open(os.path.join(model_dir, \"hparams.json\")))\n", " params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)\n", "\n", " return settings, params\n", "\n", "\n", "def download_file(url, destination):\n", " # Send a GET request to download the file\n", "\n", " try:\n", " with urllib.request.urlopen(url) as response:\n", " # Get the total file size from headers, defaulting to 0 if not present\n", " file_size = int(response.headers.get(\"Content-Length\", 0))\n", "\n", " # Check if file exists and has the same size\n", " if os.path.exists(destination):\n", " file_size_local = os.path.getsize(destination)\n", " if file_size == file_size_local:\n", " print(f\"File already exists and is up-to-date: {destination}\")\n", " return\n", "\n", " # Define the block size for reading the file\n", " block_size = 1024 # 1 Kilobyte\n", "\n", " # Initialize the progress bar with total file size\n", " progress_bar_description = os.path.basename(url) # Extract filename from URL\n", " with tqdm(total=file_size, unit=\"iB\", unit_scale=True, desc=progress_bar_description) as progress_bar:\n", " # Open the destination file in binary write mode\n", " with open(destination, \"wb\") as file:\n", " # Read the file in chunks and write to destination\n", " while True:\n", " chunk = response.read(block_size)\n", " if not chunk:\n", " break\n", " file.write(chunk)\n", " progress_bar.update(len(chunk)) # Update progress bar\n", " except urllib.error.HTTPError:\n", " s = (\n", " f\"The specified URL ({url}) is incorrect, the internet connection cannot be established,\"\n", " \"\\nor the requested file is temporarily unavailable.\\nPlease visit the following website\"\n", " \" for help: https://github.com/rasbt/LLMs-from-scratch/discussions/273\")\n", " print(s)\n", "\n", "\n", "def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):\n", " # Initialize parameters dictionary with empty blocks for each layer\n", " params = {\"blocks\": [{} for _ in range(settings[\"n_layer\"])]}\n", "\n", " # Iterate over each variable in the checkpoint\n", " for name, _ in tf.train.list_variables(ckpt_path):\n", " # Load the variable and remove singleton dimensions\n", " variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))\n", "\n", " # Process the variable name to extract relevant parts\n", " variable_name_parts = name.split(\"/\")[1:] # Skip the 'model/' prefix\n", "\n", " # Identify the target dictionary for the variable\n", " target_dict = params\n", " if variable_name_parts[0].startswith(\"h\"):\n", " layer_number = int(variable_name_parts[0][1:])\n", " target_dict = params[\"blocks\"][layer_number]\n", "\n", " # Recursively access or create nested dictionaries\n", " for key in variable_name_parts[1:-1]:\n", " target_dict = target_dict.setdefault(key, {})\n", "\n", " # Assign the variable array to the last key\n", " last_key = variable_name_parts[-1]\n", " target_dict[last_key] = variable_array\n", "\n", " return params" ] }, { "cell_type": "markdown", "id": "ff76a736-6f9f-4328-872e-f89a7b70a2cc", "metadata": { "id": "ff76a736-6f9f-4328-872e-f89a7b70a2cc" }, "source": [ "- We can then download the model weights for the 124 million parameter model as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "76271dd7-108d-4f5b-9c01-6ae0aac4b395", "metadata": { "id": "76271dd7-108d-4f5b-9c01-6ae0aac4b395" }, "outputs": [], "source": [ "settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"gpt2\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b1a31951-d971-4a6e-9c43-11ee1168ec6a", "metadata": { "id": "b1a31951-d971-4a6e-9c43-11ee1168ec6a" }, "outputs": [], "source": [ "print(\"Settings:\", settings)" ] }, { "cell_type": "code", "execution_count": null, "id": "857c8331-130e-46ba-921d-fa35d7a73cfe", "metadata": { "id": "857c8331-130e-46ba-921d-fa35d7a73cfe" }, "outputs": [], "source": [ "print(\"Parameter dictionary keys:\", params.keys())" ] }, { "cell_type": "code", "execution_count": null, "id": "c48dac94-8562-4a66-84ef-46c613cdc4cd", "metadata": { "id": "c48dac94-8562-4a66-84ef-46c613cdc4cd" }, "outputs": [], "source": [ "print(params[\"wte\"])\n", "print(\"Token embedding weight tensor dimensions:\", params[\"wte\"].shape)" ] }, { "cell_type": "markdown", "id": "466e100c-294e-4afc-a70a-2f398ac4c104", "metadata": { "id": "466e100c-294e-4afc-a70a-2f398ac4c104" }, "source": [ "- Alternatively, \"355M\", \"774M\", and \"1558M\" are also supported `model_size` arguments\n", "- The difference between these differently sized models is summarized in the figure below:" ] }, { "cell_type": "markdown", "id": "20f19d32-5aae-4176-9f86-f391672c8f0d", "metadata": { "id": "20f19d32-5aae-4176-9f86-f391672c8f0d" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "ea6e5076-f08d-41fc-bd8b-1cfe53538f41", "metadata": { "id": "ea6e5076-f08d-41fc-bd8b-1cfe53538f41" }, "source": [ "- Above, we loaded the 124M GPT-2 model weights into Python, however we still need to transfer them into our `GPTModel` instance\n", "- First, we initialize a new GPTModel instance\n", "- Note that the original GPT model initialized the linear layers for the query, key, and value matrices in the multi-head attention module with bias vectors, which is not required or recommended; however, to be able to load the weights correctly, we have to enable these too by setting `qkv_bias` to `True` in our implementation, too\n", "- We are also using the `1024` token context length that was used by the original GPT-2 model(s)" ] }, { "cell_type": "code", "execution_count": null, "id": "9fef90dd-0654-4667-844f-08e28339ef7d", "metadata": { "id": "9fef90dd-0654-4667-844f-08e28339ef7d" }, "outputs": [], "source": [ "# Define model configurations in a dictionary for compactness\n", "model_configs = {\n", " \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n", " \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n", " \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n", " \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n", "}\n", "\n", "# Copy the base configuration and update with specific model settings\n", "model_name = \"gpt2-small (124M)\" # Example model name\n", "NEW_CONFIG = GPT_CONFIG_124M.copy()\n", "NEW_CONFIG.update(model_configs[model_name])\n", "NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n", "\n", "gpt = GPTModel(NEW_CONFIG)\n", "gpt.eval();" ] }, { "cell_type": "markdown", "id": "272f29ac-8342-4b3d-a57d-9b0166ced314", "metadata": { "id": "272f29ac-8342-4b3d-a57d-9b0166ced314" }, "source": [ "- The next task is to assign the OpenAI weights to the corresponding weight tensors in our `GPTModel` instance" ] }, { "cell_type": "code", "execution_count": null, "id": "f9a92229-c002-49a6-8cfb-248297ad8296", "metadata": { "id": "f9a92229-c002-49a6-8cfb-248297ad8296" }, "outputs": [], "source": [ "def assign(left, right):\n", " if left.shape != right.shape:\n", " raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n", " return torch.nn.Parameter(torch.tensor(right))" ] }, { "cell_type": "code", "execution_count": null, "id": "f22d5d95-ca5a-425c-a9ec-fc432a12d4e9", "metadata": { "id": "f22d5d95-ca5a-425c-a9ec-fc432a12d4e9" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "def load_weights_into_gpt(gpt, params):\n", " gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])\n", " gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])\n", "\n", " for b in range(len(params[\"blocks\"])):\n", " q_w, k_w, v_w = np.split(\n", " (params[\"blocks\"][b][\"attn\"][\"c_attn\"])[\"w\"], 3, axis=-1)\n", " gpt.trf_blocks[b].att.W_query.weight = assign(\n", " gpt.trf_blocks[b].att.W_query.weight, q_w.T)\n", " gpt.trf_blocks[b].att.W_key.weight = assign(\n", " gpt.trf_blocks[b].att.W_key.weight, k_w.T)\n", " gpt.trf_blocks[b].att.W_value.weight = assign(\n", " gpt.trf_blocks[b].att.W_value.weight, v_w.T)\n", "\n", " q_b, k_b, v_b = np.split(\n", " (params[\"blocks\"][b][\"attn\"][\"c_attn\"])[\"b\"], 3, axis=-1)\n", " gpt.trf_blocks[b].att.W_query.bias = assign(\n", " gpt.trf_blocks[b].att.W_query.bias, q_b)\n", " gpt.trf_blocks[b].att.W_key.bias = assign(\n", " gpt.trf_blocks[b].att.W_key.bias, k_b)\n", " gpt.trf_blocks[b].att.W_value.bias = assign(\n", " gpt.trf_blocks[b].att.W_value.bias, v_b)\n", "\n", " gpt.trf_blocks[b].att.out_proj.weight = assign(\n", " gpt.trf_blocks[b].att.out_proj.weight,\n", " params[\"blocks\"][b][\"attn\"][\"c_proj\"][\"w\"].T)\n", " gpt.trf_blocks[b].att.out_proj.bias = assign(\n", " gpt.trf_blocks[b].att.out_proj.bias,\n", " params[\"blocks\"][b][\"attn\"][\"c_proj\"][\"b\"])\n", "\n", " gpt.trf_blocks[b].ff.layers[0].weight = assign(\n", " gpt.trf_blocks[b].ff.layers[0].weight,\n", " params[\"blocks\"][b][\"mlp\"][\"c_fc\"][\"w\"].T)\n", " gpt.trf_blocks[b].ff.layers[0].bias = assign(\n", " gpt.trf_blocks[b].ff.layers[0].bias,\n", " params[\"blocks\"][b][\"mlp\"][\"c_fc\"][\"b\"])\n", " gpt.trf_blocks[b].ff.layers[2].weight = assign(\n", " gpt.trf_blocks[b].ff.layers[2].weight,\n", " params[\"blocks\"][b][\"mlp\"][\"c_proj\"][\"w\"].T)\n", " gpt.trf_blocks[b].ff.layers[2].bias = assign(\n", " gpt.trf_blocks[b].ff.layers[2].bias,\n", " params[\"blocks\"][b][\"mlp\"][\"c_proj\"][\"b\"])\n", "\n", " gpt.trf_blocks[b].norm1.scale = assign(\n", " gpt.trf_blocks[b].norm1.scale,\n", " params[\"blocks\"][b][\"ln_1\"][\"g\"])\n", " gpt.trf_blocks[b].norm1.shift = assign(\n", " gpt.trf_blocks[b].norm1.shift,\n", " params[\"blocks\"][b][\"ln_1\"][\"b\"])\n", " gpt.trf_blocks[b].norm2.scale = assign(\n", " gpt.trf_blocks[b].norm2.scale,\n", " params[\"blocks\"][b][\"ln_2\"][\"g\"])\n", " gpt.trf_blocks[b].norm2.shift = assign(\n", " gpt.trf_blocks[b].norm2.shift,\n", " params[\"blocks\"][b][\"ln_2\"][\"b\"])\n", "\n", " gpt.final_norm.scale = assign(gpt.final_norm.scale, params[\"g\"])\n", " gpt.final_norm.shift = assign(gpt.final_norm.shift, params[\"b\"])\n", " gpt.out_head.weight = assign(gpt.out_head.weight, params[\"wte\"])\n", "\n", "\n", "load_weights_into_gpt(gpt, params)\n", "gpt.to(device);" ] }, { "cell_type": "markdown", "id": "4f7472cb-54dc-4311-96d8-b2694f885cee", "metadata": { "id": "4f7472cb-54dc-4311-96d8-b2694f885cee" }, "source": [ "- If the model is loaded correctly, we can use it to generate new text using our previous `generate` function:" ] }, { "cell_type": "code", "execution_count": null, "id": "1f690253-f845-4347-b7b6-43fabbd2affa", "metadata": { "id": "1f690253-f845-4347-b7b6-43fabbd2affa" }, "outputs": [], "source": [ "torch.manual_seed(123)\n", "\n", "token_ids = generate(\n", " model=gpt,\n", " idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n", " max_new_tokens=25,\n", " context_size=NEW_CONFIG[\"context_length\"],\n", " top_k=50,\n", " temperature=1.5\n", ")\n", "\n", "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] }, { "cell_type": "markdown", "id": "6d079f98-a7c4-462e-8416-5a64f670861c", "metadata": { "id": "6d079f98-a7c4-462e-8416-5a64f670861c" }, "source": [ "- We know that we loaded the model weights correctly because the model can generate coherent text; if we made even a small mistake, the model would not be able to do that" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }