为了账号安全,请及时绑定邮箱和手机立即绑定

Python Skip-Gram代码实战:从理论到实践,构建Skip-Gram模型与Word2vec代码详解

标签:
杂七杂八

定义与目的

Skip-gram模型作为Word2vec的一种训练方法,致力于通过目标词预测上下文词,旨在通过神经网络结构学习单词的向量表示,从而增强计算机对语言的理解与处理能力。

简单步骤概览

  1. 构建句子列表与词汇表
  2. 生成Skip-Gram训练数据
  3. 定义One-Hot编码函数
  4. 实现Skip-Gram类
  5. 训练模型
  6. 输出词嵌入
  7. 向量可视化

Python代码实战概览

通过分步骤的代码实现,我们将演示Skip-Gram模型从构建到应用的全过程,包括词汇表构建、生成训练数据、One-Hot编码、模型定义、训练过程、词嵌入输出及向量可视化。

步骤详解与代码

1. 句子列表与词汇表构建

sentences = ["The cat sat on the mat", "The dog chased the cat", "Under the mat"]
words = ' '.join(sentences).split()
word_list = list(set(words))
word_to_idx = {word: idx for idx, word in enumerate(word_list)}
idx_to_word = {idx: word for idx, word in enumerate(word_list)}
voc_size = len(word_list)
print("词汇表:", word_list)
print("词汇到索引的字典:", word_to_idx)
print("索引到词汇的字典:", idx_to_word)
print("词汇表大小:", voc_size)

2. 生成Skip-Gram训练数据

def create_skipgram_dataset(sentences, window_size=2):
    data = []
    for sentence in sentences:
        sentence_words = sentence.split()
        for idx, word in enumerate(sentence_words):
            for neighbor in sentence_words[max(idx - window_size, 0): min(idx + window_size + 1, len(sentence_words))]:
                if neighbor != word:
                    data.append((neighbor, word))
    return data

skipgram_data = create_skipgram_dataset(sentences)
print("Skip-Gram数据样例:", skipgram_data[:5])

3. 定义One-Hot编码函数

import torch

def one_hot_encoding(word, word_to_idx):
    tensor = torch.zeros(len(word_to_idx))
    tensor[word_to_idx[word]] = 1
    return tensor

word_example = "cat"
print("编码前的单词:", word_example)
print("编码后的向量:", one_hot_encoding(word_example, word_to_idx))
encoded_data = [(one_hot_encoding(context, word_to_idx), word_to_idx[target]) for context, target in skipgram_data[:5]]
print("编码后的Skip-Gram数据样例:", encoded_data[:3])

4. 实现Skip-Gram类

import torch.nn as nn

class SkipGram(nn.Module):
    def __init__(self, voc_size, embedding_size):
        super(SkipGram, self).__init__()
        self.input_to_hidden = nn.Linear(voc_size, embedding_size)
        self.hidden_to_output = nn.Linear(embedding_size, voc_size)

    def forward(self, X):
        hidden = self.input_to_hidden(X)
        output = self.hidden_to_output(hidden)
        return output

embedding_size = 2
skipgram_model = SkipGram(voc_size, embedding_size)
print("Skip-Gram模型:", skipgram_model)

5. 训练Skip-Gram类

learning_rate = 0.01
epochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(skipgram_model.parameters(), lr=learning_rate)

loss_values = []
for epoch in range(epochs):
    loss_sum = 0
    for context, target in encoded_data:
        X = context.float().unsqueeze(0)
        y_true = torch.tensor([target], dtype=torch.long)
        y_pred = skipgram_model(X)
        loss = criterion(y_pred, y_true)
        loss_sum += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if (epoch+1) % 10 == 0:
        print(f"Epoch: {epoch+1}, Loss: {loss_sum/len(encoded_data)}")
        loss_values.append(loss_sum / len(encoded_data))

import matplotlib.pyplot as plt

plt.plot([i+1 for i in range(epochs)], loss_values)
plt.title('训练损失曲线')
plt.xlabel('轮次')
plt.ylabel('损失')
plt.show()

6. 输出词嵌入

for word, idx in word_to_idx.items():
    print(f"{word}: {skipgram_model.input_to_hidden.weight[:,idx].detach().numpy()}")

7. 向量可视化

import matplotlib.pyplot as plt

plt.rcParams["font.family"] = 'SimHei'

# 假设进行了降维处理(例如PCA或t-SNE),以适应二维可视化
# 这里仅展示降维后的数据示例
# 示例数据:降维后的词向量表示
reduced_embedding = [...]

# 绘制词嵌入在二维空间内的散点图
plt.scatter(reduced_embedding[:, 0], reduced_embedding[:, 1])
for word, idx in word_to_idx.items():
    plt.annotate(word, (reduced_embedding[idx, 0], reduced_embedding[idx, 1]), fontsize=12)
plt.title('二维词嵌入')
plt.xlabel('向量维度1')
plt.ylabel('向量维度2')
plt.show()

总结

Skip-gram模型通过训练大量文本数据,成功学习到单词之间的语义关系,生成词向量,为自然语言处理任务提供了关键的技术支撑。通过以上代码实践,我们深入理解了Skip-gram模型的构建与应用,为进一步探索自然语言处理领域打下了坚实的基础。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消