2024年2月5日 14:46 by wst
算法实践图生文代码举例:
import torch
import webdataset as wds
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
SAVE_PATH = "/home/wst/models"
def gen_caption2(img, processor, model):
"生成文本"
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = processor(img, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
processor.save_pretrained(SAVE_PATH)
model.save_pretrained(SAVE_PATH)
url = "00001.tar"
dataset = wds.WebDataset(url).decode('pil')
for sample in dataset:
img = sample['jpg']
txt = gen_caption2(img, processor, model)
print("res:", txt)