{ "cells": [ { "cell_type": "markdown", "id": "vocational-baseball", "metadata": {}, "source": [ "# \n", "\n", "# 2_Understanding Megatron core - mpu\n", "---\n", "NVIDIA's Megatron-LM makes training very large langauge models ( up to a trillion params model) for own language a reality, Megatron-LM's core mpu ( model paralleism unit ) is the basis for all subsequence branching efforts on training very large models, such as [DeepSpeed](https://www.deepspeed.ai/features/#model-parallelism).\n", "\n", "However, the common side effect is that one could easily suffer from low GPUs utilization when configure Megatron GPT training runs with insufficient configuraitons.( see below an example traiing run with insufficient configurations that result in low or inconsistent gpu utils in training !\n", "\n", "Therefore, we will be taking our very first step of understanding how the core of Megatron-LM's mpu works and thereafter\n", "taking a closer look on Megatron-LM's training runs performance. \n", "\n", "![a training run with low or inconsistent gpus utils example](./Megatron-LM/pics/naive_run.JPG)\n", "\n", "## Learning Objectives\n", "- **The goal of this lab is to understand the core of Megatron's mpu = model parallelism unit works **\n", " - How Megatron group GPU-affinity per model parallel configuration ( pipeline parallel | tensor parallel )\n", " - Tensor Parallel : Column Parallel\n", " - Tensor Parallel : Row Parallel\n" ] }, { "cell_type": "markdown", "id": "democratic-disclosure", "metadata": {}, "source": [ "---------------------------------------------------------------------------\n", "## How Megatron group GPU-affinity per model-data parallel configuration\n", "\n", "Parallelism : Model & Data \n", "- p = Pipeline Model Parallel \n", "- t = Tensor Model Parallel\n", "- d = Data Parallal \n", "- n = Total number of GPUs used in the training\n", "\n", "#### Megatron requires p * t * d = n\n", "\n", "Parallel group - grouping for torch distributed (NCCL)\n", "- num_tensor_model_parallel_groups = n / t\n", "- num_pipeline_model_parallel_groups = n / p\n", "- num_data_parallel_groups = n / d\n", "\n", "assuming our configurations as following : \n", "\n", "tensor_model_parallel_size_=2 \n", "\n", "pipeline_model_parallel_size_= 4\n", "\n", "let's say we have a total of 16 GPUs denoted by g0 ... g15 \n", "\n", "therefore, world_size=16 \n", "\n", "then accoridng to Megatron [initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/initialize.py) we should see the following ...\n", "\n", " 8 data_parallel groups:\n", " [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]\n", " 8 tensor model-parallel groups:\n", " [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]\n", " 4 pipeline model-parallel groups:\n", " [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]\n", "\n", "Note that for efficiency, the caller should make sure adjacent ranks\n", "are on the same DGX box. \n", "For example if we are using 2 DGX-1 boxes\n", "with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n", "ranks 8 to 15 belong to the second box." ] }, { "cell_type": "code", "execution_count": 1, "id": "passive-dependence", "metadata": {}, "outputs": [], "source": [ "import itertools\n", "def ensure_divisibility(numerator, denominator):\n", " \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n", " assert numerator % denominator == 0, '{} is not divisible by {}'.format(\n", " numerator, denominator)\n", "def initialize_model_parallel(tensor_model_parallel_size_=2,\n", " pipeline_model_parallel_size_= 4,\n", " world_size=16):\n", " print(' ---------- world size is set to : {} ---------- '.format(world_size))\n", " print('> initializing tensor model parallel with size {}'.format(tensor_model_parallel_size_))\n", " print('> initializing pipeline model parallel with size {}'.format(pipeline_model_parallel_size_))\n", " \n", " tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)\n", " pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)\n", " \n", " # make sure world_size is divisible by t * p \n", " ensure_divisibility(world_size,tensor_model_parallel_size * pipeline_model_parallel_size)\n", " \n", " data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n", " print(\"> data parallel size is set to : \", data_parallel_size)\n", " num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size\n", " num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size\n", " num_data_parallel_groups = world_size // data_parallel_size\n", " print(\"---------- parallel groups ----------\")\n", " print(\"num_tensor_model_parallel_groups : \", num_tensor_model_parallel_groups)\n", " print(\"num_pipeline_model_parallel_groups : \", num_pipeline_model_parallel_groups)\n", " print(\"num_data_parallel_groups : \",num_data_parallel_groups )\n", " # Build the data-parallel groups.\n", " _DATA_PARALLEL_GROUP = []\n", " _MODEL_PARALLEL_GROUP = []\n", " _TENSOR_MODEL_PARALLEL_GROUP = []\n", " _PIPE_MODEL_PARALLEL_GROUP=[] \n", " _MODEL_PARALLEL_GROUP=[]\n", " all_data_parallel_group_ranks = []\n", " for i in range(pipeline_model_parallel_size):\n", " start_rank = i * num_pipeline_model_parallel_groups\n", " end_rank = (i + 1) * num_pipeline_model_parallel_groups\n", " #print(\"start rank : {} | end rank :{}\".format(start_rank, end_rank))\n", " temp=[]\n", " for j in range(tensor_model_parallel_size):\n", " ranks = range(start_rank + j, end_rank,\n", " tensor_model_parallel_size)\n", " temp.append(list(ranks))\n", " all_data_parallel_group_ranks.append(list(ranks))\n", " _DATA_PARALLEL_GROUP=all_data_parallel_group_ranks\n", "\n", " for i in range(num_pipeline_model_parallel_groups):\n", " ranks = range(i, world_size,\n", " num_pipeline_model_parallel_groups) \n", " _PIPE_MODEL_PARALLEL_GROUP.append(list(ranks))\n", " \n", " \n", " for i in range(data_parallel_size):\n", " ranks = [data_parallel_group_ranks[i]\n", " for data_parallel_group_ranks in all_data_parallel_group_ranks]\n", " _MODEL_PARALLEL_GROUP.append(ranks)\n", " \n", " for i in range(num_tensor_model_parallel_groups):\n", " ranks = range(i * tensor_model_parallel_size,\n", " (i + 1) * tensor_model_parallel_size)\n", " _TENSOR_MODEL_PARALLEL_GROUP.append(list(ranks))\n", " print(\"-----\"*20)\n", " print(\"_DATA_PARALLEL_GROUP \\n :\", _DATA_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"_TENSOR_MODEL_PARALLEL_GROUP \\n :\", _TENSOR_MODEL_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"_PIPE_MODEL_PARALLEL_GROUP \\n :\", _PIPE_MODEL_PARALLEL_GROUP)\n", " print(\"-----\"*20)\n", " print(\"Total :{} full models being partitioned into :{} GPUs \".format(len(_MODEL_PARALLEL_GROUP),world_size))\n", " for idx, m in zip(range(len(_MODEL_PARALLEL_GROUP)),_MODEL_PARALLEL_GROUP):\n", " m=[str(l) for l in m]\n", " print(\"model {} : is partitioned into gpus :{}\".format(str(idx),','.join(m))) \n", "\n" ] }, { "cell_type": "markdown", "id": "natural-genetics", "metadata": {}, "source": [ "---\n", "## sanity check, verify the below matches [megatron/mpu/initializer.py](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/initialize.py#L63)\n", " Let's say we have a total of 16 GPUs denoted by g0 ... g15 \n", " 8 data_parallel groups:\n", " [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]\n", " 8 tensor model-parallel groups:\n", " [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]\n", " 4 pipeline model-parallel groups:\n", " [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]\n", " Note that for efficiency, the caller should make sure adjacent ranks\n", " are on the same DGX box. For example if we are using 2 DGX-1 boxes\n", " with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n", " ranks 8 to 15 belong to the second box.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "residential-action", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ---------- world size is set to : 16 ---------- \n", "> initializing tensor model parallel with size 2\n", "> initializing pipeline model parallel with size 4\n", "> data parallel size is set to : 2\n", "---------- parallel groups ----------\n", "num_tensor_model_parallel_groups : 8\n", "num_pipeline_model_parallel_groups : 4\n", "num_data_parallel_groups : 8\n", "----------------------------------------------------------------------------------------------------\n", "_DATA_PARALLEL_GROUP \n", " : [[0, 2], [1, 3], [4, 6], [5, 7], [8, 10], [9, 11], [12, 14], [13, 15]]\n", "----------------------------------------------------------------------------------------------------\n", "_TENSOR_MODEL_PARALLEL_GROUP \n", " : [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]]\n", "----------------------------------------------------------------------------------------------------\n", "_PIPE_MODEL_PARALLEL_GROUP \n", " : [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]\n", "----------------------------------------------------------------------------------------------------\n", "Total :2 full models being partitioned into :16 GPUs \n", "model 0 : is partitioned into gpus :0,1,4,5,8,9,12,13\n", "model 1 : is partitioned into gpus :2,3,6,7,10,11,14,15\n" ] } ], "source": [ "initialize_model_parallel(tensor_model_parallel_size_=2,\n", " pipeline_model_parallel_size_= 4,\n", " world_size=16)" ] }, { "cell_type": "code", "execution_count": 3, "id": "likely-visiting", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ---------- world size is set to : 16 ---------- \n", "> initializing tensor model parallel with size 4\n", "> initializing pipeline model parallel with size 2\n", "> data parallel size is set to : 2\n", "---------- parallel groups ----------\n", "num_tensor_model_parallel_groups : 4\n", "num_pipeline_model_parallel_groups : 8\n", "num_data_parallel_groups : 8\n", "----------------------------------------------------------------------------------------------------\n", "_DATA_PARALLEL_GROUP \n", " : [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13], [10, 14], [11, 15]]\n", "----------------------------------------------------------------------------------------------------\n", "_TENSOR_MODEL_PARALLEL_GROUP \n", " : [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]\n", "----------------------------------------------------------------------------------------------------\n", "_PIPE_MODEL_PARALLEL_GROUP \n", " : [[0, 8], [1, 9], [2, 10], [3, 11], [4, 12], [5, 13], [6, 14], [7, 15]]\n", "----------------------------------------------------------------------------------------------------\n", "Total :2 full models being partitioned into :16 GPUs \n", "model 0 : is partitioned into gpus :0,1,2,3,8,9,10,11\n", "model 1 : is partitioned into gpus :4,5,6,7,12,13,14,15\n" ] } ], "source": [ "tensor_model_parallel_size= 4 # try a different tensor_model_parallel_size_\n", "pipeline_model_parallel_size= 2 # try a different pipeline_model_parallel_size_\n", "world_size=16\n", "assert world_size%(tensor_model_parallel_size * pipeline_model_parallel_size)==0,'please make sure world_size is divisible by tensor_model_parallel_size * pipeline_model_parallel_size' \n", "\n", "initialize_model_parallel(tensor_model_parallel_size_=tensor_model_parallel_size,pipeline_model_parallel_size_= pipeline_model_parallel_size,world_size=world_size)" ] }, { "cell_type": "markdown", "id": "bright-companion", "metadata": {}, "source": [ "----------------------------------------------------------------------\n", "## Megatron Column Parallel \n", "[ColumnParallel reference](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/layers.py#L201)\n", "![ColumnParallel](./Megatron-LM/pics/ColumnParallel.JPG)" ] }, { "cell_type": "code", "execution_count": 5, "id": "closing-louisville", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"./Megatron-LM\")\n", "from megatron.mpu import layers\n", "from torch.nn.parameter import Parameter\n", "import torch.nn.init as init\n", "import torch\n", "import random\n", "from megatron import *\n", "from megatron.mpu.tests import *\n", "tensor_model_parallel_size=4\n", "from megatron.mpu.utils import *" ] }, { "cell_type": "code", "execution_count": 6, "id": "bizarre-remedy", "metadata": {}, "outputs": [], "source": [ "## the below class is slightly modified from the original Megatron repo to skip environment variable initialization such as world_size\n", "global world_size \n", "world_size = 16\n", "class myColumnParallelLinear(torch.nn.Module):\n", " \"\"\"Linear layer with column parallelism.\n", " The linear layer is defined as Y = XA + b. A is parallelized along\n", " its second dimension as A = [A_1, ..., A_p].\n", " Arguments:\n", " input_size: first dimension of matrix A.\n", " output_size: second dimension of matrix A.\n", " bias: If true, add bias\n", " gather_output: If true, call all-gether on output and make Y avaiable\n", " to all GPUs, otherwise, every GPU will have its output\n", " which is Y_i = XA_i\n", " init_method: method to initialize weights. Note that bias is always set\n", " to zero.\n", " stride: For the strided linear layers.\n", " keep_master_weight_for_test: This was added for testing and should be\n", " set to False. It returns the master weights\n", " used for initialization.\n", " skip_bias_add: This was added to enable performance optimations where bias\n", " can be fused with other elementwise operations. we skip \n", " adding bias but instead return it.\n", " \"\"\"\n", "\n", " def __init__(self, input_size, output_size, bias=True, gather_output=True,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False):\n", " super(myColumnParallelLinear, self).__init__()\n", "\n", " # Keep input parameters\n", " self.input_size = input_size\n", " self.output_size = output_size\n", " self.gather_output = gather_output\n", " # Divide the weight matrix along the last dimension.\n", " \n", " self.output_size_per_partition = divide(output_size, world_size)\n", " self.skip_bias_add = skip_bias_add\n", "\n", " # Parameters.\n", " # Note: torch.nn.functional.linear performs XA^T + b and as a result\n", " # we allocate the transpose.\n", " # Initialize weight. \n", " use_cpu_initialization=True # hard coded to use cpu\n", " params_dtype = torch.float # skipping need of args\n", " \n", " if use_cpu_initialization:\n", " self.weight = Parameter(torch.empty(self.output_size_per_partition,\n", " self.input_size,\n", " dtype=params_dtype))\n", " \n", " self.master_weight = m_initialize_affine_weight_cpu(\n", " self.weight, self.output_size, self.input_size,\n", " self.output_size_per_partition, 0, init_method,\n", " stride=stride, return_master_weight=keep_master_weight_for_test)\n", " \n", " else:\n", " self.weight = Parameter(torch.empty(\n", " self.output_size_per_partition, self.input_size,\n", " device=torch.cuda.current_device(), dtype=params_dtype))\n", " _initialize_affine_weight_gpu(self.weight, init_method,\n", " partition_dim=0, stride=stride)\n", " \n", " if bias:\n", " if use_cpu_initialization:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size_per_partition, dtype=params_dtype))\n", " else:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size_per_partition,\n", " device=torch.cuda.current_device(),\n", " dtype=params_dtype))\n", " # Always initialize bias to zero.\n", " with torch.no_grad():\n", " self.bias.zero_()\n", " else:\n", " self.register_parameter('bias', None)\n", "\n", " def forward(self, input_):\n", " # Set up backprop all-reduce.\n", " print(\"in Column parallel forward\")\n", " input_parallel = copy_to_tensor_model_parallel_region(input_)\n", " # Matrix multiply.\n", "\n", " bias = self.bias if not self.skip_bias_add else None\n", " output_parallel = F.linear(input_parallel, self.weight, bias)\n", " if self.gather_output:\n", " # All-gather across the partitions.\n", " output = gather_from_tensor_model_parallel_region(output_parallel)\n", " else:\n", " output = output_parallel \n", " output_bias = self.bias if self.skip_bias_add else None\n", " return output, output_bias" ] }, { "cell_type": "code", "execution_count": 7, "id": "recent-cradle", "metadata": {}, "outputs": [], "source": [ "def get_weight_list(master_weight,tensor_model_parallel_gp):\n", " my_weight_list=[]\n", " a,b=master_weight.size()\n", " print(\"A = [\")\n", " tensor_model_parallel_gp=list(itertools.chain(*tensor_model_parallel_gp))\n", " cnt=0\n", " for gp in tensor_model_parallel_gp :\n", " if which_model_parallel=='col': \n", " temp=master_weight[gp::world_size].T\n", " if cnt < world_size -1 : \n", " print(\"A{}=\".format(str(cnt)), temp.size(), end = ',')\n", " else:\n", " print(\"A{}=\".format(str(cnt)), temp.size())\n", " elif which_model_parallel =='row':\n", " temp=master_weight.T\n", " temp=temp[gp::world_size]\n", " if cnt < world_size -1 :\n", " print(\"A{}=\".format(str(cnt)), temp.size(),',')\n", " else:\n", " print(\"A{}=\".format(str(cnt)), temp.size())\n", "\n", " else:\n", " print(\"set which_model_parallel to **col** or **row**\")\n", " cnt+=1 \n", " my_weight_list.append(temp)\n", " \n", " print(\" ]\")\n", " print(len(my_weight_list))\n", " return my_weight_list\n", "def m_initialize_affine_weight_cpu(weight, output_size, input_size,\n", " per_partition_size, partition_dim,\n", " init_method, stride=1,\n", " return_master_weight=False):\n", " \"\"\"Initialize affine weight for model parallel.\n", " Build the master weight on all processes and scatter\n", " the relevant chunk.\"\"\"\n", " params_dtype = torch.float\n", " # Initialize master weight\n", " master_weight = torch.empty(output_size, input_size,\n", " dtype=torch.float,\n", " requires_grad=False) \n", " \n", " master_weight = master_weight.to(dtype=params_dtype)\n", " # Split and copy\n", " per_partition_per_stride_size = divide(per_partition_size, stride)\n", " print(\"per_partition_per_stride_size \",per_partition_per_stride_size)\n", " weight_list = torch.split(master_weight, per_partition_per_stride_size,\n", " dim=partition_dim)\n", " \n", " #print(\"weight_list\", [wl.size() for wl in weight_list] , len(weight_list))\n", " #print(\"----\"*5)\n", " tensor_model_parallel_gp=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]]\n", " my_weight_list = get_weight_list(master_weight,tensor_model_parallel_gp)\n", " \n", " with torch.no_grad():\n", " torch.cat(my_weight_list, dim=partition_dim, out=weight)\n", " if return_master_weight:\n", " return master_weight\n", " return None" ] }, { "cell_type": "markdown", "id": "divine-morris", "metadata": {}, "source": [ "## Peek inside Column Parallel Class" ] }, { "cell_type": "code", "execution_count": 8, "id": "sound-privacy", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "this is how A is sliced column-wised ...\n", "\n", "per_partition_per_stride_size 32\n", "A = [\n", "A0= torch.Size([1024, 32]),A1= torch.Size([1024, 32]),A2= torch.Size([1024, 32]),A3= torch.Size([1024, 32]),A4= torch.Size([1024, 32]),A5= torch.Size([1024, 32]),A6= torch.Size([1024, 32]),A7= torch.Size([1024, 32]),A8= torch.Size([1024, 32]),A9= torch.Size([1024, 32]),A10= torch.Size([1024, 32]),A11= torch.Size([1024, 32]),A12= torch.Size([1024, 32]),A13= torch.Size([1024, 32]),A14= torch.Size([1024, 32]),A15= torch.Size([1024, 32])\n", " ]\n", "16\n" ] } ], "source": [ "tensor_model_parallel_size= 4 \n", "pipeline_model_parallel_size= 2 \n", "input_size = 1024 # 1024 rows\n", "output_size = 512 # 256 columns\n", "which_model_parallel='col'\n", "print(\"this is how A is sliced column-wised ...\\n\")\n", "testCol=myColumnParallelLinear(input_size, output_size, bias=True, gather_output=True,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False)" ] }, { "cell_type": "code", "execution_count": 9, "id": "continental-quarter", "metadata": {}, "outputs": [], "source": [ "per_partition_per_stride_size=32\n", "assert 16* per_partition_per_stride_size == 512 " ] }, { "cell_type": "code", "execution_count": 10, "id": "integral-cause", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "__main__.myColumnParallelLinear" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(testCol)" ] }, { "cell_type": "code", "execution_count": 11, "id": "impressive-watts", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1024, 512)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "testCol.input_size, testCol.output_size" ] }, { "cell_type": "markdown", "id": "explicit-mother", "metadata": {}, "source": [ "----------------------------------------------------------------------\n", "## Megatron Row Parallel \n", "[RowParallel reference](https://github.com/NVIDIA/Megatron-LM/blob/90e0a0dd08159e1c95f4f9d99bb8687f327d36c3/megatron/mpu/layers.py#L294)\n", "![RowParallel](./Megatron-LM/pics/RowParallel.JPG)" ] }, { "cell_type": "code", "execution_count": 53, "id": "owned-pixel", "metadata": {}, "outputs": [], "source": [ "class myRowParallelLinear(torch.nn.Module):\n", " \"\"\"Linear layer with row parallelism.\n", " The linear layer is defined as Y = XA + b. A is parallelized along\n", " its first dimension and X along its second dimension as:\n", " - -\n", " | A_1 |\n", " | . |\n", " A = | . | X = [X_1, ..., X_p]\n", " | . |\n", " | A_p |\n", " - -\n", " Arguments:\n", " input_size: first dimension of matrix A.\n", " output_size: second dimension of matrix A.\n", " bias: If true, add bias. Note that bias is not parallelized.\n", " input_is_parallel: If true, we assume that the input is already\n", " split across the GPUs and we do not split\n", " again.\n", " init_method: method to initialize weights. Note that bias is always set\n", " to zero.\n", " stride: For the strided linear layers.\n", " keep_master_weight_for_test: This was added for testing and should be\n", " set to False. It returns the master weights\n", " used for initialization.\n", " skip_bias_add: This was added to enable performance optimations where bias\n", " can be fused with other elementwise operations. we skip \n", " adding bias but instead return it.\n", " \"\"\"\n", "\n", " def __init__(self, input_size, output_size, bias=True,\n", " input_is_parallel=False,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False):\n", " super(myRowParallelLinear, self).__init__()\n", "\n", " # Keep input parameters\n", " self.input_size = input_size\n", " self.output_size = output_size\n", " self.input_is_parallel = input_is_parallel\n", " # Divide the weight matrix along the last dimension.\n", " self.input_size_per_partition = divide(input_size, world_size)\n", " self.skip_bias_add = skip_bias_add\n", " print(\"input_size_per_partition \", self.input_size_per_partition)\n", " \n", "\n", " # Parameters.\n", " # Note: torch.nn.functional.linear performs XA^T + b and as a result\n", " # we allocate the transpose.\n", " # Initialize weight.\n", " use_cpu_initialization=True # hard coded to use cpu\n", " params_dtype = torch.float # skipping need of args\n", " if use_cpu_initialization:\n", " self.weight = Parameter(torch.empty(self.output_size,\n", " self.input_size_per_partition,\n", " dtype=params_dtype))\n", " self.master_weight = m_initialize_affine_weight_cpu(\n", " self.weight, self.output_size, self.input_size,\n", " self.input_size_per_partition, 1, init_method,\n", " stride=stride, return_master_weight=keep_master_weight_for_test)\n", " else:\n", " self.weight = Parameter(torch.empty(\n", " self.output_size, self.input_size_per_partition,\n", " device=torch.cuda.current_device(), dtype=params_dtype))\n", " _initialize_affine_weight_gpu(self.weight, init_method,\n", " partition_dim=1, stride=stride)\n", " if bias:\n", " if use_cpu_initialization:\n", " self.bias = Parameter(torch.empty(self.output_size,\n", " dtype=params_dtype))\n", " else:\n", " self.bias = Parameter(torch.empty(\n", " self.output_size, device=torch.cuda.current_device(),\n", " dtype=params_dtype))\n", " # Always initialize bias to zero.\n", " with torch.no_grad():\n", " self.bias.zero_()\n", " else:\n", " self.register_parameter('bias', None)\n", "\n", "\n", "\n", " def forward(self, input_):\n", " # Set up backprop all-reduce.\n", " if self.input_is_parallel:\n", " input_parallel = input_\n", " else:\n", " input_parallel = scatter_to_tensor_model_parallel_region(input_)\n", " # Matrix multiply.\n", " output_parallel = F.linear(input_parallel, self.weight)\n", " # All-reduce across all the partitions.\n", " output_ = reduce_from_tensor_model_parallel_region(output_parallel)\n", " if not self.skip_bias_add:\n", " output = output_ + self.bias if self.bias is not None else output_\n", " output_bias = None\n", " else:\n", " output = output_\n", " output_bias = self.bias\n", " return output, output_bias" ] }, { "cell_type": "code", "execution_count": 39, "id": "latin-persian", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16.0" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "code", "execution_count": 74, "id": "sudden-tokyo", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "this is how A is sliced Row-wised ...\n", "\n", "input_size_per_partition 64\n", "per_partition_per_stride_size 64\n", "A = [\n", "A0= torch.Size([64, 512]) ,\n", "A1= torch.Size([64, 512]) ,\n", "A2= torch.Size([64, 512]) ,\n", "A3= torch.Size([64, 512]) ,\n", "A4= torch.Size([64, 512]) ,\n", "A5= torch.Size([64, 512]) ,\n", "A6= torch.Size([64, 512]) ,\n", "A7= torch.Size([64, 512]) ,\n", "A8= torch.Size([64, 512]) ,\n", "A9= torch.Size([64, 512]) ,\n", "A10= torch.Size([64, 512]) ,\n", "A11= torch.Size([64, 512]) ,\n", "A12= torch.Size([64, 512]) ,\n", "A13= torch.Size([64, 512]) ,\n", "A14= torch.Size([64, 512]) ,\n", "A15= torch.Size([64, 512])\n", " ]\n", "16\n" ] } ], "source": [ "tensor_model_parallel_size= 4 \n", "pipeline_model_parallel_size= 2 \n", "input_size = 1024 # first dimension of the matrix\n", "output_size = 512 # 2nd dimension of the matrix\n", "print(\"this is how A is sliced Row-wised ...\\n\")\n", "which_model_parallel='row'\n", "testRow=myRowParallelLinear(input_size,output_size, bias=True,\n", " input_is_parallel=False,\n", " init_method=init.xavier_normal_, stride=1,\n", " keep_master_weight_for_test=False,\n", " skip_bias_add=False)" ] }, { "cell_type": "code", "execution_count": 72, "id": "prompt-victoria", "metadata": {}, "outputs": [], "source": [ "per_partition_per_stride_size=64\n", "assert 16* per_partition_per_stride_size == 1024 " ] }, { "cell_type": "code", "execution_count": 58, "id": "mysterious-bishop", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1024, 512)" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "testRow.input_size, testRow.output_size" ] }, { "cell_type": "markdown", "id": "effective-premises", "metadata": {}, "source": [ "---\n", "## Up Next -\n", "[About GPT's tokenizer](./Day2-3_GPT_vocab_merge_files.ipynb)\n", "## Back To Start Menu\n", "[start menu](../Start_Here.ipynb)" ] }, { "cell_type": "markdown", "id": "surprising-utilization", "metadata": {}, "source": [ "-----\n", "\n", "\n", "## Licensing \n", "\n", "This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }