app.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.transforms as transforms
  4. import torchvision.models as models
  5. from torchvision.models import ResNet50_Weights
  6. import gradio as gr
  7. import pickle
  8. class Vocabulary:
  9. def __init__(self, freq_threshold=5):
  10. self.freq_threshold = freq_threshold
  11. # self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
  12. self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
  13. self.stoi = {v: k for k, v in self.itos.items()}
  14. self.index = 4
  15. def __len__(self):
  16. return len(self.itos)
  17. def tokenizer(self, text):
  18. text = text.lower()
  19. tokens = re.findall(r"\w+", text)
  20. return tokens
  21. def build_vocabulary(self, sentence_list):
  22. frequencies = Counter()
  23. for sentence in sentence_list:
  24. tokens = self.tokenizer(sentence)
  25. frequencies.update(tokens)
  26. for word, freq in frequencies.items():
  27. if freq >= self.freq_threshold:
  28. self.stoi[word] = self.index
  29. self.itos[self.index] = word
  30. self.index += 1
  31. def numericalize(self, text):
  32. tokens = self.tokenizer(text)
  33. numericalized = []
  34. for token in tokens:
  35. if token in self.stoi:
  36. numericalized.append(self.stoi[token])
  37. else:
  38. numericalized.append(self.stoi["<unk>"])
  39. return numericalized
  40. # You'll need to ensure these match your train.py
  41. EMBED_DIM = 256
  42. HIDDEN_DIM = 512
  43. MAX_SEQ_LENGTH = 25
  44. # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  45. DEVICE = "cpu"
  46. # Where you saved your model in train.py
  47. # MODEL_SAVE_PATH = "best_checkpoint.pth"
  48. MODEL_SAVE_PATH = "final_model.pth"
  49. with open("vocab.pkl", "rb") as f:
  50. vocab = pickle.load(f)
  51. print(vocab)
  52. vocab_size = len(vocab)
  53. print(vocab_size)
  54. # -----------------------------------------------------------------
  55. # 2. Model (Must match structure in train.py)
  56. # -----------------------------------------------------------------
  57. class ResNetEncoder(nn.Module):
  58. def __init__(self, embed_dim):
  59. super().__init__()
  60. resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
  61. for param in resnet.parameters():
  62. param.requires_grad = True
  63. modules = list(resnet.children())[:-1]
  64. self.resnet = nn.Sequential(*modules)
  65. self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
  66. self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
  67. def forward(self, images):
  68. with torch.no_grad():
  69. features = self.resnet(images) # (batch_size, 2048, 1, 1)
  70. features = features.view(features.size(0), -1)
  71. features = self.fc(features)
  72. features = self.batch_norm(features)
  73. return features
  74. class DecoderLSTM(nn.Module):
  75. def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
  76. super().__init__()
  77. self.embedding = nn.Embedding(vocab_size, embed_dim)
  78. self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
  79. self.fc = nn.Linear(hidden_dim, vocab_size)
  80. def forward(self, features, captions, states):
  81. # remove the last token for input
  82. captions_in = captions
  83. emb = self.embedding(captions_in)
  84. features = features.unsqueeze(1)
  85. # print(features.shape)
  86. # print(emb.shape)
  87. lstm_input = torch.cat((features, emb), dim=1)
  88. outputs, returned_states = self.lstm(lstm_input, states)
  89. logits = self.fc(outputs)
  90. return logits, returned_states
  91. def generate(self, features, max_len=20):
  92. """
  93. Greedy generation from the features as initial context.
  94. """
  95. batch_size = features.size(0)
  96. states = None
  97. generated_captions = []
  98. start_idx = 1 # <start>
  99. end_idx = 2 # <end>
  100. inputs = features
  101. # current_tokens = torch.LongTensor([start_idx] * batch_size).to(features.device).unsqueeze(0)
  102. current_tokens = [start_idx]
  103. for _ in range(max_len):
  104. input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0)
  105. logits, states = self.forward(inputs, input_tokens, states)
  106. logits = logits.contiguous().view(-1, vocab_size)
  107. predicted = logits.argmax(dim=1)[-1].item()
  108. generated_captions.append(predicted)
  109. current_tokens.append(predicted)
  110. # check if all ended
  111. # all_ended = True
  112. # for i, w in enumerate(predicted.numpy()):
  113. # print(w)
  114. # if w != end_idx:
  115. # all_ended = False
  116. # break
  117. # if all_ended:
  118. # break
  119. return generated_captions
  120. class ImageCaptioningModel(nn.Module):
  121. def __init__(self, encoder, decoder):
  122. super().__init__()
  123. self.encoder = encoder
  124. self.decoder = decoder
  125. def generate(self, images, max_len=MAX_SEQ_LENGTH):
  126. features = self.encoder(images)
  127. return self.decoder.generate(features, max_len=max_len)
  128. # -----------------------------------------------------------------
  129. # 3. LOAD THE TRAINED MODEL
  130. # -----------------------------------------------------------------
  131. def load_trained_model():
  132. encoder = ResNetEncoder(embed_dim=EMBED_DIM)
  133. decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
  134. model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
  135. # Load weights from disk
  136. state_dict = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
  137. model.load_state_dict(state_dict["model_state_dict"])
  138. model.eval()
  139. # print(model)
  140. return model
  141. model = load_trained_model()
  142. # -----------------------------------------------------------------
  143. # 4. INFERENCE FUNCTION (FOR GRADIO)
  144. # -----------------------------------------------------------------
  145. transform_inference = transforms.Compose(
  146. [
  147. transforms.Resize((224, 224)),
  148. transforms.ToTensor(),
  149. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  150. ]
  151. )
  152. def generate_caption_for_image(img):
  153. """
  154. Gradio callback: takes a PIL image, returns a string caption.
  155. """
  156. pil_img = img.convert("RGB")
  157. img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE)
  158. with torch.no_grad():
  159. output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH)
  160. # output_indices is a list of lists. For 1 image, output_indices[0].
  161. idx_list = output_indices
  162. result_words = []
  163. # end_token_idx = vocab.stoi["<end>"]
  164. end_token_idx = vocab.stoi["endofseq"]
  165. for idx in idx_list:
  166. if idx == end_token_idx:
  167. break
  168. # word = vocab.itos.get(idx, "<unk>")
  169. word = vocab.itos.get(idx, "unk")
  170. # skip <start>/<pad> in final output
  171. # if word not in ["<start>", "<pad>", "<end>"]:
  172. if word not in ["startofseq", "pad", "endofseq"]:
  173. result_words.append(word)
  174. return " ".join(result_words)
  175. # -----------------------------------------------------------------
  176. # 5. BUILD GRADIO INTERFACE
  177. # -----------------------------------------------------------------
  178. def main():
  179. iface = gr.Interface(
  180. fn=generate_caption_for_image,
  181. inputs=gr.Image(type="pil"),
  182. outputs="text",
  183. title="Image Captioning (ResNet + LSTM)",
  184. description="Upload an image to get a generated caption from the trained model.",
  185. )
  186. iface.launch(share=True)
  187. if __name__ == "__main__":
  188. print("Loaded model. Starting Gradio interface...")
  189. main()