Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Policy network penalized incorrectly on invalid moves #77

Open
Timeroot opened this issue Aug 13, 2018 · 15 comments
Open

Policy network penalized incorrectly on invalid moves #77

Timeroot opened this issue Aug 13, 2018 · 15 comments

Comments

@Timeroot
Copy link

Currently, the policy network is actively trained to output 0 on a move whenever it's invalid. The move is never taken, so target_pi's is zero there, and this enters into the loss function as a result. As a result, the policy network will be saying as much about "How often is this move legal" as it does about "Is this a good move" -- which is almost certainly hurting performance.

The correct action would be to manually mask out the result in NNet.py, right after getting the output from OthelloNNet.py: OthelloNNet.py returns F.log_softmax(pis), and then the entries there should be manually set to zero if it's not a valid move. This will prevent gradients from propagating back there.

@Timeroot
Copy link
Author

Err, a correction: it should be masked with valid moves, and then re-normalized. Otherwise it's not changing anything.

@Timeroot
Copy link
Author

For reference: I fixed this bug in my local copy, and it appears to be learning several times faster . Passing in the valid moves requires adapting a lot of the interfaces though, so, changing all the game implementations to do the masking. And my local copy is a huge mess, so I don't have anything I can easily offer as a PR.

But as some reference as to what I did, in Coach.py, the trainingExamples need to be extended to include what the valid moves are, so that's got something like

`

        pi = self.best_mcts.getActionProb(canonicalBoard, temp=temp)
        valids = self.game.getValidMoves(canonicalBoard, 1)
        bs, ps = zip(*self.game.getSymmetries(canonicalBoard, pi))
        _, valids_sym = zip(*self.game.getSymmetries(canonicalBoard, valids))
        sym = zip(bs,ps,valids_sym)

        for b,p,valid in sym:
            trainExamples.append([b, self.curPlayer, p, valid])

        action = np.random.choice(len(pi), p=pi)
        board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)

        r = self.game.getGameEnded(board, self.curPlayer)

        if r!=0:
            return [(x[0],x[2],r*((-1)**(x[1]!=self.curPlayer)),x[4]) for x in trainExamples]`

whereas obviously the ideal thing would be adapting getSymmetries to also permute valids, instead f calling it twice like I do.

Then MCTS.py needs to be changed to pass the valid moves in through nnet.predict, as in

valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s], v = self.nnet.predict((canonicalBoard,valids))

Then NNet.py needs to have valids added a value to unpack, in both predict and train, and once prepped this should get passed in to the net -- so changing self.nnet(boards) to self.nnet((boards, valids)).

Finally, OthelloNNet.py needs to be changed to use the valids as a mask. So forward will start

def forward(self, boardsAndValids): s, valids = boardsAndValids

and include pi = valids * pi after computing pi. Obviously NNet.py and OthelloNNet.py need to appropriately changed in all libraries' implementations, for all games; and if you don't want to do the ugly getSymmetries hack like above, then that means also changing the convention for that in OthelloGame.py and all corresponding game definitions.

@Marcin1960
Copy link

"For reference: I fixed this bug in my local copy, and it appears to be learning several times faster ."
@Timeroot

Wow!

@Timeroot
Copy link
Author

I realized that my wording above was somewhat vague. pi = valids * pi works, but then you can no longer use log_softmax, since that has the normalization inside of it. So you'll likely need to sacrifice the numerical stability of log_softmax (which is probably fine), and replace

pi = log_softmax(pi)

with

pi = torch.exp(pi) pi = valids * pi pi = pi / torch.sum(pi)

The alternative is:

pi -= (1-valids)*1000 pi = log_softmax(pi)

which just reduces the log-probabilities to such a negative value as to be effectively zero.

@keesdejong
Copy link

@Timeroot
Interesting. Although I am an absolute beginner, I already had suspicions in the same direction. I am going to try and implement your ideas in my own game for which I have largely used this git. I am curious about the result. It would be useful if you made the code that you created in one way or another accessible.

@keesdejong
Copy link

@Timeroot

I'm almost there, but I have a problem with implementing your idea in tensorflow. Maybe you can help me. I have a problem with this piece of your description:

replace
pi = log_softmax (pi)
with
pi = torch.exp (pi) pi = valids * pi pi = pi / torch.sum (pi)
The alternative is:
pi - = (1-valids) * 1000 pi = log_softmax (pi)

In tensorflow.OthelloNet this should be around line 36:
self.pi = Dense (s_fc2, self.action_size)
self.prob = tf.nn.softmax (self.pi)
self.v = Tanh (Dense (s_fc2, 1))
self.calculate_loss ()

I am new in this matter. Can you tell me how to do this for Tensorflow?

@evg-tyurin
Copy link
Contributor

@keesdejong
I'm looking forward for the confirmation that this trick can speed up the learning. I have my own implementation of the game of checkers here https://github.com/evg-tyurin/alpha-nagibator
It's based on this repo so I could implement changes/make PR to both repos in the case it really helps.

@aletote
Copy link

aletote commented Sep 19, 2018

Any news on when this is going to be updated for the main branch?
Also, how much faster is it?

@51616
Copy link

51616 commented Nov 16, 2018

I followed your fix and got small negative pi loss (around -0.00xx). Is this normal ?
edit. now i'm using alternative code and it produces positive pi loss.

@jl1990
Copy link

jl1990 commented Nov 20, 2018

@51616 If you made a fix, please upload PR

@51616
Copy link

51616 commented Nov 21, 2018

@jl1990 i use this equation instead.
pi -= (1-valids)*1000 pi = log_softmax(pi)
this should produce positive pi loss.

@suragnair suragnair pinned this issue Dec 21, 2018
@rlronan
Copy link
Contributor

rlronan commented Mar 8, 2020

I made a pull with the changes introduced for Othello's Tensorflow implementation.
Can someone confirm that these changes are correct? I can add them to the remainder of the games/implementations if that is the case,.
As a warning, these changes are not backwards compatible with the saved models in the repository.

Also can someone provide some guidance as to how I would properly cite Timeroot for the changes, since I essentially just implemented what he suggested verbatim?

Here's the pull:

#163

@keesdejong
Copy link

@rlronan Thank you. I'll see if I find the time to check if it works. And, most importantly, whether it has any effect.

@mha-py
Copy link

mha-py commented Aug 17, 2020

This thread is a bit older, but since its pinned here is an idea:
What about putting the information, if a move is invalid, into the p array of the replay list, with a value like -0.001 (like its done with draws)? Then the symmetry functions dont have to be changed (since it works with p) and its downward compatible: the value is close to zero, so the unupdated files for the games will still work.

@cestpasphoto
Copy link

cestpasphoto commented Feb 11, 2021

Adding my opinion on this (old) thread:
I implemented the formula proposed above pi -= (1-valids)*1000
I also implemented this paper (check also the related github repo) with pi = torch.where(valids, pi, -1e8)

The second one results in slightly better results in my game (which has about 50% invalid moves on average), and slightly faster training (~5%) but I didn't tried with Othello though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests