region_data = json.load(open('region.json')) districts_data = sum([x['districts'] for x in region_data['districts']], []) city_data = [[x['name'], x['center']['longitude'], x['center']['latitude']] for x in districts_data] city_data += [[x['name'], x['center']['longitude'], x['center']['latitude']] for x in region_data['districts']] city_data = pd.DataFrame(city_data)
from sklearn.neighbors import NearestNeighbors nbrs = NearestNeighbors(n_neighbors=40, algorithm='ball_tree').fit(test_img_locations)
加载CLIP模型
import os os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from PIL import Image import requests from transformers import ChineseCLIPProcessor, ChineseCLIPModel
model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") processor = ChineseCLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
图文匹配过程
questions = open('./数据集更新/初赛测试集/问题.txt').readlines() results = [] for question in tqdm(questions): words = jieba.lcut(question) words = [x for x in words if len(x) > 1 and not x.isdigit()] city = words[0] city_pic_dis, city_pic_index = nbrs.kneighbors([city_data[city_data[0].apply(lambda x: x == city)].values[0][1:]]) city_pic_dis = city_pic_dis[0] city_pic_index = city_pic_index[0]
with torch.no_grad(): # compute image feature inputs = processor(images=[Image.open(x) for x in test_imgs[city_pic_index]], return_tensors="pt") image_features = model.get_image_features(**inputs) image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) # normalize