How To Generate Meaningful Sentences Using a T5 Transformer


By Vatsal Saglani, Machine Learning Engineer at Quinnox



Photo by Tech Daily on Unsplash

 

In the blog, Generating storylines using a T5 Transformer we saw how we can fine-tune a Sequence2Sequence (Text-To-Text) Transformer (T5) to generate storylines/plots by providing inputs like genre, director, cast, and ethnicity. In this blog, we will check out how we can use that trained T5 Model for inference. Later, we will also see how can we deploy it using gunicorn and flask.

 

How to do Model Inference?

 

  • Let’s set up the script with the imports
import os
import re
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook, tnrange
from sklearn.utils import shuffle
import pickle
import math
import torch
import torch.nn.functional as F
from transformers import T5Tokenizer, T5ForConditionalGeneration
  • Set the SEED value and load the model and tokenizer
torch.manual_seed(3007)model = T5ForConditionalGeneration.from_pretrained('./outputs/model_files')
tokenizer = T5Tokenizer.from_pretrained('./outputs/model_files')
  • Use the model.generate function to generate sequences
text = "generate plot for genre: horror"
input_ids = tokenizer.encode(text, return_tensors="pt")
greedyOp = model.generate(input_ids, max_length=100)
tokenizer.decode(greedyOp[0], skip_special_tokens=True)

Note: Read this amazing Hugging Face blog regarding how you to use different decoding strategies for Text Generation using Transformers

  • Let’s put this in a function
def generateStoryLine(text, seq_len, seq_num):
				'''
				args:
					text: input text eg. generate plot for: {genre} or generate plot for: {director}
					seq_len: Max sequence length for the generated text
					seq_num: Number of sequences to generate
				'''
        outputDict = dict()
        outputDict["plots"] = {}
        input_ids = tokenizer.encode(text, return_tensors = "pt")
        beamOp = model.generate(
            input_ids,
            max_length = seq_len,
            do_sample = True,
            top_k = 100,
            top_p = 0.95,
            num_return_sequences = seq_num
        )        for ix, sample_op in enumerate(beamOp):
            outputDict["plots"][ix] = self.tokenizer.decode(sample_op, skip_special_tokens = True)
            
        
        return outputDict

 

How to deploy this with Flask?

 
There are multiple ways a user can provide the inputs and the model might need to generate the plots. The user can provide only the genre, or they can provide genre and cast or they can even provide all the four i.e. genre, director, cast and ethnicity. But for the purpose of this implementation, I have kept it mandatory to provide a genre at the least.

You can check out the link below to see how the API will work.

Movie Plot Generator
I generate vague movie plots on the web (but sometimes they are good). But I can assure you that it will always be…

 

Let’s develop a backend to achieve the API call used for the link above

 

Install the requirements

 

pip install flask flask_cors tqdm rich gunicorn

 

Create an app.py file

 

# app.pyfrom flask import Flask, request, jsonify
import json
from flask_cors import CORS
import uuidfrom predict import PredictionModelObjectapp = Flask(__name__)
CORS(app)print("Loading Model Object")
predictionObject = PredictionModelObject()
print("Loaded Model Object")
@app.route('/api/generatePlot', methods=['POST'])
def gen_plot():    req = request.get_json()
    genre = req['genre']
    director = req['director'] if 'director' in req else None
    cast = req['cast'] if 'cast' in req else None
    ethnicity = req['ethnicity'] if 'ethnicity' in req else None
    num_plots = req['num_plots'] if 'num_plots' in req else 1
    seq_len = req['seq_len'] if 'seq_len' in req else 200    if not isinstance(num_plots, int) or not isinstance(seq_len, int):
        return jsonify({
            "message": "Number of words in plot and Number of plots must be integers",
            "status": "Fail"
        })
    
    try:
        plot, status = predictionObject.returnPlot(
            genre = genre, 
            director = director,
            cast = cast,
            ethnicity = ethnicity,
            seq_len = seq_len,
            seq_num = num_plots
        )        if status == 'Pass':
            
            plot["message"] = "Success!"
            plot["status"] = "Pass"
            return jsonify(plot)
        
        else:            return jsonify({"message": plot, "status": status})
    
    except Exception as e:        return jsonify({"message": "Error getting plot for the given input", "status": "Fail"})
  • The main block to run the flask app
if __name__ == "__main__":
    app.run(debug=True, port = 5000)

This script won’t work yet. You might receive an ImportError when executing the script at this point as we haven’t yet created the predict.py script with the PredictionModelObject

 

Create the PredictionModelObject

 

  • Create an predict.py file and import the following
# predict.py
import os
import re
import random
import torch
import torch.nn as nn
from rich.console import Console
from transformers import T5Tokenizer, T5ForConditionalGeneration
from collections import defaultdictconsole = Console(record = True)torch.cuda.manual_seed(3007)
torch.manual_seed(3007)
  • Create the PredictionModelObject class
