Jax 反编译器:重新定义基于梯度的软件设计
我们通过向 JAX 增加自动区分高阶函数(functional 和 operators)的能力,将函数表示为数组的一种泛化,采用 JAX 现有的原始系统来实现高阶函数。我们提供了一组原始操作符,作为构建几种关键类型的 functional 的基础构建块。对于每个引入的原始操作符,我们推导和实现了线性化和转置规则,与 JAX 内部的前向和反向模式自动区分协议保持一致。这种增强功能允许在传统用于函数的语法中进行 functional 区分。得到的 functional 梯度本身就是可以在 Python 中调用的函数。我们通过一些应用展示了这个工具的效能和简洁性,这些应用中 functional 导数是不可或缺的。本研究的源代码可在此 https URL 中获取。
Nov, 2023
该研究论文通过利用向量化、即时编译和静态图优化等语言基元,显著减少了执行 DPSGD 的运行时间开销,实现了与最佳非私有运行时间几乎相当的结果,从而实现多达 50 倍的加速和重要的内存和运行时改进。
Oct, 2020
本论文介绍了一种基于 JAX 的裁剪和稀疏训练库 JaxPruner,旨在通过提供流行的裁剪和稀疏训练算法的简洁实现,减少内存和延迟开销,从而加速稀疏神经网络的研究,此外,JaxPruner 的算法使用公共 API,并与流行的优化库 Optax 无缝配合,使其易于与现有的基于 JAX 的库集成。本论文提供了四个不同代码库的示例,并在流行的基准测试上提供了基准实验。
Apr, 2023
JAXbind 旨在通过提供易于使用的 Python 接口,将用其他编程语言实现的自定义函数绑定到 JAX,从而大大减少了绑定自定义函数到 JAX 所需的工作量。
Mar, 2024
介绍了 Equinox,它是一种小型神经网络库,采用类似 PyTorch 的基于类的方法,同时使用 JAX-like 的函数编程,它通过 PyTree 和函数变换解决了 OO/functional 的差异而不需要引入新的程序抽象。
Oct, 2021
BlackJAX 是一个库,用于实现在贝叶斯计算中常用的采样和变分推断算法。它采用了功能性的方法来实现算法,以便于使用、速度快以及具备模块化。
Feb, 2024
QDax 是一个开源库,具有简化且模块化的 API,用于 Quality-Diversity (QD) 优化算法在 Jax 中。该库可用于各种优化目的,从黑盒优化到连续控制。QDax 提供了流行的 QD、神经进化和增强学习算法的实现,支持各种示例。所有这些实现都可以使用 Jax 进行即时编译,以便在多个加速器(包括 GPU 和 TPU)上进行高效执行。这些实现有效地展示了该框架的灵活性和用户友好性,为研究目的简化了实验。此外,该库有详细的文档,并通过测试覆盖率达到 95%。
Aug, 2023
本文介绍了一个硬件优化的数据流架构,用于将计算图形的高阶梯度转化为硬件优化;该架构通过设计一个使用 FIFO 流和优化计算内核库的数据流架构,并提出一个编译器来提取和优化计算图形,以实现最大吞吐量,同时确保无死锁操作,并输出 FPGA 实现的高级综合(HLS)代码,从而实现了 1.8-4.8 倍和 1.5-3.6 倍的加速比,以及较低的内存使用率和能耗延迟乘积。
Aug, 2023
本文介绍了一种新的学习方法,称为 NeurDP,以处理被编译优化的二进制文件的反编译问题,并以图神经网络模型将低级程序语言转换为中间表示,从而实现更好的翻译性能,经评估结果表明,NeurDP 可以比最先进的神经反编译框架的准确率高出 45.21%。
Jan, 2023
通过 GPU 加速计算及发布基于 JAX 的 evosax 库支持的 30 种演化式算法及硬件优化,本文探究了深度学习和演化式优化间的结合,以进一步推动黑箱优化算法的发展。
Dec, 2022