123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- import torch
- import torch.nn as nn
- import torchvision.transforms as transforms
- import torchvision.models as models
- from torchvision.models import ResNet50_Weights
- import gradio as gr
- import pickle
- class Vocabulary:
- def __init__(self, freq_threshold=5):
- self.freq_threshold = freq_threshold
- # self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
- self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
- self.stoi = {v: k for k, v in self.itos.items()}
- self.index = 4
- def __len__(self):
- return len(self.itos)
- def tokenizer(self, text):
- text = text.lower()
- tokens = re.findall(r"\w+", text)
- return tokens
- def build_vocabulary(self, sentence_list):
- frequencies = Counter()
- for sentence in sentence_list:
- tokens = self.tokenizer(sentence)
- frequencies.update(tokens)
- for word, freq in frequencies.items():
- if freq >= self.freq_threshold:
- self.stoi[word] = self.index
- self.itos[self.index] = word
- self.index += 1
- def numericalize(self, text):
- tokens = self.tokenizer(text)
- numericalized = []
- for token in tokens:
- if token in self.stoi:
- numericalized.append(self.stoi[token])
- else:
- numericalized.append(self.stoi["<unk>"])
- return numericalized
- # You'll need to ensure these match your train.py
- EMBED_DIM = 256
- HIDDEN_DIM = 512
- MAX_SEQ_LENGTH = 25
- # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- DEVICE = "cpu"
- # Where you saved your model in train.py
- # MODEL_SAVE_PATH = "best_checkpoint.pth"
- MODEL_SAVE_PATH = "final_model.pth"
- with open("vocab.pkl", "rb") as f:
- vocab = pickle.load(f)
- print(vocab)
- vocab_size = len(vocab)
- print(vocab_size)
- # -----------------------------------------------------------------
- # 2. Model (Must match structure in train.py)
- # -----------------------------------------------------------------
- class ResNetEncoder(nn.Module):
- def __init__(self, embed_dim):
- super().__init__()
- resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
- for param in resnet.parameters():
- param.requires_grad = True
- modules = list(resnet.children())[:-1]
- self.resnet = nn.Sequential(*modules)
- self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
- self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
- def forward(self, images):
- with torch.no_grad():
- features = self.resnet(images) # (batch_size, 2048, 1, 1)
- features = features.view(features.size(0), -1)
- features = self.fc(features)
- features = self.batch_norm(features)
- return features
- class DecoderLSTM(nn.Module):
- def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
- super().__init__()
- self.embedding = nn.Embedding(vocab_size, embed_dim)
- self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
- self.fc = nn.Linear(hidden_dim, vocab_size)
- def forward(self, features, captions, states):
- # remove the last token for input
- captions_in = captions
- emb = self.embedding(captions_in)
- features = features.unsqueeze(1)
- # print(features.shape)
- # print(emb.shape)
- lstm_input = torch.cat((features, emb), dim=1)
- outputs, returned_states = self.lstm(lstm_input, states)
- logits = self.fc(outputs)
- return logits, returned_states
- def generate(self, features, max_len=20):
- """
- Greedy generation from the features as initial context.
- """
- batch_size = features.size(0)
- states = None
- generated_captions = []
- start_idx = 1 # <start>
- end_idx = 2 # <end>
- inputs = features
- # current_tokens = torch.LongTensor([start_idx] * batch_size).to(features.device).unsqueeze(0)
- current_tokens = [start_idx]
- for _ in range(max_len):
- input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0)
- logits, states = self.forward(inputs, input_tokens, states)
- logits = logits.contiguous().view(-1, vocab_size)
- predicted = logits.argmax(dim=1)[-1].item()
- generated_captions.append(predicted)
- current_tokens.append(predicted)
- # check if all ended
- # all_ended = True
- # for i, w in enumerate(predicted.numpy()):
- # print(w)
- # if w != end_idx:
- # all_ended = False
- # break
- # if all_ended:
- # break
- return generated_captions
- class ImageCaptioningModel(nn.Module):
- def __init__(self, encoder, decoder):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- def generate(self, images, max_len=MAX_SEQ_LENGTH):
- features = self.encoder(images)
- return self.decoder.generate(features, max_len=max_len)
- # -----------------------------------------------------------------
- # 3. LOAD THE TRAINED MODEL
- # -----------------------------------------------------------------
- def load_trained_model():
- encoder = ResNetEncoder(embed_dim=EMBED_DIM)
- decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
- model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
- # Load weights from disk
- state_dict = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
- model.load_state_dict(state_dict["model_state_dict"])
- model.eval()
- # print(model)
- return model
- model = load_trained_model()
- # -----------------------------------------------------------------
- # 4. INFERENCE FUNCTION (FOR GRADIO)
- # -----------------------------------------------------------------
- transform_inference = transforms.Compose(
- [
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- ]
- )
- def generate_caption_for_image(img):
- """
- Gradio callback: takes a PIL image, returns a string caption.
- """
- pil_img = img.convert("RGB")
- img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE)
- with torch.no_grad():
- output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH)
- # output_indices is a list of lists. For 1 image, output_indices[0].
- idx_list = output_indices
- result_words = []
- # end_token_idx = vocab.stoi["<end>"]
- end_token_idx = vocab.stoi["endofseq"]
- for idx in idx_list:
- if idx == end_token_idx:
- break
- # word = vocab.itos.get(idx, "<unk>")
- word = vocab.itos.get(idx, "unk")
- # skip <start>/<pad> in final output
- # if word not in ["<start>", "<pad>", "<end>"]:
- if word not in ["startofseq", "pad", "endofseq"]:
- result_words.append(word)
- return " ".join(result_words)
- # -----------------------------------------------------------------
- # 5. BUILD GRADIO INTERFACE
- # -----------------------------------------------------------------
- def main():
- iface = gr.Interface(
- fn=generate_caption_for_image,
- inputs=gr.Image(type="pil"),
- outputs="text",
- title="Image Captioning (ResNet + LSTM)",
- description="Upload an image to get a generated caption from the trained model.",
- )
- iface.launch(share=True)
- if __name__ == "__main__":
- print("Loaded model. Starting Gradio interface...")
- main()
|