GPT.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # %%
  4. # ## Introduction
  5. #
  6. # This example demonstrates how to implement an autoregressive language model
  7. # using a miniature version of the GPT model.
  8. # The model consists of a single Transformer block with causal masking
  9. # in its attention layer.
  10. #
  11. #
  12. # **References:**
  13. #
  14. # - [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)
  15. # - [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)
  16. # - [GPT-3](https://arxiv.org/abs/2005.14165)
  17. # ## Setup
  18. import argparse
  19. import tensorflow as tf
  20. from tensorflow import keras
  21. from tensorflow.keras import layers
  22. from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
  23. import numpy as np
  24. import re
  25. import string
  26. import random
  27. import os
  28. import sys
  29. import time
  30. # Disable warning , info etc.
  31. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  32. os.environ['CUDA_VISIBLE_DEVICES']='0'
  33. ### Prepare the data for word-level language modelling
  34. def parse_args():
  35. parser = argparse.ArgumentParser()
  36. parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
  37. args = parser.parse_args()
  38. return args
  39. # ## Implement a Transformer block as a layer
  40. def causal_attention_mask(batch_size, n_dest, n_src, dtype):
  41. """
  42. Mask the upper half of the dot product matrix in self attention.
  43. This prevents flow of information from future tokens to current token.
  44. 1's in the lower triangle, counting from the lower right corner.
  45. """
  46. i = tf.range(n_dest)[:, None]
  47. j = tf.range(n_src)
  48. m = i >= j - n_src + n_dest
  49. mask = tf.cast(m, dtype)
  50. mask = tf.reshape(mask, [1, n_dest, n_src])
  51. mult = tf.concat(
  52. [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
  53. )
  54. return tf.tile(mask, mult)
  55. class TransformerBlock(layers.Layer):
  56. def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
  57. super(TransformerBlock, self).__init__()
  58. self.att = layers.MultiHeadAttention(num_heads, embed_dim)
  59. self.ffn = keras.Sequential(
  60. [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
  61. )
  62. self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
  63. self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
  64. self.dropout1 = layers.Dropout(rate)
  65. self.dropout2 = layers.Dropout(rate)
  66. def call(self, inputs):
  67. input_shape = tf.shape(inputs)
  68. batch_size = input_shape[0]
  69. seq_len = input_shape[1]
  70. causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
  71. attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
  72. attention_output = self.dropout1(attention_output)
  73. out1 = self.layernorm1(inputs + attention_output)
  74. ffn_output = self.ffn(out1)
  75. ffn_output = self.dropout2(ffn_output)
  76. return self.layernorm2(out1 + ffn_output)
  77. # ## Implement an embedding layer
  78. #
  79. # Create two seperate embedding layers: one for tokens and one for token index
  80. # (positions).
  81. class TokenAndPositionEmbedding(layers.Layer):
  82. def __init__(self, maxlen, vocab_size, embed_dim):
  83. super(TokenAndPositionEmbedding, self).__init__()
  84. self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
  85. self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
  86. def call(self, x):
  87. maxlen = tf.shape(x)[-1]
  88. positions = tf.range(start=0, limit=maxlen, delta=1)
  89. positions = self.pos_emb(positions)
  90. x = self.token_emb(x)
  91. return x + positions
  92. def create_model():
  93. inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)
  94. embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
  95. x = embedding_layer(inputs)
  96. transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
  97. x = transformer_block(x)
  98. outputs = layers.Dense(vocab_size)(x)
  99. model = keras.Model(inputs=inputs, outputs=[outputs, x])
  100. loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  101. model.compile(
  102. "adam", loss=[loss_fn, None],
  103. ) # No loss and optimization based on word embeddings from transformer block
  104. return model
  105. def main():
  106. args = parse_args()
  107. global g_args
  108. g_args = args
  109. batch_size = args.batch_size
  110. print("Batch size: "+str(batch_size))
  111. ### Implement the miniature GPT model
  112. global vocab_size
  113. vocab_size = 20000 # Only consider the top 20k words
  114. global maxlen
  115. maxlen = 80 # Max sequence size
  116. global embed_dim
  117. embed_dim = 256 # Embedding size for each token
  118. global num_heads
  119. num_heads = 2 # Number of attention heads
  120. global feed_forward_dim
  121. feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer
  122. # The dataset contains each review in a separate text file
  123. # The text files are present in four different folders
  124. # Create a list all files
  125. filenames = []
  126. directories = [
  127. "/workspace/python/source_code/Data/wikitext-2"
  128. ]
  129. for dir in directories:
  130. for f in os.listdir(dir):
  131. filenames.append(os.path.join(dir, f))
  132. # print(f"{len(filenames)} files")
  133. # Create a dataset from text files
  134. random.shuffle(filenames)
  135. text_ds = tf.data.TextLineDataset(filenames)
  136. text_ds = text_ds.shuffle(buffer_size=256)
  137. text_ds = text_ds.batch(batch_size)
  138. def custom_standardization(input_string):
  139. """ Remove html line-break tags and handle punctuation """
  140. lowercased = tf.strings.lower(input_string)
  141. stripped_html = tf.strings.regex_replace(lowercased, "<br />", " ")
  142. return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
  143. # Create a vectorization layer and adapt it to the text
  144. vectorize_layer = TextVectorization(
  145. standardize=custom_standardization,
  146. max_tokens=vocab_size - 1,
  147. output_mode="int",
  148. output_sequence_length=maxlen + 1,
  149. )
  150. vectorize_layer.adapt(text_ds)
  151. vocab = vectorize_layer.get_vocabulary() # To get words back from token indices
  152. def prepare_lm_inputs_labels(text):
  153. """
  154. Shift word sequences by 1 position so that the target for position (i) is
  155. word at position (i+1). The model will use all words up till position (i)
  156. to predict the next word.
  157. """
  158. text = tf.expand_dims(text, -1)
  159. tokenized_sentences = vectorize_layer(text)
  160. x = tokenized_sentences[:, :-1]
  161. y = tokenized_sentences[:, 1:]
  162. return x, y
  163. text_ds = text_ds.map(prepare_lm_inputs_labels)
  164. text_ds = text_ds.prefetch(tf.data.experimental.AUTOTUNE)
  165. # ## Implement a Keras callback for generating text
  166. # %%
  167. class PrintLR(tf.keras.callbacks.Callback):
  168. def __init__(self, total_images=0):
  169. self.total_images = total_images
  170. def on_epoch_begin(self, epoch, logs=None):
  171. self.epoch_start_time = time.time()
  172. def on_epoch_end(self, epoch, logs=None):
  173. epoch_time = time.time() - self.epoch_start_time
  174. print('Epoch time : {}'.format(epoch_time))
  175. images_per_sec = round(self.total_images / epoch_time, 2)
  176. print('Units/sec: {}'.format(images_per_sec))
  177. model = create_model()
  178. model.fit(text_ds, verbose=1, epochs=3, callbacks=[PrintLR(total_images=44880)])
  179. main()