Rieoptax: Riemannian Optimization in JAX
From MaRDI portal
Publication:6413476
arXiv2210.04840MaRDI QIDQ6413476
Author name not available (Why is that?)
Publication date: 10 October 2022
Abstract: We present Rieoptax, an open source Python library for Riemannian optimization in JAX. We show that many differential geometric primitives, such as Riemannian exponential and logarithm maps, are usually faster in Rieoptax than existing frameworks in Python, both on CPU and GPU. We support various range of basic and advanced stochastic optimization solvers like Riemannian stochastic gradient, stochastic variance reduction, and adaptive gradient methods. A distinguishing feature of the proposed toolbox is that we also support differentially private optimization on Riemannian manifolds.
Has companion code repository: https://github.com/saitejautpala/rieoptax
This page was built for publication: Rieoptax: Riemannian Optimization in JAX
Report a bug (only for logged in users!)Click here to report a bug for this page (MaRDI item Q6413476)