文搜图/图搜图
文搜图/图搜图
- 1.环境
- 2.建集合
- 3.图入集合
- 4.输入向量化
- 5.文搜图
- 6.图搜图
- 7.参考博文
- 8.仓库代码
1.环境
linux安装docker,修改镜像源,安装docker-compose
#1.安装dockersudo apt updatesudo apt install docker.iosudo systemctl start dockersudo docker --version#2.修改docker镜像源sudo su vi /etc/docker/daemon.json{ \"registry-mirrors\": [\"https://rsk59qvc.mirror.aliyuncs.com\"]}sudo systemctl restart docker重启docker info查看是否修改成功#3.安装docker-composesudo curl -L \"https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)\" -o /usr/local/bin/docker-composesudo chmod +x /usr/local/bin/docker-composedocker-compose --version
安装启动milvus(2.5.0)容器和可视化attu容器
wget https://github.com/milvus-io/milvus/releases/download/v2.5.0/milvus-standalone-docker-compose.yml -O docker-compose.ymlsudo docker-compose up -ddocker run -d --name attu -p 8000:3000 -e MILVUS_URL=host.docker.internal:19530 zilliz/attu:v2.5#windows的话按如下安装milvus管理员身份打开powershell,Invoke-WebRequest https://raw.githubusercontent.com/milvus-io/milvus/refs/heads/master/scripts/standalone_embed.bat -OutFile standalone.bat.\\standalone.bat startdocker ps -a (查看是否成功)
本地浏览器访问http://localhost:8000/#/connect可视化milvus库
python环境
pip install cn-clip ipythonpip install pymilvus==2.5.0
2.建集合
# 1.创建milvus库对象from pymilvus import MilvusClient, DataTypeimport torchimport timedef create_schema(): schema = milvus_client.create_schema( auto_id=True, enable_dynamic_field=True, description=\"\" ) schema.add_field(field_name=\"id\", datatype=DataType.INT64, descrition=\'ids\', is_primary=True) schema.add_field(field_name=\"vectors\", datatype=DataType.FLOAT_VECTOR, descrition=\'embedding vectors\', dim=512) schema.add_field(field_name=\"filepath\", datatype=DataType.VARCHAR, descrition=\'file path\', max_length=200) return schemadef create_collection(collection_name, schema, timeout = 3): # 创建集合 try: milvus_client.create_collection( collection_name=collection_name, schema=schema, shards_num=2 ) print(f\"开始创建集合:{collection_name}\") except Exception as e: print(f\"创建集合的过程中出现了错误: {e}\") return False # 检查集合是否创建成功 start_time = time.time() while True: if milvus_client.has_collection(collection_name): print(f\"集合 {collection_name} 创建成功\") return True elif time.time() - start_time > timeout: print(f\"创建集合 {collection_name} 超时\") return False time.sleep(1)class CollectionDeletionError(Exception): \"\"\"删除集合失败\"\"\"def check_and_drop_collection(collection_name): if milvus_client.has_collection(collection_name): print(f\"集合 {collection_name} 已经存在\") try: milvus_client.drop_collection(collection_name) print(f\"删除集合:{collection_name}\") return True except Exception as e: print(f\"删除集合时出现错误: {e}\") return False return Truecollection_name = \"w_cc\"uri=\"http://localhost:19530\"milvus_client = MilvusClient(uri=uri)# 如果无法删除集合,抛出异常if not check_and_drop_collection(collection_name): raise CollectionDeletionError(\'删除集合失败\')else: # 创建集合的模式 schema = create_schema() # 创建集合并等待成功 create_collection(collection_name, schema)
3.图入集合
# 2.向量化图像与文字,并把图像入库,创建索引,使用倒排索引(IVF_FLAT),检索效率高,准确性也不错。度量方式使用余弦相似度(COSINE)。import cn_clip.clip as clip # 导入可用模型的函数from cn_clip.clip import available_modelsimport torchfrom PIL import Imageimport osfrom glob import globfrom tqdm import tqdmimport timeimport cn_clip.clip as clip # 导入可用模型的函数from cn_clip.clip import available_modelsimport torch# 用于图片处理from PIL import Imagefrom pymilvus import MilvusClient# 查看 chinese-clip 中可用模型列表print(\"Available models:\", available_models())# 确定使用的设备:如果可用则使用GPU,否则使用CPUdevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"# 指定模型名称model_name = \"ViT-B-16\"# 加载chinese-clip模型和对应的预处理函数# model: 包含图片编码器(encode_image)和文本编码器(encode_text)# preprocess: 图片预处理函数(包括归一化、缩放等操作)# download_root: 设置模型下载后保存的位置model, preprocess = clip.load_from_name(model_name, device=device, download_root=\'./chinese_clip_model\')# 将模型设置为评估模式,关闭dropout等训练特性model.eval()collection_name = \"multimodal_chinese_clip\"uri=\"http://localhost:19530\"milvus_client = MilvusClient(uri=uri)def encode_image(image_path): # 关闭梯度计算,减少内存消耗,提高计算效率 with torch.no_grad(): # 打开图片文件 # 如果图片不是RGB格式,使用convert转换格式 raw_image = Image.open(image_path).convert(\'RGB\') processed_image = preprocess(raw_image).unsqueeze(0).to(device) # 生成图片的向量 image_features = model.encode_image(processed_image) # 特征归一化 image_features /= image_features.norm(dim=-1, keepdim=True) # 以列表形式返回向量 return image_features.squeeze().tolist()def encode_text(text_list): # 关闭梯度计算,减少内存消耗,提高计算效率 with torch.no_grad(): # 文本分词和特殊符号处理 text_tokens = clip.tokenize(text_list).to(device) # 生成文本的向量 text_features = model.encode_text(text_tokens) # 特征归一化 text_features /= text_features.norm(dim=-1, keepdim=True) # 以列表形式返回向量 return [f.squeeze().tolist() for f in text_features]def process_images_and_insert(input_dir_path, ext_list, batch_size=100): # 获取所有图片路径(递归图片检索) image_paths = [] for ext in ext_list: print(f\"正在查找扩展名: {ext}\") # 确保路径通配符正确,递归查找 pattern = os.path.join(input_dir_path, f\"**/*{ext}\") #f\"**/*{ext}\") print(f\"搜索模式: {pattern}\") image_paths.extend(glob(pattern, recursive=True)) total_images = len(image_paths) print(f\"总计需要处理 {total_images} 张图片\") # 初始化总计时器 total_start_time = time.time() # 初始化进度条 with tqdm(total=total_images, desc=\"处理图片并插入数据\") as progress_bar: # 分批处理图片 for batch_start in range(0, total_images, batch_size): batch_data = [] batch_paths = image_paths[batch_start: batch_start + batch_size] batch_start_time = time.time() # 当前批次的向量化处理 for image_path in batch_paths: try: image_embedding = encode_image(image_path) batch_data.append({ \"vectors\": image_embedding, \"filepath\": image_path }) except Exception as e: print(f\"处理图片 {image_path} 时出错: {str(e)}\") continue # 批量插入当前批次到Milvus if batch_data: try: res = milvus_client.insert( collection_name=collection_name, data=batch_data ) # 计算批次耗时 batch_duration = time.time() - batch_start_time # 更新进度条:每次成功插入的图片数量 progress_bar.update(len(batch_data)) # 显示批次处理时间 progress_bar.set_postfix({ \"批次耗时\": batch_duration, }) except Exception as e: print(f\"插入批次 {batch_start} 时失败: {str(e)}\") # 计算总耗时 total_duration = time.time() - total_start_time print(f\"\\n所有图片处理完成!总耗时: {total_duration}\") print(f\"平均处理速度: {total_images/total_duration:.1f}张/秒\")input_dir_path = \"lhq_1024_jpg_5000\"batch_size = 300ext_list = [\'.JPEG\', \'.jpg\', \'.png\'] # 确保扩展名大小写问题process_images_and_insert(input_dir_path, ext_list, batch_size)def create_index(collection_name): # 准备索引参数 index_params = milvus_client.prepare_index_params() index_params.add_index( index_name=\"IVF_FLAT\", # 指定创建索引的字段 field_name=\"vectors\", index_type=\"IVF_FLAT\", metric_type=\"COSINE\", params={\"nlist\":512} ) # 创建索引 milvus_client.create_index( collection_name=collection_name, index_params=index_params )create_index(collection_name)# 加载集合print(f\"正在加载集合 {collection_name}\")milvus_client.load_collection(collection_name=collection_name)print(f\"集合 {collection_name} 加载完成\")# 验证加载状态state = str(milvus_client.get_load_state(collection_name=collection_name)[\'state\'])if state == \'Loaded\': print(\"集合加载完成\")else: print(\"集合加载失败\")print(milvus_client.query( collection_name=collection_name, output_fields=[\"count(*)\"]))
4.输入向量化
from PIL import Imagefrom pymilvus import MilvusClientimport cn_clip.clip as clip # 导入可用模型的函数from cn_clip.clip import available_modelsimport torchfrom PIL import Imageimport osfrom glob import globfrom tqdm import tqdmimport timeimport cn_clip.clip as clip # 导入可用模型的函数from cn_clip.clip import available_modelsimport torch# 用于图片处理from PIL import Imagefrom pymilvus import MilvusClientcollection_name = \"w_cc\"uri=\"http://localhost:19530\"milvus_client = MilvusClient(uri=uri)print(\"Available models:\", available_models())# 确定使用的设备:如果可用则使用GPU,否则使用CPUdevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"# 指定模型名称model_name = \"ViT-B-16\"model, preprocess = clip.load_from_name(model_name, device=device, download_root=\'./chinese_clip_model\')# 将模型设置为评估模式,关闭dropout等训练特性5model.eval()def encode_image(image_path): # 关闭梯度计算,减少内存消耗,提高计算效率 with torch.no_grad(): # 打开图片文件 # 如果图片不是RGB格式,使用convert转换格式 raw_image = Image.open(image_path).convert(\'RGB\') processed_image = preprocess(raw_image).unsqueeze(0).to(device) # 生成图片的向量 image_features = model.encode_image(processed_image) # 特征归一化 image_features /= image_features.norm(dim=-1, keepdim=True) # 以列表形式返回向量 return image_features.squeeze().tolist()def encode_text(text_list): # 关闭梯度计算,减少内存消耗,提高计算效率 with torch.no_grad(): # 文本分词和特殊符号处理 text_tokens = clip.tokenize(text_list).to(device) # 生成文本的向量 text_features = model.encode_text(text_tokens) # 特征归一化 text_features /= text_features.norm(dim=-1, keepdim=True) # 以列表形式返回向量 return [f.squeeze().tolist() for f in text_features]def vector_search(vector, field_name, limit, output_fields): # 执行向量图片检索 res = milvus_client.search( collection_name=collection_name, data=vector, anns_field=field_name, limit=limit, output_fields=output_fields ) return res# from IPython.display import display# from PIL import Image# # 定义显示图片检索结果的函数def create_concatenated_image(res, images_per_row=2, images_per_column=2, image_size=(400, 400)): # 设置拼接后的大图尺寸:宽度是每行图片的宽度之和,高度是每列图片的高度之和 width = image_size[0] * images_per_row height = image_size[1] * images_per_column # 创建一个空白的大画布(RGB模式,白色背景) concatenated_image = Image.new(\"RGB\", (width, height)) # 存储所有结果图片的列表 result_images = [] # 遍历图片检索结果的每个hit对象 for result in res: # 通常res是单batch列表 for hit in result: # 从hit对象中获取图片文件路径 filename = hit[\"entity\"][\"filepath\"] # 打开图片文件并调整大小为指定尺寸 try: img = Image.open(filename) # 保持宽高比的缩略图 img = img.resize(image_size) result_images.append(img) except Exception as e: print(f\"无法加载图片 {filename}: {e}\") continue # 将缩略图拼接到大画布上 for idx, img in enumerate(result_images): # 计算当前图片应放置的网格位置 x = idx % images_per_row y = idx // images_per_row # 将图片粘贴到计算好的位置 concatenated_image.paste(img, (x * image_size[0], y * image_size[1])) return concatenated_image
5.文搜图
query_text = [\"小桥流水人家\"]query_embedding = encode_text (query_text)[0]field_name = \"vectors\"limit = 10output_fields = [\"filepath\"]res = vector_search ([query_embedding], field_name, limit, output_fields)print(f\"查询文本: {query_text}\")print(f\"检索结果:\")# 使用 create_concatenated_image 函数生成拼接图像result_image = create_concatenated_image(res, 2, 2, (400, 400))# 保存拼接图像到本地目录output_path = \"./output/concatenated_image.png\"result_image.save(output_path)print(f\"拼接图像已保存到: {output_path}\")
做成接口
app = FastAPI()class QueryRequest(BaseModel): query_text: str@app.post(\"/text-search-images/\")async def search_images(query_request: QueryRequest):#用户输入query_text query_text = query_request.query_text query_embedding = encode_text([query_text])[0] # 获取文本向量 field_name = \"vectors\" limit = 10 output_fields = [\"filepath\"] res = vector_search([query_embedding], field_name, limit, output_fields) # image_paths = [image[\"filepath\"] for image in res] image_paths=[] for images in res: image_paths.extend([image[\"entity\"][\"filepath\"] for image in images ]) # 返回图片路径列表 return {\"images\": image_paths}
post测试
import requests# 测试文本查询相似图片的接口def test_text_search_images(query_text): url = \"http://127.0.0.1:8001/text-search-images/\" response = requests.post(url, json={\"query_text\": query_text}) if response.status_code == 200: print(\"查询相似图片成功:\") print(response.json()) # 打印返回的图片路径列表 else: print(\"查询相似图片失败:\", response.status_code)if __name__ == \"__main__\": query_text = \"小桥流水人家\" # 示例查询文本 test_text_search_images(query_text)
效果
6.图搜图
query_image = \'目标/上海/屏幕截图.png\'query_embedding = encode_image(query_image)field_name = \"vectors\"limit = 5output_fields = [\"filepath\"]res = vector_search([query_embedding], field_name, limit, output_fields)image_paths=[]for images in res: image_paths.extend([image[\"entity\"][\"filepath\"] for image in images ])print(image_paths)print(f\"查询图片\")query_image_save_path = \'./output/query_image.png\'print(f\"图片检索结果:\")concatenated_image = create_concatenated_image(res, images_per_row=3, images_per_column=3, image_size=(300, 300))concatenated_image_save_path = \'./output/retrieved_images.png\'concatenated_image.save(concatenated_image_save_path)print(f\"检索结果图像已保存到: {concatenated_image_save_path}\")
做成接口
class ImageQueryRequest(BaseModel): image_path: str# 定义后端接口:根据上传的图片查询相似的图片@app.post(\"/search-similar-images/\")async def search_similar_images(request: ImageQueryRequest): image_path = request.image_path # 获取图片的嵌入向量 query_embedding = encode_image(image_path) # 查询相似图片路径 field_name = \"vectors\" limit = 10 output_fields = [\"filepath\"] res = vector_search([query_embedding], field_name, limit, output_fields) image_paths=[] for images in res: image_paths.extend([image[\"entity\"][\"filepath\"] for image in images ]) # 提取文件路径 # image_paths = [image[\"filepath\"] for image in res] # 返回匹配的图片路径列表 return {\"similar_images\": image_paths}# 定义后端接口:上传图片并展示@app.post(\"/show-image/\")async def show_image(image: UploadFile = File(...)): # 保存上传的图片到临时目录 temp_image_path = f\"./temp_images/{image.filename}\" os.makedirs(os.path.dirname(temp_image_path), exist_ok=True) with open(temp_image_path, \"wb\") as f: f.write(await image.read()) # 使用 PIL 打开并显示图片 img = Image.open(temp_image_path) img.show() return {\"message\": f\"图片已显示,路径: {temp_image_path}\"}
post测试
import requestsdef test_search_similar_images(image_path): url = \"http://127.0.0.1:8001/search-similar-images/\" response = requests.post(url, json={\"image_path\": image_path}) if response.status_code == 200: print(\"查询相似图片成功:\") print(response.json()) else: print(\"查询相似图片失败:\", response.status_code)def test_show_image(image_path): url = \"http://127.0.0.1:8001/show-image/\" # 打开图片文件,并发送 POST 请求 with open(image_path, \"rb\") as img_file: files = {\"image\": img_file} response = requests.post(url, files=files) if response.status_code == 200: print(\"图片显示成功:\") print(response.json()) else: print(\"图片显示失败:\", response.status_code)if __name__ == \"__main__\": image_path = \'query_image.jpg\' test_search_similar_images(image_path) test_show_image(image_path)
7.参考博文
[1]https://mp.weixin.qq.com/s/wW_3X7CquqeuEdu4-zn3qg
8.仓库代码
https://github.com/Turing-dz/text_img_search_img