Skip to content

Commit

Permalink
mlx - merge master into mlx (#19657)
Browse files Browse the repository at this point in the history
* Introduce float8 training (#19488)

* Add float8 training support

* Add tests for fp8 training

* Add `quantize_and_dequantize` test

* Fix bugs and add float8 correctness tests

* Cleanup

* Address comments and cleanup

* Add docstrings and some minor refactoring

* Add `QuantizedFloat8DTypePolicy`

* Add dtype policy setter

* Fix torch dynamo issue by using `self._dtype_policy`

* Improve test coverage

* Add LoRA to ConvND layers (#19516)

* Add LoRA to `BaseConv`

* Add tests

* Fix typo

* Fix tests

* Fix tests

* Add path to run keras on dm-tree when optree is not available.

* feat(losses): add Tversky loss implementation (#19511)

* feat(losses): add Tversky loss implementation

* adjusted documentation

* Update KLD docs

* Models and layers now return owned metrics recursively. (#19522)

- added `Layer.metrics` to return all metrics owned by the layer and its sub-layers recursively.
- `Layer.metrics_variables` now returns variables from all metrics recursively, not just the layer and its direct sub-layers.
- `Model.metrics` now returns all metrics recursively, not just the model level metrics.
- `Model.metrics_variables` now returns variables from all metrics recursively, not just the model level metrics.
- added test coverage to test metrics and variables 2 levels deep.

This is consistent with the Keras 2 behavior and how `Model/Layer.variables` and `Model/Layer.weights` work.

* Update IoU ignore_class handling

* Fix `RandomBrightness`, Enhance `IndexLookup` Initialization and Expand Test Coverage for `Preprocessing Layers` (#19513)

* Add tests for CategoryEncoding class in category_encoding_test.py

* fix

* Fix IndexLookup class initialization and add test cases

* Add test case for IndexLookupLayerTest without vocabulary

* Fix IndexLookup class initialization

* Add normalization test cases

* Add test cases for Hashing class

* Fix value range validation error in RandomBrightness class

* Refactor IndexLookup class initialization and add test cases

* Reffix ndexLookup class initialization and afix est cases

* Add test for spectral norm

* Add missing test decorator

* Fix torch test

* Fix code format

* Generate API (#19530)

* API Generator for Keras

* API Generator for Keras

* Generates API Gen via api_gen.sh

* Remove recursive import of _tf_keras

* Generate API Files via api_gen.sh

* Update APIs

* Added metrics from custom `train_step`/`test_step` are now returned. (#19529)

This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics.

* Use temp dir and abs path in `api_gen.py` (#19533)

* Use temp dir and abs path

* Use temp dir and abs path

* Update Readme

* Update API

* Fix gradient accumulation when using `overwrite_with_gradient` during float8 training (#19534)

* Fix gradient accumulation with `overwrite_with_gradient` in float8 training

* Add comments

* Fix annotation

* Update code path in ignore path (#19537)

* Add operations per run (#19538)

* Include input shapes in model visualization.

* Add pad_to_aspect_ratio feature in ops.image.resize

* Add pad_to_aspect_ratio feature in Resizing layer.

* Fix incorrect usage of `quantize` (#19541)

* Add logic to prevent double quantization

* Add detailed info for double quantization error

* Update error msg

* Add eigh op.

* Add keepdim in argmax/argmin.

* Fix small bug in model.save_weights (#19545)

* Update public APIs.

* eigh should work on JAX GPU

* Copy init to keras/__init__.py (#19551)

* Revert "Copy init to keras/__init__.py (#19551)" (#19552)

This reverts commit da9af61.

* sum-reduce inlined losses

* Remove the dependency on `tensorflow.experimental.numpy` and support negative indices for `take` and `take_along_axis` (#19556)

* Remove `tfnp`

* Update numpy api

* Improve test coverage

* Improve test coverage

* Fix `Tri` and `Eye` and increase test converage

* Update `round` test

* Fix `jnp.round`

* Fix `diag` bug for iou_metrics

* Add op.select.

* Add new API for select

* Make `ops.abs` and `ops.absolute` consistent between backends. (#19563)

- The TensorFlow implementation was missing `convert_to_tensor`
- The sparse annotation was unnecessarily applied twice
- Now `abs` calls `absolute` in all backends

Also fixed TensorFlow `ops.select`.

* Add pickle support for Keras model (#19555)

* Implement unit tests for pickling

* Reformat model_test

* Reformat model_test

* Rename depickle to unpickle

* Rename depickle to unpickle

* Reformat

* remove a comment

* Ellipsis Serialization and tests (#19564)

* Serialization and tests

* Serialization and tests

* Serialization and tests

* Make TF one_hot input dtype less strict.

* Fix einsum `_int8_call` (#19570)

* CTC Decoding for JAX and Tensorflow (#19366)

* Tensorflow OP for CTC decoding

* JAX op for CTC greedy decoding

* Update CTC decoding documentation

* Fix linting issues

* Fix trailing whitespace

* Simplify returns in tensorflow CTC wrapper

* Fix CTC decoding error messages

* Fix line too long

* Bug fixes to JAX CTC greedy decoder

* Force int typecast in TF CTC decoder

* Unit tests for CTC greedy decoding

* Add unit test for CTC beam search decoding

* Fix mask index set location in JAX CTC decoding

* CTC beam search decoding for JAX

* Fix unhandled token repetitions in ctc_beam_search_decode

* Fix merge_repeated bug in CTC beam search decode

* Fix beam storage and repetition bugs in JAX ctc_decode

* Remove trailing whitespace

* Fix ordering bug for ties in JAX CTC beam search

* Cast sequence lengths to integers in JAX ctc_decode

* Remove line break in docstring

* CTC beam search decoding for JAX

* Fix unhandled token repetitions in ctc_beam_search_decode

* Fix merge_repeated bug in CTC beam search decode

* Fix beam storage and repetition bugs in JAX ctc_decode

* Fix ordering bug for ties in JAX CTC beam search

* Generate public api directory

* Add not implemented errors for NumPy and Torch CTC decoding

* Remove unused redefinition of JAX ctc_beam_search_decode

* Docstring edits

* Expand nan_to_num args.

* Add vectorize op.

* list insert requires index (#19575)

* Add signature and exclude args to knp.vectorize.

* Fix the apis of `dtype_polices` (#19580)

* Fix api of `dtype_polices`

* Update docstring

* Increase test coverage

* Fix format

* Fix keys of `save_own_variables` and `load_own_variables` (#19581)

* Fix JAX CTC test.

* Fix loss_weights handling in single output case

* Fix JAX vectorize.

* Move _tf_keras directory to the root of the pip package.

* One time fix to _tf_keras API.

* Convert return type imdb.load_data to nparray (#19598)

Convert return type imdb.load_data to Numpy array. Currently X_train and X-test returned as list.

* Fix typo

* fix api_gen.py for legacy (#19590)

* fix api_gen.py for legacy

* merge api and legacy for _tf_keras

* Improve int8 for `Embedding` (#19595)

* pin torch < 2.3.0 (#19603)

* Clean up duplicated `inputs_quantizer` (#19604)

* Cleanup duplicated `inputs_quantizer` and add type check for `input_spec` and `supports_masking`

* Revert setter

* output format changes and errors in github (#19608)

* Provide write permission to action for cache management. (#19606)

* Pickle support for all saveables (#19592)

* Pickle support

* Add keras pickleable mixin

* Reformat

* Implement pickle all over

* reformat

* Reformat

* Keras saveable

* Keras saveable

* Keras saveable

* Keras saveable

* Keras saveable

* obj_type

* Update pickleable

* Saveable logic touchups

* Add slogdet op.

* Update APIs

* Remove unused import

* Refactor CTC APIs (#19611)

* Add `ctc_loss` and `ctc_decode` for numpy backend, improve imports and tests

* Support "beam_search" strategy for torch's `ctc_decode`

* Improve `ctc_loss`

* Cleanup

* Refactor `ctc_decode`

* Update docstring

* Update docstring

* Add `CTCDecode` operation and ensure dtype inference of `ctc_decode`

* Fix `name` of `losses.CTC`

* update the namex version requirements (#19617)

* Add `PSNR` API (#19616)

* PSNR

* Fix

* Docstring format

* Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps (#19618)

* Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps

* Formatting

* Implement custom layer insertion in clone_model. (#19610)

* Implement custom layer insertion in clone_model.

* Add recursive arg and tests.

* Add nested sequential cloning test

* Fix bidir lstm saving issue.

* Fix CI

* Fix cholesky tracing with jax

* made extract_patches dtype agnostic (#19621)

* Simplify Bidirectional implementation

* Add support for infinite `PyDataset`s. (#19624)

`PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled.

Fixes #19528

Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hang with no error reported.

* Fix dataset shuffling issue.

* Update version string.

* Minor fix

* Restore version string resolution in pip_build.

* Speed up `DataAdapter` tests by testing only the current backend. (#19625)

There is no use case for using an iterator for a different backend than the current backend.

Also:
- limit the number of tests using multiprocessing, the threading tests give us good coverage.
- fixed the `test_exception_reported` test, which was not actually exercising the multiprocessing / multithreading cases.
- removed unused `init_pool` method.

* feat(ops): support np.argpartition (#19588)

* feat(ops): support np.argpartition

* updated documentation, type-casting, and tf implementation

* fixed tf implementation

* added torch cast to int32

* updated torch type and API generated files

* added torch output type cast

* test(trainers): add test_errors implementation for ArrayDataAdapter class (#19626)

* Fix torch GPU CI

* Fix argmax/argmin keepdims with defined axis in TF

* Misc fixes in TF backend ops.

* Fix `argpartition` cuda bug in torch (#19634)

* fix(ops): specify NonZero output dtype and add test coverage (#19635)

* Fix `ops.ctc_decode` (#19633)

* Fix greedy ctc decode

* Remove print

* Fix `tf.nn.ctc_beam_search_decoder`

* Change default `mask_index` to `0`

* Fix losses test

* Update

* Ensure the same rule applies for np arrays in autocasting (#19636)

* Ensure the same rule applies for np arrays in autocasting

* Trigger CI by adding docstring

* Update

* Update docstring

* Fix `istft` and add class `TestMathErrors` in `ops/math_test.py` (#19594)

* Fix and test math functions for jax backend

* run /workspaces/keras/shell/format.sh

* refix

* fix

* fix _get_complex_tensor_from_tuple

* fix

* refix

* Fix istft function to handle inputs with less than 2 dimensions

* fix

* Fix ValueError in istft function for inputs with less than 2 dimensions

* Return a tuple from `ops.shape` with the Torch backend. (#19640)

With Torch, `x.shape` returns a `torch.Size`, which is a subclass of `tuple` but can cause different behaviors. In particular `convert_to_tensor` does not work on `torch.Size`.

This fixes #18900

* support conv3d on cpu for TF (#19641)

* Enable cudnn rnns when dropout is set (#19645)

* Enable cudnn rnns when dropout is set

* Fix

* Fix plot_model for input dicts.

* Fix deprecation warning in torch

* Bump the github-actions group with 2 updates (#19653)

Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [github/codeql-action](https://github.com/github/codeql-action).


Updates `actions/upload-artifact` from 4.3.1 to 4.3.3
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](actions/upload-artifact@5d5d22a...6546280)

Updates `github/codeql-action` from 3.24.9 to 3.25.3
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](github/codeql-action@1b1aada...d39d31e)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: github-actions
- dependency-name: github/codeql-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump the python group with 2 updates (#19654)

Bumps the python group with 2 updates: torch and torchvision.


Updates `torch` from 2.2.1+cu121 to 2.3.0+cu121

Updates `torchvision` from 0.17.1+cu121 to 0.18.0+cu121

---
updated-dependencies:
- dependency-name: torch
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
- dependency-name: torchvision
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Revert "Bump the python group with 2 updates (#19654)" (#19655)

This reverts commit 09133f4.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: james77777778 <20734616+james77777778@users.noreply.github.com>
Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Luca Pizzini <lpizzini7@gmail.com>
Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
Co-authored-by: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Sachin Prasad <sachinprasad@google.com>
Co-authored-by: Uwe Schmidt <uschmidt83@users.noreply.github.com>
Co-authored-by: Luke Wood <LukeWood@users.noreply.github.com>
Co-authored-by: Maanas Arora <maanasarora23@gmail.com>
Co-authored-by: AlexanderLavelle <73360008+AlexanderLavelle@users.noreply.github.com>
Co-authored-by: Surya <116063290+SuryanarayanaY@users.noreply.github.com>
Co-authored-by: Shivam Mishra <124146945+shmishra99@users.noreply.github.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: IMvision12 <88665786+IMvision12@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
Co-authored-by: Vachan V Y <109357590+VachanVY@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
19 people committed May 3, 2024
1 parent 4c90dfb commit 5a3542b
Show file tree
Hide file tree
Showing 107 changed files with 4,396 additions and 1,034 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/actions.yml
Expand Up @@ -126,3 +126,12 @@ jobs:
fi
- name: Lint
run: bash shell/lint.sh
- name: Check for API changes
run: |
bash shell/api_gen.sh
git status
clean=$(git status | grep "nothing to commit")
if [ -z "$clean" ]; then
echo "Please run shell/api_gen.sh to generate API."
exit 1
fi
10 changes: 10 additions & 0 deletions .github/workflows/nightly.yml
Expand Up @@ -92,6 +92,16 @@ jobs:
fi
- name: Lint
run: bash shell/lint.sh
- name: Check for API changes
run: |
bash shell/api_gen.sh
git status
clean=$(git status | grep "nothing to commit")
if [ -z "$clean" ]; then
echo "Please run shell/api_gen.sh to generate API."
exit 1
fi
nightly:
name: Build Wheel file and upload
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/scorecard.yml
Expand Up @@ -48,14 +48,14 @@ jobs:
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1
uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3
with:
name: SARIF file
path: results.sarif
retention-days: 5

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9
uses: github/codeql-action/upload-sarif@d39d31e687223d841ef683f52467bd88e9b21c14 # v3.25.3
with:
sarif_file: results.sarif
1 change: 1 addition & 0 deletions .github/workflows/stale-issue-pr.yaml
Expand Up @@ -10,6 +10,7 @@ jobs:
permissions:
issues: write
pull-requests: write
actions: write
steps:
- name: Awaiting response issues
uses: actions/stale@v9
Expand Down
2 changes: 1 addition & 1 deletion SECURITY.md
Expand Up @@ -59,7 +59,7 @@ Besides the virtual environment, the hardware (GPUs or TPUs) can also be attacke

## Reporting a Vulnerability

Beware that none of the topics under [Using Keras Securely](#using-Keras-securely) are considered vulnerabilities of Keras.
Beware that none of the topics under [Using Keras Securely](#using-keras-securely) are considered vulnerabilities of Keras.

If you have discovered a security vulnerability in this project, please report it
privately. **Do not disclose it as a public issue.** This gives us time to work with you
Expand Down
18 changes: 16 additions & 2 deletions api_gen.py
Expand Up @@ -7,6 +7,7 @@
"""

import os
import re
import shutil

import namex
Expand Down Expand Up @@ -78,8 +79,7 @@ def create_legacy_directory(package_dir):
for path in os.listdir(os.path.join(src_dir, "legacy"))
if os.path.isdir(os.path.join(src_dir, "legacy", path))
]

for root, _, fnames in os.walk(os.path.join(package_dir, "_legacy")):
for root, _, fnames in os.walk(os.path.join(api_dir, "_legacy")):
for fname in fnames:
if fname.endswith(".py"):
legacy_fpath = os.path.join(root, fname)
Expand Down Expand Up @@ -110,6 +110,20 @@ def create_legacy_directory(package_dir):
f"keras.api.{legacy_submodule}",
f"keras.api._tf_keras.keras.{legacy_submodule}",
)
# Remove duplicate generated comments string.
legacy_contents = re.sub(r"\n", r"\\n", legacy_contents)
legacy_contents = re.sub('""".*"""', "", legacy_contents)
legacy_contents = re.sub(r"\\n", r"\n", legacy_contents)
# If the same module is in legacy and core_api, use legacy
legacy_imports = re.findall(
r"import (\w+)", legacy_contents
)
for import_name in legacy_imports:
core_api_contents = re.sub(
f"\n.* import {import_name}\n",
r"\n",
core_api_contents,
)
legacy_contents = core_api_contents + "\n" + legacy_contents
with open(tf_keras_fpath, "w") as f:
f.write(legacy_contents)
Expand Down
11 changes: 5 additions & 6 deletions keras/api/_tf_keras/keras/__init__.py
Expand Up @@ -6,7 +6,6 @@

from keras.api import activations
from keras.api import applications
from keras.api import backend
from keras.api import callbacks
from keras.api import config
from keras.api import constraints
Expand All @@ -15,21 +14,21 @@
from keras.api import dtype_policies
from keras.api import export
from keras.api import initializers
from keras.api import layers
from keras.api import legacy
from keras.api import losses
from keras.api import metrics
from keras.api import mixed_precision
from keras.api import models
from keras.api import ops
from keras.api import optimizers
from keras.api import preprocessing
from keras.api import quantizers
from keras.api import random
from keras.api import regularizers
from keras.api import saving
from keras.api import tree
from keras.api import utils
from keras.api._tf_keras.keras import backend
from keras.api._tf_keras.keras import layers
from keras.api._tf_keras.keras import losses
from keras.api._tf_keras.keras import metrics
from keras.api._tf_keras.keras import preprocessing
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.exports import Variable
Expand Down
123 changes: 123 additions & 0 deletions keras/api/_tf_keras/keras/backend/__init__.py
Expand Up @@ -17,4 +17,127 @@
from keras.src.backend.config import set_epsilon
from keras.src.backend.config import set_floatx
from keras.src.backend.config import set_image_data_format
from keras.src.legacy.backend import abs
from keras.src.legacy.backend import all
from keras.src.legacy.backend import any
from keras.src.legacy.backend import arange
from keras.src.legacy.backend import argmax
from keras.src.legacy.backend import argmin
from keras.src.legacy.backend import batch_dot
from keras.src.legacy.backend import batch_flatten
from keras.src.legacy.backend import batch_get_value
from keras.src.legacy.backend import batch_normalization
from keras.src.legacy.backend import batch_set_value
from keras.src.legacy.backend import bias_add
from keras.src.legacy.backend import binary_crossentropy
from keras.src.legacy.backend import binary_focal_crossentropy
from keras.src.legacy.backend import cast
from keras.src.legacy.backend import cast_to_floatx
from keras.src.legacy.backend import categorical_crossentropy
from keras.src.legacy.backend import categorical_focal_crossentropy
from keras.src.legacy.backend import clip
from keras.src.legacy.backend import concatenate
from keras.src.legacy.backend import constant
from keras.src.legacy.backend import conv1d
from keras.src.legacy.backend import conv2d
from keras.src.legacy.backend import conv2d_transpose
from keras.src.legacy.backend import conv3d
from keras.src.legacy.backend import cos
from keras.src.legacy.backend import count_params
from keras.src.legacy.backend import ctc_batch_cost
from keras.src.legacy.backend import ctc_decode
from keras.src.legacy.backend import ctc_label_dense_to_sparse
from keras.src.legacy.backend import cumprod
from keras.src.legacy.backend import cumsum
from keras.src.legacy.backend import depthwise_conv2d
from keras.src.legacy.backend import dot
from keras.src.legacy.backend import dropout
from keras.src.legacy.backend import dtype
from keras.src.legacy.backend import elu
from keras.src.legacy.backend import equal
from keras.src.legacy.backend import eval
from keras.src.legacy.backend import exp
from keras.src.legacy.backend import expand_dims
from keras.src.legacy.backend import eye
from keras.src.legacy.backend import flatten
from keras.src.legacy.backend import foldl
from keras.src.legacy.backend import foldr
from keras.src.legacy.backend import gather
from keras.src.legacy.backend import get_value
from keras.src.legacy.backend import gradients
from keras.src.legacy.backend import greater
from keras.src.legacy.backend import greater_equal
from keras.src.legacy.backend import hard_sigmoid
from keras.src.legacy.backend import in_top_k
from keras.src.legacy.backend import int_shape
from keras.src.legacy.backend import is_sparse
from keras.src.legacy.backend import l2_normalize
from keras.src.legacy.backend import less
from keras.src.legacy.backend import less_equal
from keras.src.legacy.backend import log
from keras.src.legacy.backend import map_fn
from keras.src.legacy.backend import max
from keras.src.legacy.backend import maximum
from keras.src.legacy.backend import mean
from keras.src.legacy.backend import min
from keras.src.legacy.backend import minimum
from keras.src.legacy.backend import moving_average_update
from keras.src.legacy.backend import name_scope
from keras.src.legacy.backend import ndim
from keras.src.legacy.backend import not_equal
from keras.src.legacy.backend import one_hot
from keras.src.legacy.backend import ones
from keras.src.legacy.backend import ones_like
from keras.src.legacy.backend import permute_dimensions
from keras.src.legacy.backend import pool2d
from keras.src.legacy.backend import pool3d
from keras.src.legacy.backend import pow
from keras.src.legacy.backend import prod
from keras.src.legacy.backend import random_bernoulli
from keras.src.legacy.backend import random_normal
from keras.src.legacy.backend import random_normal_variable
from keras.src.legacy.backend import random_uniform
from keras.src.legacy.backend import random_uniform_variable
from keras.src.legacy.backend import relu
from keras.src.legacy.backend import repeat
from keras.src.legacy.backend import repeat_elements
from keras.src.legacy.backend import reshape
from keras.src.legacy.backend import resize_images
from keras.src.legacy.backend import resize_volumes
from keras.src.legacy.backend import reverse
from keras.src.legacy.backend import rnn
from keras.src.legacy.backend import round
from keras.src.legacy.backend import separable_conv2d
from keras.src.legacy.backend import set_value
from keras.src.legacy.backend import shape
from keras.src.legacy.backend import sigmoid
from keras.src.legacy.backend import sign
from keras.src.legacy.backend import sin
from keras.src.legacy.backend import softmax
from keras.src.legacy.backend import softplus
from keras.src.legacy.backend import softsign
from keras.src.legacy.backend import sparse_categorical_crossentropy
from keras.src.legacy.backend import spatial_2d_padding
from keras.src.legacy.backend import spatial_3d_padding
from keras.src.legacy.backend import sqrt
from keras.src.legacy.backend import square
from keras.src.legacy.backend import squeeze
from keras.src.legacy.backend import stack
from keras.src.legacy.backend import std
from keras.src.legacy.backend import stop_gradient
from keras.src.legacy.backend import sum
from keras.src.legacy.backend import switch
from keras.src.legacy.backend import tanh
from keras.src.legacy.backend import temporal_padding
from keras.src.legacy.backend import tile
from keras.src.legacy.backend import to_dense
from keras.src.legacy.backend import transpose
from keras.src.legacy.backend import truncated_normal
from keras.src.legacy.backend import update
from keras.src.legacy.backend import update_add
from keras.src.legacy.backend import update_sub
from keras.src.legacy.backend import var
from keras.src.legacy.backend import variable
from keras.src.legacy.backend import zeros
from keras.src.legacy.backend import zeros_like
from keras.src.utils.naming import get_uid
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/dtype_policies/__init__.py
Expand Up @@ -4,6 +4,9 @@
since your modifications would be overwritten.
"""

from keras.src.dtype_policies import deserialize
from keras.src.dtype_policies import get
from keras.src.dtype_policies import serialize
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
Expand Down
5 changes: 4 additions & 1 deletion keras/api/_tf_keras/keras/layers/__init__.py
Expand Up @@ -157,7 +157,6 @@
from keras.src.layers.regularization.activity_regularization import (
ActivityRegularization,
)
from keras.src.layers.regularization.alpha_dropout import AlphaDropout
from keras.src.layers.regularization.dropout import Dropout
from keras.src.layers.regularization.gaussian_dropout import GaussianDropout
from keras.src.layers.regularization.gaussian_noise import GaussianNoise
Expand Down Expand Up @@ -190,6 +189,10 @@
from keras.src.layers.rnn.simple_rnn import SimpleRNNCell
from keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells
from keras.src.layers.rnn.time_distributed import TimeDistributed
from keras.src.legacy.layers import AlphaDropout
from keras.src.legacy.layers import RandomHeight
from keras.src.legacy.layers import RandomWidth
from keras.src.legacy.layers import ThresholdedReLU
from keras.src.utils.jax_layer import FlaxLayer
from keras.src.utils.jax_layer import JaxLayer
from keras.src.utils.torch_utils import TorchModuleWrapper
19 changes: 13 additions & 6 deletions keras/api/_tf_keras/keras/losses/__init__.py
Expand Up @@ -4,6 +4,7 @@
since your modifications would be overwritten.
"""

from keras.src.legacy.losses import Reduction
from keras.src.losses import deserialize
from keras.src.losses import get
from keras.src.losses import serialize
Expand Down Expand Up @@ -38,12 +39,18 @@
from keras.src.losses.losses import dice
from keras.src.losses.losses import hinge
from keras.src.losses.losses import huber
from keras.src.losses.losses import kl_divergence
from keras.src.losses.losses import log_cosh
from keras.src.losses.losses import mean_absolute_error
from keras.src.losses.losses import mean_absolute_percentage_error
from keras.src.losses.losses import mean_squared_error
from keras.src.losses.losses import mean_squared_logarithmic_error
from keras.src.losses.losses import kl_divergence as KLD
from keras.src.losses.losses import kl_divergence as kld
from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence
from keras.src.losses.losses import log_cosh as logcosh
from keras.src.losses.losses import mean_absolute_error as MAE
from keras.src.losses.losses import mean_absolute_error as mae
from keras.src.losses.losses import mean_absolute_percentage_error as MAPE
from keras.src.losses.losses import mean_absolute_percentage_error as mape
from keras.src.losses.losses import mean_squared_error as MSE
from keras.src.losses.losses import mean_squared_error as mse
from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE
from keras.src.losses.losses import mean_squared_logarithmic_error as msle
from keras.src.losses.losses import poisson
from keras.src.losses.losses import sparse_categorical_crossentropy
from keras.src.losses.losses import squared_hinge
Expand Down
18 changes: 12 additions & 6 deletions keras/api/_tf_keras/keras/metrics/__init__.py
Expand Up @@ -11,12 +11,18 @@
from keras.src.losses.losses import categorical_hinge
from keras.src.losses.losses import hinge
from keras.src.losses.losses import huber
from keras.src.losses.losses import kl_divergence
from keras.src.losses.losses import log_cosh
from keras.src.losses.losses import mean_absolute_error
from keras.src.losses.losses import mean_absolute_percentage_error
from keras.src.losses.losses import mean_squared_error
from keras.src.losses.losses import mean_squared_logarithmic_error
from keras.src.losses.losses import kl_divergence as KLD
from keras.src.losses.losses import kl_divergence as kld
from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence
from keras.src.losses.losses import log_cosh as logcosh
from keras.src.losses.losses import mean_absolute_error as MAE
from keras.src.losses.losses import mean_absolute_error as mae
from keras.src.losses.losses import mean_absolute_percentage_error as MAPE
from keras.src.losses.losses import mean_absolute_percentage_error as mape
from keras.src.losses.losses import mean_squared_error as MSE
from keras.src.losses.losses import mean_squared_error as mse
from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE
from keras.src.losses.losses import mean_squared_logarithmic_error as msle
from keras.src.losses.losses import poisson
from keras.src.losses.losses import sparse_categorical_crossentropy
from keras.src.losses.losses import squared_hinge
Expand Down
6 changes: 6 additions & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Expand Up @@ -56,6 +56,7 @@
from keras.src.ops.nn import categorical_crossentropy
from keras.src.ops.nn import conv
from keras.src.ops.nn import conv_transpose
from keras.src.ops.nn import ctc_decode
from keras.src.ops.nn import ctc_loss
from keras.src.ops.nn import depthwise_conv
from keras.src.ops.nn import elu
Expand All @@ -71,6 +72,7 @@
from keras.src.ops.nn import multi_hot
from keras.src.ops.nn import normalize
from keras.src.ops.nn import one_hot
from keras.src.ops.nn import psnr
from keras.src.ops.nn import relu
from keras.src.ops.nn import relu6
from keras.src.ops.nn import selu
Expand Down Expand Up @@ -100,6 +102,7 @@
from keras.src.ops.numpy import arctanh
from keras.src.ops.numpy import argmax
from keras.src.ops.numpy import argmin
from keras.src.ops.numpy import argpartition
from keras.src.ops.numpy import argsort
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
Expand Down Expand Up @@ -190,10 +193,12 @@
from keras.src.ops.numpy import reshape
from keras.src.ops.numpy import roll
from keras.src.ops.numpy import round
from keras.src.ops.numpy import select
from keras.src.ops.numpy import sign
from keras.src.ops.numpy import sin
from keras.src.ops.numpy import sinh
from keras.src.ops.numpy import size
from keras.src.ops.numpy import slogdet
from keras.src.ops.numpy import sort
from keras.src.ops.numpy import split
from keras.src.ops.numpy import sqrt
Expand All @@ -218,6 +223,7 @@
from keras.src.ops.numpy import true_divide
from keras.src.ops.numpy import var
from keras.src.ops.numpy import vdot
from keras.src.ops.numpy import vectorize
from keras.src.ops.numpy import vstack
from keras.src.ops.numpy import where
from keras.src.ops.numpy import zeros
Expand Down

0 comments on commit 5a3542b

Please sign in to comment.