Skip to content

Commit

Permalink
clean the codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ftshijt committed Jan 23, 2024
1 parent 3b2a4ed commit 5aaabf6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 25 deletions.
25 changes: 0 additions & 25 deletions fairseq/models/multires_hubert/multires_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ def __init__(
== len(cfg.label_rate_ratios) // 2
), "number of override encoder layers must match the label rate ratios information"
self.len_encoder_modules = len(self.override_encoder_layers)
logger.info(self.override_encoder_layers)
else:
self.override_encoder_layers = None
self.len_encoder_modules = None
Expand Down Expand Up @@ -477,7 +476,6 @@ def __init__(
self.label_rates.append(self.base_rate)

for label_rate_ratio in self.label_rate_ratios:
logger.info("label_Rate_ratio: {}".format(label_rate_ratio))
upsample_rate, downsample_rate = label_rate_ratio
if (base_ds_rate * upsample_rate) % downsample_rate != 0:
logger.warning(
Expand Down Expand Up @@ -554,7 +552,6 @@ def __init__(

# Note(jiatong): different from hubert, we just set the final dim as encoder_embed_dim
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
# final_dim = cfg.encoder_embed_dim

self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
Expand Down Expand Up @@ -584,10 +581,6 @@ def __init__(
self.label_embs_concat = nn.ParameterList()

for i in range(self.predictor_head_num):
# if dictionaries[i] is None:
# self.label_embs_concat.append(None)
# continue
# TODO(jiatong): for skipping case
if self.use_single_target:
num_classes = len(dictionaries[0])
else:
Expand Down Expand Up @@ -691,11 +684,9 @@ def forward_targets(
targ_tsz = target.size(1)
if feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / feat2tar_ratio)
# logger.info("feat_tsz: {}, features: {}, target:{}".format(feat_tsz, features.size(), target.size()))
features = features[:, :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * feat2tar_ratio
target = target[:, target_inds.long()]
# logger.info("finalized_target: {}, feat2tar_ratio: {}, target_inds: {}".format(target.size(), feat2tar_ratio,target_inds.long()))
return features, target

def forward_padding_mask(
Expand Down Expand Up @@ -774,7 +765,6 @@ def align_size_sum(feat1, pad1, feat2):
x, padding_mask, mask_indices = self.downsample_modules[i](
x, padding=padding_mask, mask_indices=mask_indices
)
# logger.info("index: {}, x_input: {}, residuals: {}".format(i, x.size(), residuals[i].size()))

residual = self.middle_encoder(x, padding_mask=padding_mask, layer=None)[0]
x = x + residual
Expand All @@ -790,7 +780,6 @@ def align_size_sum(feat1, pad1, feat2):
self.label_nums - 2 - i
](x, padding=padding_mask, mask_indices=mask_indices)
x, _ = self.decoders[i](x, padding_mask=padding_mask, layer=None)
# logger.info("index: {}, x_input: {}, residuals: {}".format(i, x.size(), residuals[i].size()))
x, padding_mask = align_size_sum(x, padding_mask, residuals[i])
res_outputs.append(x)
padding_masks.append(padding_mask)
Expand All @@ -803,7 +792,6 @@ def align_size_sum(feat1, pad1, feat2):
if target_list is not None:
new_target_list = []
for i in range(self.label_nums):
# logger.info("i: {}, res_output: {}, target_list: {}".format(i, res_outputs[i].size(), target_list[0].size()))
if self.use_single_target:
res_outputs[i], reformat_target_list = self.forward_targets(
res_outputs[i], target_list[0], self.feat2tar_ratios[i]
Expand All @@ -827,7 +815,6 @@ def align_size_sum(feat1, pad1, feat2):
res_outputs[i], multi_mask_indices[i]
)

# logger.info("-" * 100)

if features_only:
# NOTE(jiatong): need to reverse back
Expand All @@ -850,13 +837,6 @@ def compute_pred(proj_x, target, label_embs):
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)

multires_record = {
"logit_m_list": [],
"logit_u_list": [],
"padding_mask": [],
"features_pen": [],
}

logit_m_list, logit_u_list = [], []
for j in range(self.label_nums):
if new_target_list[j] is None:
Expand Down Expand Up @@ -890,11 +870,6 @@ def compute_pred(proj_x, target, label_embs):
else:
logit_u_list.append(None)

multires_record["logit_m_list"].append(logit_m_list)
multires_record["logit_u_list"].append(logit_u_list)
multires_record["padding_mask"].append(padding_mask)
multires_record["features_pen"].append(features_pen)

# if we only want one prediction, we can exit now
if self.predictor_head_num == 1:
break
Expand Down
4 changes: 4 additions & 0 deletions fairseq/tasks/multires_hubert_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class MultiresHubertPretrainingConfig(FairseqDataclass):

@register_task("multires_hubert_pretraining", dataclass=MultiresHubertPretrainingConfig)
class MultiresHubertPretrainingTask(FairseqTask):
"""
Multiresolution HuBERT Pretraining Task.
The task is based on `HubertPretrainingTask` but extended to multiresolution.
"""

cfg: MultiresHubertPretrainingConfig

Expand Down

0 comments on commit 5aaabf6

Please sign in to comment.