-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored main branch #1
base: main
Are you sure you want to change the base?
Conversation
*excludes, | ||
] | ||
cmd.extend(excludes) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function run_apidoc
refactored with the following changes:
- Merge extend into list declaration (
merge-list-extend
)
htmlhelp_basename = "{}doc".format(project) | ||
htmlhelp_basename = f"{project}doc" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 112-112
refactored with the following changes:
- Replace call to format with f-string (
use-fstring-for-formatting
)
model = MyMNISTMLP(num_hidden=100) | ||
else: | ||
# If a model state is passed, we reload the model using PyTorch's load_state_dict. | ||
# In this case, model hyperparameters are restored from the saved state. | ||
state_dict = torch.load(str(model_state_url)) | ||
model = MyMNISTMLP.from_state_dict(state_dict) | ||
return model | ||
return MyMNISTMLP(num_hidden=100) | ||
# If a model state is passed, we reload the model using PyTorch's load_state_dict. | ||
# In this case, model hyperparameters are restored from the saved state. | ||
state_dict = torch.load(str(model_state_url)) | ||
return MyMNISTMLP.from_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function model_fn
refactored with the following changes:
- Lift return into if (
lift-return-into-if
) - Remove unnecessary else after guard condition (
remove-unnecessary-else
)
model = ResNet18CIFAR() | ||
else: | ||
state_dict = torch.load(str(model_state_url)) | ||
model = ResNet18CIFAR.from_state_dict(state_dict) | ||
return model | ||
return ResNet18CIFAR() | ||
state_dict = torch.load(str(model_state_url)) | ||
return ResNet18CIFAR.from_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function model_fn
refactored with the following changes:
- Lift return into if (
lift-return-into-if
) - Remove unnecessary else after guard condition (
remove-unnecessary-else
)
class_incremental_scenario = ClassIncrementalScenario( | ||
return ClassIncrementalScenario( | ||
data_module=data_module, | ||
class_groupings=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], | ||
chunk_id=chunk_id, | ||
) | ||
return class_incremental_scenario |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function data_module_fn
refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable
)
assert self._dataset_name in ["clear10", "clear100"] | ||
assert self._dataset_name in {"clear10", "clear100"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function CLEARDataModule.__init__
refactored with the following changes:
- Use set when checking membership of a collection of literals (
collection-into-set
)
learner_args = learner_args + ["loss_weight_new_data", "memory_size", "memory_batch_size"] | ||
learner_args += ["loss_weight_new_data", "memory_size", "memory_batch_size"] | ||
updater_class = OfflineExperienceReplayModelUpdater | ||
elif args.updater == "RD": | ||
learner_args = learner_args + ["memory_size"] | ||
learner_args += ["memory_size"] | ||
updater_class = RepeatedDistillationModelUpdater | ||
elif args.updater == "GDumb": | ||
learner_args = learner_args + ["memory_size"] | ||
learner_args += ["memory_size"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_updater_and_learner_kwargs
refactored with the following changes:
- Replace assignment with augmented assignment [×3] (
aug-assign
)
updater: Optional[str] = None | ||
for i, arg in enumerate(sys.argv): | ||
if arg == "--updater" and len(sys.argv) > i: | ||
updater = sys.argv[i + 1] | ||
break | ||
updater: Optional[str] = next( | ||
( | ||
sys.argv[i + 1] | ||
for i, arg in enumerate(sys.argv) | ||
if arg == "--updater" and len(sys.argv) > i | ||
), | ||
None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function parse_hyperparameters
refactored with the following changes:
- Use the built-in function
next
instead of a for-loop (use-next
)
transforms = {} | ||
for transform_fn_name in transform_fn_names: | ||
if transform_fn_name in vars(config_module): | ||
transforms[transform_fn_name] = getattr(config_module, transform_fn_name)( | ||
**get_transform_args(args) | ||
) | ||
return transforms | ||
return { | ||
transform_fn_name: getattr(config_module, transform_fn_name)( | ||
**get_transform_args(args) | ||
) | ||
for transform_fn_name in transform_fn_names | ||
if transform_fn_name in vars(config_module) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_transforms_kwargs
refactored with the following changes:
- Convert for loop into dictionary comprehension (
dict-comprehension
) - Inline variable that is immediately returned (
inline-immediately-returned-variable
)
if not folder_downloaded_from_s3: | ||
if local_dir != current_state_folder: | ||
shutil.rmtree(current_state_folder, ignore_errors=True) | ||
if local_dir is not None: | ||
shutil.copytree( | ||
local_dir, | ||
current_state_folder, | ||
ignore=shutil.ignore_patterns("*.sagemaker-uploading"), | ||
dirs_exist_ok=True, | ||
) | ||
if not folder_downloaded_from_s3 and local_dir != current_state_folder: | ||
shutil.rmtree(current_state_folder, ignore_errors=True) | ||
if local_dir is not None: | ||
shutil.copytree( | ||
local_dir, | ||
current_state_folder, | ||
ignore=shutil.ignore_patterns("*.sagemaker-uploading"), | ||
dirs_exist_ok=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function ModelUpdaterCLI._copy_state_to_working_directory
refactored with the following changes:
- Merge nested if conditions (
merge-nested-ifs
)
early_stopping_enabled=bool(args.early_stopping), | ||
early_stopping_enabled=args.early_stopping, | ||
**learner_kwargs, | ||
**get_transforms_kwargs(config_module, args), | ||
**get_transforms_kwargs(config_module, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function ModelUpdaterCLI.run
refactored with the following changes:
- Remove unnecessary casts to int, str, float or bool (
remove-unnecessary-cast
)
max_accuracy_ki = max([results[k][i] for k in range(j)]) | ||
max_accuracy_ki = max(results[k][i] for k in range(j)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function forgetting
refactored with the following changes:
- Replace unneeded comprehension with generator (
comprehension-to-generator
)
sum([results["accuracy"][task_id][i] - results["accuracy"][i][i] for i in range(task_id)]) | ||
sum( | ||
results["accuracy"][task_id][i] - results["accuracy"][i][i] | ||
for i in range(task_id) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function backward_transfer
refactored with the following changes:
- Replace unneeded comprehension with generator (
comprehension-to-generator
)
return sum( | ||
[ | ||
return ( | ||
sum( | ||
results["accuracy"][i - 1][i] - results["accuracy_init"][0][i] | ||
for i in range(1, task_id + 1) | ||
] | ||
) / (task_id) | ||
) | ||
/ task_id | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function forward_transfer
refactored with the following changes:
- Replace unneeded comprehension with generator (
comprehension-to-generator
)
[len(self._class_to_index_map[key]) for key in self._class_to_index_map.keys()] | ||
len(self._class_to_index_map[key]) | ||
for key in self._class_to_index_map.keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function GreedyClassBalancingBuffer._update
refactored with the following changes:
- Replace unneeded comprehension with generator (
comprehension-to-generator
)
self._distillation_type = state_dict.pop(prefix + "distillation_type") | ||
self._distillation_type = state_dict.pop(f"{prefix}distillation_type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function WeightedPooledOutputDistillationLossComponent._load_from_state_dict
refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
s3_client.delete_object(Bucket=bucket, Key=str(object_name)) | ||
s3_client.delete_object(Bucket=bucket, Key=object_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function delete_file_from_s3
refactored with the following changes:
- Remove unnecessary casts to int, str, float or bool (
remove-unnecessary-cast
)
uri = "/tmp" + uri[7:] | ||
uri = f"/tmp{uri[7:]}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function redirect_to_tmp
refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
[ | ||
isinstance(hyperparameter_instance, Domain) | ||
for hyperparameter_instance in config_space.values() | ||
] | ||
isinstance(hyperparameter_instance, Domain) | ||
for hyperparameter_instance in config_space.values() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function is_syne_tune_config_space
refactored with the following changes:
- Replace unneeded comprehension with generator (
comprehension-to-generator
)
self._test_data = [] | ||
for i in range(self.num_chunks): | ||
self._test_data.append( | ||
DummyDataset( | ||
torch.split(self.X_test, 100 // self.num_chunks)[i], | ||
torch.split(self.y_test, 100 // self.num_chunks)[i], | ||
self._transform, | ||
) | ||
self._test_data = [ | ||
DummyDataset( | ||
torch.split(self.X_test, 100 // self.num_chunks)[i], | ||
torch.split(self.y_test, 100 // self.num_chunks)[i], | ||
self._transform, | ||
) | ||
for i in range(self.num_chunks) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function DummyTorchVisionDataModuleWithChunks.setup
refactored with the following changes:
- Convert for loop into list comprehension (
list-comprehension
)
@pytest.mark.parametrize( | ||
"model_name,expected_model_class", | ||
[(model_name, model_class) for model_name, model_class in zip(models.keys(), models.values())], | ||
) | ||
@pytest.mark.parametrize("model_name,expected_model_class", list(zip(models.keys(), models.values()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_model_fn
refactored with the following changes:
- Replace identity comprehension with call to collection constructor (
identity-comprehension
)
assert len(train_data) == sum([train_data_class_counts[c] for c in class_groupings[i]]) | ||
assert len(val_data) == sum([val_data_class_counts[c] for c in class_groupings[i]]) | ||
assert len(train_data) == sum( | ||
train_data_class_counts[c] for c in class_groupings[i] | ||
) | ||
assert len(val_data) == sum( | ||
val_data_class_counts[c] for c in class_groupings[i] | ||
) | ||
for j, test_data in enumerate(scenario.test_data()): | ||
assert len(test_data) == sum([test_data_class_counts[c] for c in class_groupings[j]]) | ||
assert len(test_data) == sum( | ||
test_data_class_counts[c] for c in class_groupings[j] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_class_incremental_scenario
refactored with the following changes:
- Replace unneeded comprehension with generator [×3] (
comprehension-to-generator
)
for j, test_data in enumerate(scenario.test_data()): | ||
for test_data in scenario.test_data(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_permutation_scenario
refactored with the following changes:
- Remove unnecessary calls to
enumerate
when the index is not used (remove-unused-enumerate
)
@pytest.mark.parametrize( | ||
"task_id,result", | ||
[ | ||
[0, 0.0], | ||
[1, (1 / 1) * (0.9362000226974487 - 0.8284000158937)], | ||
[ | ||
@pytest.mark.parametrize("task_id,result", [[0, 0.0], [1, 1 * (0.9362000226974487 - 0.8284000158937)], [ | ||
2, | ||
(1 / 2) | ||
* sum( | ||
[0.9362000226974487 - 0.4377000033855438, 0.9506999850273132 - 0.48260000348091125] | ||
), | ||
], | ||
], | ||
) | ||
]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_forgetting
refactored with the following changes:
- Simplify binary operation (
bin-op-identity
)
@pytest.mark.parametrize( | ||
"task_id,result", | ||
[ | ||
[0, 0.0], | ||
[1, (1 / 1) * (0.8284000158309937 - 0.9362000226974487)], | ||
[ | ||
@pytest.mark.parametrize("task_id,result", [[0, 0.0], [1, 1 * (0.8284000158309937 - 0.9362000226974487)], [ | ||
2, | ||
(1 / 2) | ||
* sum( | ||
[0.4377000033855438 - 0.9362000226974487, 0.48260000348091125 - 0.9506999850273132] | ||
), | ||
], | ||
], | ||
) | ||
]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_backward_transfer
refactored with the following changes:
- Simplify binary operation (
bin-op-identity
)
if not isinstance(getattr(buffer, "_" + key), list): | ||
assert getattr(buffer, "_" + key) == state_dict[key] | ||
if not isinstance(getattr(buffer, f"_{key}"), list): | ||
assert getattr(buffer, f"_{key}") == state_dict[key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_buffer_load_state_dict
refactored with the following changes:
- Use f-string instead of string concatenation [×2] (
use-fstring-for-concatenation
)
state_url = None | ||
if use_dir: | ||
state_url = tmpdir | ||
state_url = tmpdir if use_dir else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_load_tuning_history_when_no_previous_history_exists
refactored with the following changes:
- Move setting of default value for variable into
else
branch (introduce-default-else
) - Replace if statement with if expression (
assign-if-exp
)
learner_kwargs.update(LEARNER_HYPERPARAMETER_UPDATES[learner_class]) | ||
learner_kwargs |= LEARNER_HYPERPARAMETER_UPDATES[learner_class] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_update_hyperparameters
refactored with the following changes:
- Merge dictionary updates via the union operator (
dict-assign-update-to-union
)
[ | ||
str(w.message).startswith("Early stopping is enabled but no metric is specified") | ||
for w in warning_init | ||
] | ||
str(w.message).startswith( | ||
"Early stopping is enabled but no metric is specified" | ||
) | ||
for w in warning_init | ||
) | ||
is_warning_early_stopping_without_val_set_sent = any( | ||
[ | ||
str(w.message).startswith( | ||
"Early stopping is currently not supported without a validation set" | ||
) | ||
for w in warning_update | ||
] | ||
str(w.message).startswith( | ||
"Early stopping is currently not supported without a validation set" | ||
) | ||
for w in warning_update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_model_updater_with_early_stopping
refactored with the following changes:
- Replace unneeded comprehension with generator [×2] (
comprehension-to-generator
)
for i in range(3): | ||
for _ in range(3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function test_er_validation_buffer
refactored with the following changes:
- Replace unused for index with underscore (
for-index-underscore
)
Sourcery Code Quality Report✅ Merging this PR will increase code quality in the affected files by 0.19%.
Here are some functions in these files that still need a tune-up:
Legend and ExplanationThe emojis denote the absolute quality of the code:
The 👍 and 👎 indicate whether the quality has improved or gotten worse with this pull request. Please see our documentation here for details on how these metrics are calculated. We are actively working on this report - lots more documentation and extra metrics to come! Help us improve this quality report! |
Branch
main
refactored by Sourcery.If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.
See our documentation here.
Run Sourcery locally
Reduce the feedback loop during development by using the Sourcery editor plugin:
Review changes via command line
To manually merge these changes, make sure you're on the
main
branch, then run:Help us improve this pull request!