# predict.py
class PredictionModelObject(object):    def __init__(self):console.log("Model Loading")
        self.model = T5ForConditionalGeneration.from_pretrained('./outputs/model_files')
        self.tokenizer = T5Tokenizer.from_pretrained('./outputs/model_files')
        console.log("Model Loaded")
    
    def beamSearch(self, text, seq_len, seq_num):        outputDict = dict()
        outputDict["plots"] = {}
        input_ids = self.tokenizer.encode(text, return_tensors = "pt")
        beamOp = self.model.generate(
            input_ids,
            max_length = seq_len,
            do_sample = True,
            top_k = 100,
            top_p = 0.95,
            num_return_sequences = seq_num
        )        for ix, sample_op in enumerate(beamOp):
            outputDict["plots"][ix] = self.tokenizer.decode(sample_op, skip_special_tokens = True)
            
        
        return outputDict    def genreToPlot(self, genre, seq_len, seq_num):        text = f"generate plot for genre: {genre}"        return self.beamSearch(text, seq_len, seq_num)    def genreDirectorToPlot(self, genre, director, seq_len, seq_num):        text = f"generate plot for genre: {genre} and director: {director}"
        
        return self.beamSearch(text, seq_len, seq_num)    def genreDirectorCastToPlot(self, genre, director, cast, seq_len, seq_num):        text = f"generate plot for genre: {genre} director: {director} cast: {cast}"        return self.beamSearch(text, seq_len, seq_num)    def genreDirectorCastEthnicityToPlot(self, genre, director, cast, ethnicity, seq_len, seq_num):        text = f"generate plot for genre: {genre} director: {director} cast: {cast} and ethnicity: {ethnicity}"        return self.beamSearch(text, seq_len, seq_num)
    
    def genreCastToPlot(self, genre, cast, seq_len, seq_num):        text = f"genreate plot for genre: {genre} and cast: {cast}"        return self.beamSearch(text, seq_len, seq_num)    def genreEthnicityToPlot(self, genre, ethnicity, seq_len, seq_num):        text = f"generate plot for genre: {genre} and ethnicity: {ethnicity}"        return self.beamSearch(text, seq_len, seq_num)    def returnPlot(self, genre, director, cast, ethnicity, seq_len, seq_num):
        console.log('Got genre: ', genre, 'director: ', director, 'cast: ', cast, 'seq_len: ', seq_len, 'seq_num: ', seq_num, 'ethnicity: ',ethnicity)
        
        seq_len = 200 if not seq_len else int(seq_len)
        
        seq_num = 1 if not seq_num else int(seq_num)
        
        if not director and not cast and not ethnicity:            return self.genreToPlot(genre, seq_len, seq_num), "Pass"
        
        elif genre and director and not cast and not ethnicity:            return self.genreDirectorToPlot(genre, director, seq_len, seq_num), "Pass"        elif genre and director and cast and not ethnicity:            return self.genreDirectorCastToPlot(genre, director, cast, seq_len, seq_num), "Pass"        elif genre and director and cast and ethnicity:            return self.genreDirectorCastEthnicityToPlot(genre, director, cast, ethnicity, seq_len, seq_num), "Pass"        elif genre and cast and not director and not ethnicity:            return self.genreCastToPlot(genre, cast, seq_len, seq_num), "Pass"
        
        elif genre and ethnicity and not director and not cast:            return self.genreEthnicityToPlot(genre, ethnicity, seq_len, seq_num), "Pass"        else:            return "Genre cannot be empty", "Fail"

Save the predict.py file and then run the app.py file in debug mode using,

 

Test your API

 

  • Create a test_api.py file and execute
# test_api.py
import requests
import osurl = "<http://localhost:5000/api/generatePlot>"
json = {
    "genre": str(input("Genre: ")),
    "director": str(input("Director: ")),
    "cast": str(input("Cast: ")),
    "ethnicity": str(input("Ethnicity: ")),
    "num_plots": int(input("Num Plots: ")),
    "seq_len": int(input("Sequence Length: ")),
}r = requests.post(url, json = json)
print(r.json())

 

How to run with gunicorn ?

 
Using gunicorn with flask is very easy. While installing the requirements at the start we have installed the gunicorn command and now we need to go to the folder where the app.py file is located via. the terminal and run the following command

gunicorn -k gthread -w 2 -t 40000 --threads 3 -b:5000 app:app

The format and flags we use above represent the following

  • k: kind (type of workers)- gthreadgevent, etc…
  • w: number of workers
  • t: timeout time
  • threads: number of threads per worker
  • b: bind port number

If your filename is server.py or flask_app.py the app:app part will change to server:app or flask_app:app

 

In Summary

 
In this blog, we saw how can we use our previously trained T5 transformer to generate storylines and deploy it using flask and gunicorn. This blog is made to be followed pretty easily so you don’t waste time going around different platforms to check out the issues. Hope you have fun reading and implementing this.

 
Bio: Vatsal Saglani (@saglanivatsal) is a Machine Learning Engineer at Quinnox.

Original. Reposted with permission.

Related:





Source link

Leave a Reply

Your email address will not be published. Required fields are marked *