在这篇文章中,我们演示了如何有效地微调最先进的蛋白质语言模型(pLM)来预测蛋白质亚细胞定位 亚马逊SageMaker.
蛋白质是身体的分子机器,负责从移动肌肉到应对感染的一切。尽管有多种蛋白质,但所有蛋白质都是由重复的氨基酸分子链组成。人类基因组编码 20 种标准氨基酸,每种氨基酸的化学结构略有不同。这些可以用字母表中的字母表示,然后我们可以将蛋白质作为文本字符串进行分析和探索。蛋白质序列和结构的巨大可能数量赋予了蛋白质广泛的用途。
蛋白质在药物开发中也发挥着关键作用,既可以作为潜在靶点,也可以作为治疗药物。如下表所示,2022 年许多最畅销的药物要么是蛋白质(尤其是抗体),要么是其他分子,如 mRNA 在体内翻译成蛋白质。正因为如此,许多生命科学研究人员需要更快、更便宜、更准确地回答有关蛋白质的问题。
名字 | 生产厂家 | 2022 年全球销售额(十亿美元) | 适应症 |
合宜性 | 辉瑞/ BioNTech | $40.8 | Covid-19 |
斯派克瓦克斯 | 现代 | $21.8 | Covid-19 |
阿达木单抗 | 艾伯维 | $21.6 | 关节炎、克罗恩病等 |
凯特鲁达 | 默克公司 | $21.0 | 各种癌症 |
数据来源:厄克特,L. 2022 年销售额最高的公司和药品。 《自然评论药物发现》22, 260–260 (2023)。
因为我们可以将蛋白质表示为字符序列,所以我们可以使用最初为书面语言开发的技术来分析它们。这包括在大型数据集上预训练的大型语言模型 (LLM),然后可以针对特定任务进行调整,例如文本摘要或聊天机器人。同样,pLM 是使用未标记的自我监督学习在大型蛋白质序列数据库上进行预训练的。我们可以调整它们来预测蛋白质的 3D 结构或它如何与其他分子相互作用等。研究人员甚至使用 pLM 从头开始设计新型蛋白质。这些工具不会取代人类的科学专业知识,但它们有可能加快临床前开发和试验设计。
这些模型面临的一项挑战是它们的尺寸。 LLM 和 pLM 在过去几年中都实现了数量级的增长,如下图所示。这意味着可能需要很长时间才能将它们训练到足够的精度。这也意味着您需要使用具有大量内存的硬件,尤其是 GPU 来存储模型参数。
较长的训练时间加上大量的实例,意味着高昂的成本,这使得这项工作对于许多研究人员来说是遥不可及的。例如,在 2023 年, 研究团队 描述了在 100 个 A768 GPU 上训练 100 亿参数的 pLM 164 天!幸运的是,在许多情况下,我们可以通过使现有的 pLM 适应我们的特定任务来节省时间和资源。这种技术称为 微调,并且还允许我们借用其他类型语言建模的高级工具。
解决方案概述
我们在这篇文章中解决的具体问题是 亚细胞定位:给定一个蛋白质序列,我们可以建立一个模型来预测它是生活在细胞外部(细胞膜)还是内部?这是一条重要的信息,可以帮助我们了解其功能以及它是否会成为良好的药物靶点。
我们首先使用下载公共数据集 亚马逊SageMaker Studio。然后我们使用 SageMaker 通过高效的训练方法对 ESM-2 蛋白质语言模型进行微调。最后,我们将该模型部署为实时推理端点,并用它来测试一些已知的蛋白质。下图说明了此工作流程。
在以下部分中,我们将逐步介绍准备训练数据、创建训练脚本和运行 SageMaker 训练作业的步骤。这篇文章中的所有代码都可以在 GitHub上.
准备训练数据
我们使用部分 DeepLoc-2 数据集,其中包含数千个 SwissProt 蛋白质,其位置已通过实验确定。我们筛选 100-512 个氨基酸之间的高质量序列:
df = pd.read_csv(
"https://services.healthtech.dtu.dk/services/DeepLoc-2.0/data/Swissprot_Train_Validation_dataset.csv"
).drop(["Unnamed: 0", "Partition"], axis=1)
df["Membrane"] = df["Membrane"].astype("int32")
# filter for sequences between 100 and 512 amino acides
df = df[df["Sequence"].apply(lambda x: len(x)).between(100, 512)]
# Remove unnecessary features
df = df[["Sequence", "Kingdom", "Membrane"]]
接下来,我们对序列进行标记并将它们分成训练集和评估集:
dataset = Dataset.from_pandas(df).train_test_split(test_size=0.2, shuffle=True)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
def preprocess_data(examples, max_length=512):
text = examples["Sequence"]
encoding = tokenizer(text, truncation=True, max_length=max_length)
encoding["labels"] = examples["Membrane"]
return encoding
encoded_dataset = dataset.map(
preprocess_data,
batched=True,
num_proc=os.cpu_count(),
remove_columns=dataset["train"].column_names,
)
encoded_dataset.set_format("torch")
最后,我们将处理后的训练和评估数据上传到 亚马逊简单存储服务 (Amazon S3):
train_s3_uri = S3_PATH + "/data/train"
test_s3_uri = S3_PATH + "/data/test"
encoded_dataset["train"].save_to_disk(train_s3_uri)
encoded_dataset["test"].save_to_disk(test_s3_uri)
创建训练脚本
SageMaker 脚本模式 允许您在 AWS 管理的优化机器学习 (ML) 框架容器中运行自定义训练代码。对于这个例子,我们采用 现有的文本分类脚本 来自《拥抱的脸》。这使我们能够尝试多种方法来提高训练工作的效率。
方法一:加权训练班
与许多生物数据集一样,DeepLoc 数据分布不均匀,这意味着膜蛋白和非膜蛋白的数量不相等。我们可以重新采样数据并丢弃大多数类别的记录。然而,这会减少总训练数据并可能损害我们的准确性。相反,我们在训练过程中计算类别权重并使用它们来调整损失。
在我们的训练脚本中,我们将 Trainer
班级来自 transformers
用 WeightedTrainer
在计算交叉熵损失时考虑类权重的类。这有助于防止我们的模型出现偏差:
class WeightedTrainer(Trainer):
def __init__(self, class_weights, *args, **kwargs):
self.class_weights = class_weights
super().__init__(*args, **kwargs)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
loss_fct = torch.nn.CrossEntropyLoss(
weight=torch.tensor(self.class_weights, device=model.device)
)
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
方法二:梯度累积
梯度累积是一种训练技术,允许模型模拟较大批量的训练。通常,批量大小(在一个训练步骤中用于计算梯度的样本数量)受到 GPU 内存容量的限制。通过梯度累积,模型首先计算较小批次的梯度。然后,梯度不是立即更新模型权重,而是在多个小批量中累积。当累积梯度等于目标较大批量大小时,执行优化步骤以更新模型。这使得模型可以有效地进行更大批量的训练,而不会超出 GPU 内存限制。
然而,较小批量的前向和后向传递需要额外的计算。通过梯度累积增加批量大小可能会减慢训练速度,特别是在使用太多累积步骤的情况下。目的是最大限度地提高 GPU 使用率,同时避免过多额外的梯度计算步骤导致速度过慢。
方法 3:梯度检查点
梯度检查点是一种减少训练期间所需内存的技术,同时保持合理的计算时间。大型神经网络占用大量内存,因为它们必须存储前向传递的所有中间值,以便计算后向传递期间的梯度。这可能会导致内存问题。一种解决方案是不存储这些中间值,但是在向后传递期间必须重新计算它们,这需要大量时间。
梯度检查点提供了一种平衡的方法。它只保存一些中间值,称为 检查站,并根据需要重新计算其他。因此,它比存储所有内容使用的内存更少,而且比重新计算所有内容使用的计算量也更少。通过策略性地选择检查点的激活,梯度检查点使大型神经网络能够以可管理的内存使用和计算时间进行训练。这项重要的技术使得训练非常大的模型成为可能,否则这些模型会遇到内存限制。
在我们的训练脚本中,我们通过向 TrainingArguments
宾语:
from transformers import TrainingArguments
training_args = TrainingArguments(
gradient_accumulation_steps=4,
gradient_checkpointing=True
)
方法 4:LLM 的低阶适应
像 ESM-2 这样的大型语言模型可以包含数十亿个参数,这些参数的训练和运行成本很高。 研究人员 开发了一种称为低秩适应(LoRA)的训练方法,可以更有效地微调这些庞大的模型。
LoRA 背后的关键思想是,在针对特定任务微调模型时,不需要更新所有原始参数。相反,LoRA 向模型中添加了新的较小矩阵来转换输入和输出。微调期间仅更新这些较小的矩阵,这样速度更快并且使用更少的内存。原始模型参数保持冻结。
使用 LoRA 进行微调后,您可以将小型适应矩阵合并回原始模型。或者,如果您想快速调整其他任务的模型而不忘记以前的任务,则可以将它们分开。总体而言,LoRA 允许法学硕士以平常成本的一小部分有效地适应新任务。
在我们的训练脚本中,我们使用以下命令配置 LoRA PEFT
来自拥抱脸的库:
from peft import get_peft_model, LoraConfig, TaskType
import torch
from transformers import EsmForSequenceClassification
model = EsmForSequenceClassification.from_pretrained(
“facebook/esm2_t33_650M_UR50D”,
Torch_dtype=torch.bfloat16,
Num_labels=2,
)
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
bias="none",
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=[
"query",
"key",
"value",
"EsmSelfOutput.dense",
"EsmIntermediate.dense",
"EsmOutput.dense",
"EsmContactPredictionHead.regression",
"EsmClassificationHead.dense",
"EsmClassificationHead.out_proj",
]
)
model = get_peft_model(model, peft_config)
提交 SageMaker 培训作业
定义训练脚本后,您可以配置并提交 SageMaker 训练作业。首先,指定超参数:
hyperparameters = {
"model_id": "facebook/esm2_t33_650M_UR50D",
"epochs": 1,
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 4,
"use_gradient_checkpointing": True,
"lora": True,
}
接下来,定义要从训练日志中捕获哪些指标:
metric_definitions = [
{"Name": "epoch", "Regex": "'epoch': ([0-9.]*)"},
{
"Name": "max_gpu_mem",
"Regex": "Max GPU memory use during training: ([0-9.e-]*) MB",
},
{"Name": "train_loss", "Regex": "'loss': ([0-9.e-]*)"},
{
"Name": "train_samples_per_second",
"Regex": "'train_samples_per_second': ([0-9.e-]*)",
},
{"Name": "eval_loss", "Regex": "'eval_loss': ([0-9.e-]*)"},
{"Name": "eval_accuracy", "Regex": "'eval_accuracy': ([0-9.e-]*)"},
]
最后,定义一个 Hugging Face 估计器并将其提交以在 ml.g5.2xlarge 实例类型上进行训练。这是一种经济高效的实例类型,在许多 AWS 区域中广泛使用:
from sagemaker.experiments.run import Run
from sagemaker.huggingface import HuggingFace
from sagemaker.inputs import TrainingInput
hf_estimator = HuggingFace(
base_job_name="esm-2-membrane-ft",
entry_point="lora-train.py",
source_dir="scripts",
instance_type="ml.g5.2xlarge",
instance_count=1,
transformers_version="4.28",
pytorch_version="2.0",
py_version="py310",
output_path=f"{S3_PATH}/output",
role=sagemaker_execution_role,
hyperparameters=hyperparameters,
metric_definitions=metric_definitions,
checkpoint_local_path="/opt/ml/checkpoints",
sagemaker_session=sagemaker_session,
keep_alive_period_in_seconds=3600,
tags=[{"Key": "project", "Value": "esm-fine-tuning"}],
)
with Run(
experiment_name=EXPERIMENT_NAME,
sagemaker_session=sagemaker_session,
) as run:
hf_estimator.fit(
{
"train": TrainingInput(s3_data=train_s3_uri),
"test": TrainingInput(s3_data=test_s3_uri),
}
)
下表比较了我们讨论的不同训练方法及其对我们作业的运行时间、准确性和 GPU 内存要求的影响。
配置 | 计费时间(分钟) | 评估准确度 | 最大 GPU 内存使用量 (GB) |
基本型号 | 28 | 0.91 | 22.6 |
基础+GA | 21 | 0.90 | 17.8 |
碱+气相色谱 | 29 | 0.91 | 10.2 |
基础+LoRA | 23 | 0.90 | 18.6 |
所有方法都产生了具有高评估精度的模型。使用 LoRA 和梯度激活分别将运行时间(和成本)减少了 18% 和 25%。使用梯度检查点将最大 GPU 内存使用量降低了 55%。根据您的限制(成本、时间、硬件),其中一种方法可能比另一种更有意义。
这些方法单独使用时效果都很好,但是当我们组合使用它们时会发生什么呢?下表总结了结果。
配置 | 计费时间(分钟) | 评估准确度 | 最大 GPU 内存使用量 (GB) |
所有方法 | 12 | 0.80 | 3.3 |
在这种情况下,我们发现准确度降低了 12%。然而,我们将运行时间减少了 57%,GPU 内存使用量减少了 85%!这是一个巨大的减少,使我们能够在各种经济高效的实例类型上进行训练。
清理
如果您使用自己的 AWS 账户进行操作,请删除您创建的所有实时推理终端节点和数据,以避免产生进一步费用。
predictor.delete_endpoint()
bucket = boto_session.resource("s3").Bucket(S3_BUCKET)
bucket.objects.filter(Prefix=S3_PREFIX).delete()
结论
在这篇文章中,我们演示了如何有效地微调 ESM-2 等蛋白质语言模型来完成科学相关的任务。有关使用 Transformers 和 PEFT 库训练 pLMS 的更多信息,请查看帖子 蛋白质深度学习 和 ESMBind (ESMB):ESM-2 的低阶适应,用于蛋白质结合位点预测 在拥抱脸博客上。您还可以在以下位置找到更多使用机器学习来预测蛋白质特性的示例: AWS 上出色的蛋白质分析 GitHub 存储库。
关于作者
布赖恩·洛亚尔 是 Amazon Web Services 全球医疗保健和生命科学团队的高级 AI/ML 解决方案架构师。 他在生物技术和机器学习领域拥有超过 17 年的经验,热衷于帮助客户解决基因组和蛋白质组学挑战。 在业余时间,他喜欢与朋友和家人一起烹饪和用餐。
- :具有
- :是
- :不是
- $UP
- 07
- 1
- 100
- 17
- 20
- 2022
- 2023
- 22
- 28
- 3d
- 425
- 600
- 7
- 750
- 8
- a
- 关于
- 账号管理
- 积累
- 积累
- 准确
- 活化
- 激活
- 适应
- 适应
- 调整
- 添加
- 地址
- 添加
- 高级
- AI / ML
- 瞄准
- 所有类型
- 允许
- 沿
- 字母
- 还
- Amazon
- 亚马逊SageMaker
- 亚马逊网络服务
- 量
- an
- 分析
- 分析
- 和
- 另一个
- 回答
- 任何
- 的途径
- 方法
- 架构
- 保健
- AS
- At
- 可使用
- 避免
- 远离
- AWS
- 背部
- 均衡
- BE
- 因为
- 背后
- 之间
- 偏见
- 大
- 十亿美元
- 捆绑
- 生物技术
- 博客
- 身体
- 借
- 都
- 布赖恩
- 建立
- 但是
- by
- 计算
- 计算
- 计算
- 被称为
- CAN
- 容量
- 捕获
- 案件
- 例
- 原因
- 细胞
- 链
- 链
- 挑战
- 挑战
- 字符
- 收费
- 聊天机器人
- 便宜
- 查
- 化学
- 程
- 码
- 组合
- 公司
- 计算
- 计算
- 约束
- 包含
- 集装箱
- 包含
- 价格
- 经济有效
- 可以
- 创建信息图
- 创建
- 习俗
- 合作伙伴
- data
- 数据库
- 数据集
- 减少
- 下降
- 定义
- 定义
- 演示
- 证明
- 根据
- 部署
- 描述
- 设计
- 尽管
- 决心
- 发达
- 研发支持
- 设备
- 图表
- 不同
- 发现
- 讨论
- 疾病
- 分布
- 别
- 向下
- 下载
- 药物
- 毒品
- ,我们将参加
- 每
- 效果
- 只
- 效率
- 高效
- 有效
- 或
- 其他
- 使
- 编码
- 端点
- 巨大
- 时代
- 时代
- 等于
- 等于
- 特别
- 评估
- 甚至
- 一切
- 例子
- 例子
- 超额
- 过多
- 现有
- 昂贵
- 体验
- 实验
- 专门知识
- 探索
- 额外
- 面部彩妆
- 家庭
- 快
- 可行
- 精选
- 特征
- 少数
- 数字
- 过滤
- 终于
- 找到最适合您的地方
- 结束
- 姓氏:
- 以下
- 针对
- 幸好
- 向前
- 分数
- 骨架
- 朋友
- 止
- 冻结
- 功能
- 进一步
- 得到
- GitHub上
- 特定
- 给
- 全球
- Go
- 非常好
- GPU
- 图形处理器
- 渐变
- 长大的
- 发生
- 硬件
- 有
- he
- 医疗保健
- HealthTech
- 帮助
- 帮助
- 帮助
- 高
- 高品质
- 他的
- 创新中心
- How To
- 但是
- HTML
- HTTP
- HTTPS
- 巨大
- 拥抱脸
- 人
- 伤害
- 主意
- if
- 说明
- 进口
- 重要
- 改善
- in
- 包括
- 增加
- 感染
- 信息
- 输入
- 内
- 例
- 代替
- 相互作用
- 成
- 问题
- IT
- 工作
- 保持
- 保持
- 键
- 神的国
- 已知
- 标签
- 语言
- 大
- 大
- 学习
- 减
- 让
- 库
- 自学资料库
- 生活
- 生命科学
- 生命科学
- 喜欢
- 极限
- 限制
- 有限
- 生活
- 本地化
- 地点
- 长
- 长时间
- 离
- 占地
- 低
- 忠诚
- 机
- 机器学习
- 机
- 制成
- 多数
- 使
- 制作
- 管理
- 管理
- 许多
- 大规模
- 最大
- 生产力
- 最多
- 可能..
- 意
- 手段
- 内存
- 合并
- 方法
- 方法
- 指标
- 分钟
- ML
- 模型
- 造型
- 模型
- 分子
- 更多
- 更高效
- 移动
- 基因
- 许多
- 多
- 姓名
- 自然
- 必要
- 需求
- 打印车票
- 网络
- 神经
- 神经网络
- 全新
- 不包含
- 小说
- 数
- 对象
- 对象
- of
- on
- 一
- 那些
- 仅由
- 优化
- 优化
- or
- 秩序
- 订单
- 原版的
- 本来
- 其他名称
- 其它
- 除此以外
- 我们的
- 输出
- 输出
- 学校以外
- 超过
- 最划算
- 己
- 参数
- 部分
- 通过
- 通行证
- 多情
- 过去
- 演出
- 执行
- 片
- 柏拉图
- 柏拉图数据智能
- 柏拉图数据
- 播放
- 加
- 可能
- 帖子
- 帖子
- 潜力
- 可能
- 预测
- Prepare
- 防止
- 以前
- 市场问题
- 处理
- 生成
- 项目
- 蛋白质
- 蛋白质
- 提供
- 国家
- 放
- 询问
- 有疑问吗?
- 很快
- 范围
- 排名
- 达到
- 实时的
- 合理
- 记录
- 减少
- 减少
- 减少
- 减少
- 正则表达式
- 地区
- 相应
- 去掉
- 更换
- 知识库
- 代表
- 代表
- 岗位要求
- 研究人员
- 资源
- 分别
- 回应
- 提供品牌战略规划
- 成果
- 回报
- 评论
- 右
- 角色
- 运行
- 运行
- sagemaker
- 销售
- 保存
- 科学
- 科学
- .
- 划伤
- 脚本
- 脚本
- 部分
- 看到
- 选择
- 自
- 前辈
- 感
- 分开
- 序列
- 特色服务
- 套数
- 几个
- 如图
- 同样
- 简易
- 模拟
- 网站
- 尺寸
- 尺寸
- 略有不同
- 放慢
- 减速
- 小
- 小
- 方案,
- 解决方案
- 解决
- 一些
- 来源
- 具体的
- 速度
- 分裂
- 标准
- 开始
- 国家的最先进的
- 留
- 稳步
- 步
- 步骤
- 存储
- 商店
- 存储
- 从战略
- 串
- 结构体
- 结构
- 提交
- 足够
- 表
- 采取
- 需要
- 目标
- 目标
- 任务
- 任务
- 团队
- 技术
- 技术
- test
- 文本
- 比
- 这
- 其
- 他们
- 他们自己
- 然后
- 疗法
- 那里。
- 因此
- 博曼
- 他们
- 事
- Free Introduction
- 千
- 通过
- 次
- 时
- 至
- 令牌化
- 也有
- 工具
- 火炬
- 合计
- 培训
- 熟练
- 产品培训
- 改造
- 变形金刚
- 试用
- true
- 尝试
- 调音
- 转
- 类型
- 类型
- 一般
- 理解
- 无名
- 不必要
- 更新
- 更新
- 更新
- us
- 用法
- USD
- 使用
- 用过的
- 使用
- 运用
- 通常
- 折扣值
- 价值观
- 各种
- 非常
- 通过
- 想
- we
- 卷筒纸
- Web服务
- 井
- 为
- 什么是
- ,尤其是
- 是否
- 这
- 而
- 宽
- 大范围
- 广泛
- 也完全不需要
- 工作
- 锻炼
- 工作流程
- 将
- 书面
- X
- 年
- 完全
- 您一站式解决方案
- 和风网