spmd#

flax.nnx.get_partition_spec(tree)[來源]#

從包含 Variable 值的 PyTree 中提取 PartitionSpec 樹。

flax.nnx.get_named_sharding(tree, mesh)[來源]#
flax.nnx.with_partitioning(initializer, sharding, mesh=None, **metadata)[來源]#
flax.nnx.with_sharding_constraint(x, axis_resources, mesh=None)[來源]#