Skip to content

Commit

Permalink
Small update_with_oracle_cut_size fixes
Browse files Browse the repository at this point in the history
Fix an off-by-one in `TransitionModel.forward`, where we always did one
move more than the maximum number of moves.

This explosed another issue: when creating cut states, we skipped states
where the (maximum number of) moves from that state only applied
transitions that did not modify the buffer.

Replace uses of `random.uniform` by `random.randrange`.
  • Loading branch information
danieldk committed Feb 21, 2023
1 parent e27c60a commit 10f5e94
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion spacy/ml/tb_framework.pyx
Expand Up @@ -338,9 +338,9 @@ def _forward_fallback(
all_ids.append(ids)
all_statevecs.append(statevecs)
all_which.append(which)
n_moves += 1
if n_moves >= max_moves >= 1:
break
n_moves += 1

def backprop_parser(d_states_d_scores):
ids = ops.xp.vstack(all_ids)
Expand Down
9 changes: 4 additions & 5 deletions spacy/pipeline/transition_parser.pyx
Expand Up @@ -258,7 +258,7 @@ class Parser(TrainablePipe):
# batch uniform length. Since we do not have a gold standard
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2)
states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)
Expand Down Expand Up @@ -425,7 +425,7 @@ class Parser(TrainablePipe):
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2)
init_states, gold_states, _ = self._init_gold_batch(
examples,
max_length=max_moves
Expand Down Expand Up @@ -729,9 +729,8 @@ class Parser(TrainablePipe):
action.do(state.c, action.label)
if state.is_final():
break
if moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
states.append(start_state)
golds.append(gold)
if state.is_final():
break
return states, golds, max_length
Expand Down

0 comments on commit 10f5e94

Please sign in to comment.