万圣节终于到了! 🎃
那是南瓜雕刻、性感的服装和围着闪烁烛光低声细语讲述恐怖故事的时节。
但如果你和我一样,总是在最需要的时候记不起任何恐怖故事。所以我想到,为什么不做一个工具,可以从大量故事中挑选出真正能让我们感到毛骨悚然的那种故事呢。
这就是我们今天在建的东西。
这个计划很简单。
我们将使用Reddit上的恐怖故事数据集,对这些数据进行嵌入处理,并建立一个Qdrant集合,以便根据主题和氛围等进行搜索。简单来说,就是捕捉到这种“氛围”,比如“闹鬼的房子”或“诡异的森林”。
我会展示你需要的所有步骤来构建一个这样的应用:设置向量数据库的步骤,嵌入和索引数据,以及召唤最恐怖的万圣节故事。
咱们就开始吧。
1. 安装库文件首先,我们先来安装我们将要用到的工具:
在命令行中输入以下命令:
pip install qdrant-client sentence_transformers datasets
切换到全屏模式,退出全屏
2.: 下载数据集文件我们将使用这个Reddit上的恐怖故事数据集。我们用datasets库来下载一下这个数据集吧。
从datasets库导入load_dataset函数
ds = load_dataset("intone/horror_stories_reddit")
# ds变量被赋值为load_dataset函数加载的数据集,其来源为'intone/horror_stories_reddit'。
进入全屏模式,退出全屏模式
3. 加载嵌入式模型.我们将使用sentence_transformers
库来帮助我们将数据嵌入到all-MiniLM-L6-v2
模型中。这里是如何进行设置的:
从'sentence_transformers'导入SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cpu')
进入全屏 退出全屏
如果你手头上有一个可用的GPU并且想要加快处理速度,只需改成 device='cuda:0'
。
generate_embeddings_direct
函数处理数据集的某部分(例如“train”)时,将其拆分成更小的组,称为批次(batch),根据指定的batch_size
。这有助于更高效地管理内存。
对于每批,该函数提取一组句子(例如,每次32个),并使用加载的嵌入模型来处理这些句子。
导入 tqdm 模块
def generate_embeddings(split, batch_size=32):
embeddings = []
split_name = [name for name, data_split in ds.items() if data_split is split][0]
with tqdm(total=len(split), desc=f"正在为 {split_name} 分割生成嵌入") as pbar:
for i in range(0, len(split), batch_size):
batch_sentences = split['text'][i:i+batch_size]
batch_embeddings = model.encode(batch_sentences)
embeddings.extend(batch_embeddings)
pbar.update(len(batch_sentences))
return embeddings
全屏 退出全屏
它会立即在数据集中添加一个新的列,把它们放进去。这样,这个功能能高效更新数据,同时节省内存。
# 生成训练数据集的嵌入表示
train_embeddings = generate_embeddings(ds['train'])
# 将嵌入表示添加到训练数据集中
ds["train"] = ds["train"].add_column("embeddings", train_embeddings)
点击全屏 点击退出全屏
5. 设置客户端现在我们可以开始使用Qdrant客户端了。如果你是在本地操作,只需连接到默认端点即可。
from qdrant_client import QdrantClient
# 连接到本地的Qdrant实例
qdrant_client = QdrantClient(url="http://localhost:6333")
可以进入全屏模式,也可以退出全屏
很简单吧?但在实际工作中,你很可能是在云端工作。这意味着你需要设置并验证Qdrant Cloud实例。
云配置
要连接到您的云实例,您需要实例网址和一个API密钥(API key)。具体操作步骤如下。
导入 QdrantClient 从 qdrant_client
# 每处需用您的 Qdrant 云实例 URL 和 API Key 替换
qdrant_client = QdrantClient(
url="https://YOUR_CLOUD_INSTANCE_ID.aws.qdrant.tech", # 用您的云实例 URL 替换这里的 URL
api_key="YOUR_API_KEY" # 用您的 API Key 替换这里的 api_key
)
全屏模式 退出全屏
确保将 YOUR_CLOUD_INSTANCE_ID
替换成你的实际实例 ID,并将 YOUR_API_KEY
替换成你创建的那个 API 密钥。你可以在 Qdrant Cloud 控制台里找到这些信息。
在 Qdrant 中,一个集合就像是一个迷你数据库,优化存储和查询向量。当我们定义一个集合时,我们需要设定向量的尺寸和用于衡量相似性的度量。其设置可能如下所示:
导入qdrant_client中的models
collection_name="halloween"
# 下面的代码创建了一个集合来存储产品特性的向量
qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE) # 向量参数配置
)
进入全屏,退出全屏
我们定义了一个名为 halloween
的集合,包含384维向量,这是 all-MiniLM-L6-v2
嵌入的大小。这里我们使用余弦距离作为相似度的衡量标准。根据您的数据和应用场景,您可能希望使用不同的距离度量方式,例如 Distance.EUCLID
或 Distance.DOT
。
:)
没有数据,集合就什么都不是。现在是时候把我们之前创建的嵌入插入其中了。这里有一个策略,可以分批加载嵌入:
def batched(iterable, n):
iterator = iter(iterable)
while 批次 := list(islice(iterator, n)):
yield 批次
批大小 = 100
当前ID = 0 # 初始化计数器
切换到全屏模式 退出全屏
batched
函数将一个可迭代对象分割成大小为 n
的多个小块。它使用 islice
提取连续的元素组,并为每一块生成结果,直到数据集被完全处理完。
from itertools 导入 islice
for batch in batched(ds["train"], batch_size):
# 生成一个ID列表,使用计数器
ids = list(range(current_id, current_id + len(batch)))
# 更新计数器,使其从批次后的下一个ID开始
current_id += len(batch)
vectors = [point.pop("embeddings") for point in batch]
qdrant_client.upsert(
collection_name=collection_name,
points=models.Batch(
ids=ids,
vectors=vectors,
payloads=batch,
),
)
切换到全屏模式,切换回正常模式
每个批次都会通过upsert
方法发送至Qdrant,该方法用来插入数据批次。upsert
方法接收一组ID、向量以及其余项目数据(负载数据),以存储或更新在Qdrant集合里。
终于,到了这个时候。
一切准备好了,现在来看看我们的恐怖故事搜索工具能否真的吓到人。我们试着搜一下“恐怖小丑”,看看结果。
import json
import textwrap
# 用于包装并打印长文本的函数
def print_wrapped(text, width=80):
wrapped_text = textwrap.fill(text, width=width)
print(wrapped_text)
# 搜索结果查询
search_result = qdrant_client.query_points(
collection_name=collection_name,
query=model.encode("creepy clown").tolist(),
limit=1,
)
# 获取第一个结果
if search_result.points:
tale = search_result.points[0]
# 打印负载信息
print("ID:", tale.id)
print("Score:", tale.score)
print("Original:", tale.payload.get('isOriginal', 'N/A'))
# 打印特定的负载字段
print("Title:", tale.payload.get('title', 'N/A'))
print("Author:", tale.payload.get('author', 'N/A'))
print("Subreddit:", tale.payload.get('subreddit', 'N/A'))
print("URL:", tale.payload.get('url', 'N/A'))
# 单独打印故事文本,带有单词包装以提高可读性
print("\nStory Text:\n")
print_wrapped(tale.payload.get('text', 'No text available'), width=80)
else:
print("未找到结果.")
全屏 退出全屏
结果就这么出现了,标题叫做:“偷瞄小贼”。
说实话,这真的挺诡异的啊。
它到底是基于真实故事还是虚构的?说实话,我真的不清楚。它会让人感到一种若有若无的不安,仿佛有什么东西在注视着你。它挺长的,所以我就不贴在这里了,但如果你想亲自看看,尽管去运行程序试试吧。
你可以探索任何其他气氛:“闹鬼的屋子”,“阴森的森林”,“被诅咒的玩偶”,或者你想体验的任何其他气氛。谁知道呢,你可能会发现更恐怖的东西。
如果你发现了,不妨在评论里发出来。我很想看看它还能发现什么新东西。
下一步感谢你陪我一起完成这次万圣节的实验!如果你一直跟着的话,你已经迈出了进入向量搜索世界的第一步,并学会了如何找到那些感觉恐怖但又不仅仅是包含恐怖词汇的。
如果你准备好了,就可以进入向量搜索的神秘领域,你可以探索许多更高级的话题,比如multitenancy、payload结构体和批量上传。
所以,放手去做,看看你能深入到什么地步。
祝你好运! 👻
共同学习,写下你的评论
评论加载中...
作者其他优质文章