Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code examples for ix_ function with better doc string. #20327

Merged
merged 1 commit into from May 6, 2024

Conversation

selamw1
Copy link
Contributor

@selamw1 selamw1 commented Mar 20, 2024

No description provided.

Copy link

google-cla bot commented Mar 20, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@selamw1
Copy link
Contributor Author

selamw1 commented Mar 20, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

It's already signed.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but please note that the implements decorator will overwrite any docstring you define in the function. If you want to be able to add a customized docstring, we'll have to modify that decorator (or remove it, but I believe we have tests to check that all public APIs have this decorator)

Also, it looks like this content was copied from the text within numpy repository. We cannot have verbatim copies of other open source projects as part of the JAX source code, because it violates the OSS license requirements. If you want to write a docstring for this function, I'd suggest doing it from scratch, without referencing the original NumPy docstring. Alternatively, we can put the text in jax/third_party and reference it here, although that's a bit awkward in the case of docstrings.

This licensing issue is one of the reasons we have not done this kind of docstring specialization in the past, instead relying on runtime rewrites via arguments to implements.

@selamw1
Copy link
Contributor Author

selamw1 commented Apr 25, 2024

Hi @jakevdp

Is the following modification a doable solution for inserting embedded code examples (customized to JAX) for NumPy and SciPy?

We would use lax_description for inserting examples and leave the implements decorator as it is.

@util.implements(np.ix_, lax_description="""Example:
    >>> import jax.numpy as jnp
    >>> a = jnp.arange(10).reshape(2, 5)
    >>> a
    Array([[0, 1, 2, 3, 4],
           [5, 6, 7, 8, 9]], dtype=int32)
    >>> ixgrid = jnp.ix_(jnp.array([0, 1]), jnp.array([2, 4]))
    >>> ixgrid
    (Array([[0],
           [1]], dtype=int32), Array([[2, 4]], dtype=int32))
    >>> ixgrid[0].shape, ixgrid[1].shape
    ((2, 1), (1, 2))
    >>> a[ixgrid]
    Array([[2, 4],
           [7, 9]], dtype=int32)""")
def ix_(*args: ArrayLike) -> tuple[Array, ...]:
  util.check_arraylike("ix", *args)
Screenshot 2024-04-25 at 2 09 51 PM

If not, I will do it from scratch as you suggested earlier, without referencing the original NumPy docstring.

Since the CLA has failed here, I will close and raise another PR.

@selamw1 selamw1 requested a review from jakevdp April 25, 2024 22:23
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 25, 2024

Typically the example comes below the description of the function parameters. I think in this case it would be best to rewrite the docstring from scratch. I'm doing some similar work in #20941 – you can do something similar for this case.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 25, 2024

Oh, and please don't open a new PR if possible – it's useful to keep all the conversation in one place. The failing CLA is probably due to whatever email you have attached to the commits on your branch: you can undo previous commits and redo new commits with the correct email in order to update this branch

@jakevdp jakevdp self-assigned this Apr 25, 2024
@selamw1 selamw1 force-pushed the add_examples branch 3 times, most recently from 2c2893b to c6b8554 Compare April 26, 2024 01:49
@selamw1
Copy link
Contributor Author

selamw1 commented Apr 26, 2024

It make sene, docstring is rewritten from scratch and tried to squash the commits.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good! A few formatting pieces below (you might try building the docs locally to make sure the formatting is acceptable, rather than waiting for the github CI).

Also, you might add a See also section that links to meshgrid, mgrid, and ogrid

jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
@selamw1
Copy link
Contributor Author

selamw1 commented Apr 27, 2024

Thanks @jakevdp

'See also' section added and indentations corrected. Locally, the formatting appears as follows:

image

jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
@selamw1
Copy link
Contributor Author

selamw1 commented Apr 30, 2024

Thanks @jakevdp

  • Description is adjusted:
    • ndarray is modified to "Tuple of JAX arrays".
    • integer sequences is modified to integer arrays.
  • New example is added.

jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
@selamw1
Copy link
Contributor Author

selamw1 commented Apr 30, 2024

unnecessary texts and example outputs are modified/removed.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small change requested below. Also

Please run the linter on your change (see https://jax.readthedocs.io/en/latest/contributing.html#linting-and-type-checking)

I think you'll have to add ix_ to this list to satisfy tests:

known_exceptions = {

Finally, once you've made all changes, please squash everything into a single commit; see https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests for details.

jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented May 1, 2024

Looks like there are some lint issues with whitespace. You can fix these locally by running pre-commit: https://jax.readthedocs.io/en/latest/contributing.html#linting-and-type-checking

@selamw1
Copy link
Contributor Author

selamw1 commented May 1, 2024

Looks like there are some lint issues with whitespace. You can fix these locally by running pre-commit: https://jax.readthedocs.io/en/latest/contributing.html#linting-and-type-checking

It seems the lint failure was below:
trim trailing whitespace.................................................Failed

removing whitespace before the following lines:
def ix_(*args: ArrayLike) -> tuple[Array, ...]:
and
util.check_arraylike("ix", *args)

resolve the issue?

running pre-commit locally prints the following error:

      Preparing metadata (pyproject.toml): finished with status 'error'
stderr:
      error: subprocess-exited-with-error
      
      × Preparing metadata (pyproject.toml) did not run successfully.
      │ exit code: 1
      ╰─> [39 lines of output]
          INFO:hatch_jupyter_builder.utils:Running jupyter-builder

@jakevdp
Copy link
Collaborator

jakevdp commented May 1, 2024

Yeah, removing all trailing whitespace should fix the error. Regarding the pre-commit installation error, it's probably some kind of version incompatibility. Hard to say without seeing more of the error or logs.

@jakevdp
Copy link
Collaborator

jakevdp commented May 3, 2024

It seems like some other commits snuck into your PR! It looks like you both rebased and merged, but targeted different snapshots of the main branch. I'd try fixing it like this:

$ git remote -v  # just to make sure it's clear what upstream and origin are here
origin	git@github.com:selamw1/jax.git
upstream	git@github.com:google/jax.git
$ git checkout main
$ git pull upstream main  # this should fast-forward cleanly if you haven't added commits to your local main branch
$ git checkout add_examples
$ git rebase main  # this should rebase cleanly if you haven't done anything strange
$ git log  # make sure you only have a single commit

In most cases this should work, though admittedly I don't really know what you've done to your local branch, so if you've done something really strange (like merged against branches with conflicts) then you may have to do more to fix it.

@jakevdp
Copy link
Collaborator

jakevdp commented May 3, 2024

It seems like some other commits snuck into your PR! It looks like you both rebased and merged, but targeted different snapshots of the main branch in each operation. I'd try fixing it like this:

$ git remote -v  # just to make sure it's clear what upstream and origin are here
origin	git@github.com:selamw1/jax.git
upstream	git@github.com:google/jax.git
$ git checkout main
$ git pull upstream main  # this should fast-forward cleanly if you haven't added commits to your local main branch
$ git checkout add_examples
$ git rebase main  # this should rebase cleanly if you haven't done anything strange
$ git log  # make sure you only have a single commit
$ git push origin +add_examples

In most cases this should work, though admittedly I don't really know what you've done to your local branch, so if you've done something really strange (like merged against branches with conflicts) then you may have to do more to fix it.

@google-ml-butler google-ml-butler bot added the pull ready Ready for copybara import and testing label May 6, 2024
@copybara-service copybara-service bot merged commit f6d8852 into google:main May 6, 2024
9 checks passed
@selamw1 selamw1 deleted the add_examples branch May 6, 2024 22:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants