Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery refactored main branch #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

sourcery-ai[bot]
Copy link

@sourcery-ai sourcery-ai bot commented Dec 19, 2022

Branch main refactored by Sourcery.

If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.

See our documentation here.

Run Sourcery locally

Reduce the feedback loop during development by using the Sourcery editor plugin:

Review changes via command line

To manually merge these changes, make sure you're on the main branch, then run:

git fetch origin sourcery/main
git merge --ff-only FETCH_HEAD
git reset HEAD^

Help us improve this pull request!

Comment on lines +32 to -34
*excludes,
]
cmd.extend(excludes)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function run_apidoc refactored with the following changes:

Comment on lines -112 to +111
htmlhelp_basename = "{}doc".format(project)
htmlhelp_basename = f"{project}doc"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 112-112 refactored with the following changes:

Comment on lines -32 to +36
model = MyMNISTMLP(num_hidden=100)
else:
# If a model state is passed, we reload the model using PyTorch's load_state_dict.
# In this case, model hyperparameters are restored from the saved state.
state_dict = torch.load(str(model_state_url))
model = MyMNISTMLP.from_state_dict(state_dict)
return model
return MyMNISTMLP(num_hidden=100)
# If a model state is passed, we reload the model using PyTorch's load_state_dict.
# In this case, model hyperparameters are restored from the saved state.
state_dict = torch.load(str(model_state_url))
return MyMNISTMLP.from_state_dict(state_dict)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function model_fn refactored with the following changes:

Comment on lines -19 to +21
model = ResNet18CIFAR()
else:
state_dict = torch.load(str(model_state_url))
model = ResNet18CIFAR.from_state_dict(state_dict)
return model
return ResNet18CIFAR()
state_dict = torch.load(str(model_state_url))
return ResNet18CIFAR.from_state_dict(state_dict)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function model_fn refactored with the following changes:

Comment on lines -40 to -45
class_incremental_scenario = ClassIncrementalScenario(
return ClassIncrementalScenario(
data_module=data_module,
class_groupings=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
chunk_id=chunk_id,
)
return class_incremental_scenario
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function data_module_fn refactored with the following changes:

Comment on lines -228 to +225
assert self._dataset_name in ["clear10", "clear100"]
assert self._dataset_name in {"clear10", "clear100"}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function CLEARDataModule.__init__ refactored with the following changes:

Comment on lines -84 to +90
learner_args = learner_args + ["loss_weight_new_data", "memory_size", "memory_batch_size"]
learner_args += ["loss_weight_new_data", "memory_size", "memory_batch_size"]
updater_class = OfflineExperienceReplayModelUpdater
elif args.updater == "RD":
learner_args = learner_args + ["memory_size"]
learner_args += ["memory_size"]
updater_class = RepeatedDistillationModelUpdater
elif args.updater == "GDumb":
learner_args = learner_args + ["memory_size"]
learner_args += ["memory_size"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_updater_and_learner_kwargs refactored with the following changes:

  • Replace assignment with augmented assignment [×3] (aug-assign)

Comment on lines -103 to +110
updater: Optional[str] = None
for i, arg in enumerate(sys.argv):
if arg == "--updater" and len(sys.argv) > i:
updater = sys.argv[i + 1]
break
updater: Optional[str] = next(
(
sys.argv[i + 1]
for i, arg in enumerate(sys.argv)
if arg == "--updater" and len(sys.argv) > i
),
None,
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function parse_hyperparameters refactored with the following changes:

  • Use the built-in function next instead of a for-loop (use-next)

Comment on lines -485 to +494
transforms = {}
for transform_fn_name in transform_fn_names:
if transform_fn_name in vars(config_module):
transforms[transform_fn_name] = getattr(config_module, transform_fn_name)(
**get_transform_args(args)
)
return transforms
return {
transform_fn_name: getattr(config_module, transform_fn_name)(
**get_transform_args(args)
)
for transform_fn_name in transform_fn_names
if transform_fn_name in vars(config_module)
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_transforms_kwargs refactored with the following changes:

Comment on lines -172 to +180
if not folder_downloaded_from_s3:
if local_dir != current_state_folder:
shutil.rmtree(current_state_folder, ignore_errors=True)
if local_dir is not None:
shutil.copytree(
local_dir,
current_state_folder,
ignore=shutil.ignore_patterns("*.sagemaker-uploading"),
dirs_exist_ok=True,
)
if not folder_downloaded_from_s3 and local_dir != current_state_folder:
shutil.rmtree(current_state_folder, ignore_errors=True)
if local_dir is not None:
shutil.copytree(
local_dir,
current_state_folder,
ignore=shutil.ignore_patterns("*.sagemaker-uploading"),
dirs_exist_ok=True,
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ModelUpdaterCLI._copy_state_to_working_directory refactored with the following changes:

Comment on lines -238 to +239
early_stopping_enabled=bool(args.early_stopping),
early_stopping_enabled=args.early_stopping,
**learner_kwargs,
**get_transforms_kwargs(config_module, args),
**get_transforms_kwargs(config_module, args)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ModelUpdaterCLI.run refactored with the following changes:

max_accuracy_ki = max([results[k][i] for k in range(j)])
max_accuracy_ki = max(results[k][i] for k in range(j))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function forgetting refactored with the following changes:

Comment on lines -81 to +84
sum([results["accuracy"][task_id][i] - results["accuracy"][i][i] for i in range(task_id)])
sum(
results["accuracy"][task_id][i] - results["accuracy"][i][i]
for i in range(task_id)
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function backward_transfer refactored with the following changes:

Comment on lines -105 to +114
return sum(
[
return (
sum(
results["accuracy"][i - 1][i] - results["accuracy_init"][0][i]
for i in range(1, task_id + 1)
]
) / (task_id)
)
/ task_id
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function forward_transfer refactored with the following changes:

Comment on lines -314 to +315
[len(self._class_to_index_map[key]) for key in self._class_to_index_map.keys()]
len(self._class_to_index_map[key])
for key in self._class_to_index_map.keys()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function GreedyClassBalancingBuffer._update refactored with the following changes:

self._distillation_type = state_dict.pop(prefix + "distillation_type")
self._distillation_type = state_dict.pop(f"{prefix}distillation_type")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function WeightedPooledOutputDistillationLossComponent._load_from_state_dict refactored with the following changes:

s3_client.delete_object(Bucket=bucket, Key=str(object_name))
s3_client.delete_object(Bucket=bucket, Key=object_name)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function delete_file_from_s3 refactored with the following changes:

uri = "/tmp" + uri[7:]
uri = f"/tmp{uri[7:]}"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function redirect_to_tmp refactored with the following changes:

Comment on lines -106 to +107
[
isinstance(hyperparameter_instance, Domain)
for hyperparameter_instance in config_space.values()
]
isinstance(hyperparameter_instance, Domain)
for hyperparameter_instance in config_space.values()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function is_syne_tune_config_space refactored with the following changes:

Comment on lines -97 to +104
self._test_data = []
for i in range(self.num_chunks):
self._test_data.append(
DummyDataset(
torch.split(self.X_test, 100 // self.num_chunks)[i],
torch.split(self.y_test, 100 // self.num_chunks)[i],
self._transform,
)
self._test_data = [
DummyDataset(
torch.split(self.X_test, 100 // self.num_chunks)[i],
torch.split(self.y_test, 100 // self.num_chunks)[i],
self._transform,
)
for i in range(self.num_chunks)
]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function DummyTorchVisionDataModuleWithChunks.setup refactored with the following changes:

@pytest.mark.parametrize(
"model_name,expected_model_class",
[(model_name, model_class) for model_name, model_class in zip(models.keys(), models.values())],
)
@pytest.mark.parametrize("model_name,expected_model_class", list(zip(models.keys(), models.values())))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_model_fn refactored with the following changes:

Comment on lines -62 to +71
assert len(train_data) == sum([train_data_class_counts[c] for c in class_groupings[i]])
assert len(val_data) == sum([val_data_class_counts[c] for c in class_groupings[i]])
assert len(train_data) == sum(
train_data_class_counts[c] for c in class_groupings[i]
)
assert len(val_data) == sum(
val_data_class_counts[c] for c in class_groupings[i]
)
for j, test_data in enumerate(scenario.test_data()):
assert len(test_data) == sum([test_data_class_counts[c] for c in class_groupings[j]])
assert len(test_data) == sum(
test_data_class_counts[c] for c in class_groupings[j]
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_class_incremental_scenario refactored with the following changes:

Comment on lines -134 to +140
for j, test_data in enumerate(scenario.test_data()):
for test_data in scenario.test_data():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_permutation_scenario refactored with the following changes:

Comment on lines -27 to +33
@pytest.mark.parametrize(
"task_id,result",
[
[0, 0.0],
[1, (1 / 1) * (0.9362000226974487 - 0.8284000158937)],
[
@pytest.mark.parametrize("task_id,result", [[0, 0.0], [1, 1 * (0.9362000226974487 - 0.8284000158937)], [
2,
(1 / 2)
* sum(
[0.9362000226974487 - 0.4377000033855438, 0.9506999850273132 - 0.48260000348091125]
),
],
],
)
]])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_forgetting refactored with the following changes:

Comment on lines -46 to +45
@pytest.mark.parametrize(
"task_id,result",
[
[0, 0.0],
[1, (1 / 1) * (0.8284000158309937 - 0.9362000226974487)],
[
@pytest.mark.parametrize("task_id,result", [[0, 0.0], [1, 1 * (0.8284000158309937 - 0.9362000226974487)], [
2,
(1 / 2)
* sum(
[0.4377000033855438 - 0.9362000226974487, 0.48260000348091125 - 0.9506999850273132]
),
],
],
)
]])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_backward_transfer refactored with the following changes:

Comment on lines -189 to +190
if not isinstance(getattr(buffer, "_" + key), list):
assert getattr(buffer, "_" + key) == state_dict[key]
if not isinstance(getattr(buffer, f"_{key}"), list):
assert getattr(buffer, f"_{key}") == state_dict[key]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_buffer_load_state_dict refactored with the following changes:

state_url = None
if use_dir:
state_url = tmpdir
state_url = tmpdir if use_dir else None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_load_tuning_history_when_no_previous_history_exists refactored with the following changes:

learner_kwargs.update(LEARNER_HYPERPARAMETER_UPDATES[learner_class])
learner_kwargs |= LEARNER_HYPERPARAMETER_UPDATES[learner_class]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_update_hyperparameters refactored with the following changes:

Comment on lines -83 to +92
[
str(w.message).startswith("Early stopping is enabled but no metric is specified")
for w in warning_init
]
str(w.message).startswith(
"Early stopping is enabled but no metric is specified"
)
for w in warning_init
)
is_warning_early_stopping_without_val_set_sent = any(
[
str(w.message).startswith(
"Early stopping is currently not supported without a validation set"
)
for w in warning_update
]
str(w.message).startswith(
"Early stopping is currently not supported without a validation set"
)
for w in warning_update
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_model_updater_with_early_stopping refactored with the following changes:

for i in range(3):
for _ in range(3):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_er_validation_buffer refactored with the following changes:

@sourcery-ai
Copy link
Author

sourcery-ai bot commented Dec 19, 2022

Sourcery Code Quality Report

✅  Merging this PR will increase code quality in the affected files by 0.19%.

Quality metrics Before After Change
Complexity 3.27 ⭐ 3.15 ⭐ -0.12 👍
Method Length 80.37 🙂 79.99 🙂 -0.38 👍
Working memory 8.76 🙂 8.75 🙂 -0.01 👍
Quality 69.22% 🙂 69.41% 🙂 0.19% 👍
Other metrics Before After Change
Lines 6183 6176 -7
Changed files Quality Before Quality After Quality Change
doc/conf.py 75.97% ⭐ 76.28% ⭐ 0.31% 👍
examples/getting_started/renate_config.py 86.97% ⭐ 88.11% ⭐ 1.14% 👍
examples/simple_classifier_cifar10/renate_config.py 90.10% ⭐ 93.35% ⭐ 3.25% 👍
examples/train_mlp_locally/renate_config.py 86.52% ⭐ 90.54% ⭐ 4.02% 👍
src/renate/benchmark/experiment_config.py 68.90% 🙂 69.40% 🙂 0.50% 👍
src/renate/benchmark/experimentation.py 49.00% 😞 48.99% 😞 -0.01% 👎
src/renate/benchmark/scenarios.py 82.92% ⭐ 83.02% ⭐ 0.10% 👍
src/renate/benchmark/datasets/vision_datasets.py 64.45% 🙂 65.67% 🙂 1.22% 👍
src/renate/cli/parsing_functions.py 71.66% 🙂 71.59% 🙂 -0.07% 👎
src/renate/cli/run_training.py 63.98% 🙂 64.01% 🙂 0.03% 👍
src/renate/evaluation/metrics/classification.py 84.28% ⭐ 84.28% ⭐ 0.00%
src/renate/memory/buffer.py 69.75% 🙂 69.75% 🙂 0.00%
src/renate/models/layers/cn.py 82.55% ⭐ 82.55% ⭐ 0.00%
src/renate/tuning/tuning.py 48.64% 😞 48.88% 😞 0.24% 👍
src/renate/updaters/learner.py 74.69% 🙂 74.70% 🙂 0.01% 👍
src/renate/updaters/experimental/er.py 66.85% 🙂 67.21% 🙂 0.36% 👍
src/renate/updaters/learner_components/losses.py 76.55% ⭐ 76.54% ⭐ -0.01% 👎
src/renate/utils/file.py 75.17% ⭐ 75.18% ⭐ 0.01% 👍
src/renate/utils/syne_tune.py 81.68% ⭐ 81.65% ⭐ -0.03% 👎
test/datasets.py 84.57% ⭐ 85.10% ⭐ 0.53% 👍
test/renate/benchmark/test_experimentation_config.py 83.82% ⭐ 83.89% ⭐ 0.07% 👍
test/renate/benchmark/test_scenarios.py 62.11% 🙂 62.18% 🙂 0.07% 👍
test/renate/evaluation/metrics/test_classification.py 83.13% ⭐ 83.37% ⭐ 0.24% 👍
test/renate/memory/test_buffer.py 67.48% 🙂 67.45% 🙂 -0.03% 👎
test/renate/tuning/test_tuning.py 71.56% 🙂 71.20% 🙂 -0.36% 👎
test/renate/updaters/test_learner.py 84.28% ⭐ 84.40% ⭐ 0.12% 👍
test/renate/updaters/test_model_updater.py 57.44% 🙂 57.44% 🙂 0.00%
test/renate/updaters/experimental/test_er.py 57.58% 🙂 58.10% 🙂 0.52% 👍

Here are some functions in these files that still need a tune-up:

File Function Complexity Length Working Memory Quality Recommendation
src/renate/tuning/tuning.py _execute_tuning_job_locally 9 🙂 387 ⛔ 24 ⛔ 28.29% 😞 Try splitting into smaller methods. Extract out complex expressions
src/renate/benchmark/experimentation.py _execute_experiment_job_locally 4 ⭐ 539 ⛔ 32 ⛔ 30.32% 😞 Try splitting into smaller methods. Extract out complex expressions
src/renate/tuning/tuning.py execute_tuning_job 1 ⭐ 313 ⛔ 46 ⛔ 35.47% 😞 Try splitting into smaller methods. Extract out complex expressions
src/renate/updaters/experimental/er.py BaseExperienceReplayLearner.training_step 13 🙂 235 ⛔ 15 😞 35.84% 😞 Try splitting into smaller methods. Extract out complex expressions
src/renate/updaters/experimental/er.py SuperExperienceReplayModelUpdater.__init__ 0 ⭐ 311 ⛔ 75 ⛔ 36.47% 😞 Try splitting into smaller methods. Extract out complex expressions

Legend and Explanation

The emojis denote the absolute quality of the code:

  • ⭐ excellent
  • 🙂 good
  • 😞 poor
  • ⛔ very poor

The 👍 and 👎 indicate whether the quality has improved or gotten worse with this pull request.


Please see our documentation here for details on how these metrics are calculated.

We are actively working on this report - lots more documentation and extra metrics to come!

Help us improve this quality report!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
0 participants