Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Field-aware factorization machines #604

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open

Conversation

mdymczyk
Copy link
Contributor

@mdymczyk mdymczyk commented May 11, 2018

Initial implementation of field-aware factorization machines.

Based on these 2 whitepapers:

And the following repositories:

Currently only initial GPU implementation as CPU will most probably just be a copy of the original impl (without the SSE alignments for now).

No benchmarks so far as there's still something wrong (getting different results).

Thing to be still done:

  • add validation set option and early stopping (FFM seems to need this a lot as it tends to overfit)
  • add multi GPU support
  • review the data structures used - using an object oriented approach with Dataset/Row/Node hierarchy is good for development but might provide a lot of overhead when copying data to the device, refactoring this into 3 (or more) continuous arrays might provide a lot of speedup
  • review the main method wTx (in trainer.cu) - probably can be rewritten in a more GPU friendly manner
  • probably something else I'm forgetting

If anyone wants to take it for a spin:

>>> from h2o4gpu.solvers.ffm import FFMH2O
>>> import numpy as np
>>> X = [ [(1, 2, 1), (2, 3, 1), (3, 5, 1)],
...      [(1, 0, 1), (2, 3, 1), (3, 7, 1)],
...      [(1, 1, 1), (2, 3, 1), (3, 7, 1), (3, 9, 1)] ]
>>>
>>> y = [1, 1, 0]
>>> ffmh2o = FFMH2O(n_gpus=1)
>>> ffmh2o.fit(X,y)
<h2o4gpu.solvers.ffm.FFMH2O object at 0x7f2d30319fd0>
>>> ffmh2o.predict(X)
array([0.7611223 , 0.6475924 , 0.88890105], dtype=float32)

The input format is a list of lists containing fieldIdx:featureIdx:value tuples and a corresponding list of labels (0 or 1) for each row.

@mdymczyk mdymczyk force-pushed the build/centos-rewrite branch 3 times, most recently from a25d455 to 50dc67b Compare May 31, 2018 04:26

if(update) {
expnyts[rowIdx % MAX_BLOCK_THREADS] = std::exp(-labels[rowIdx] * losses[rowIdx]);
kappas[rowIdx % MAX_BLOCK_THREADS] = -labels[rowIdx] * expnyts[rowIdx % MAX_BLOCK_THREADS] / (1 + expnyts[rowIdx % MAX_BLOCK_THREADS]);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line correct? There is a slightly different equation in the paper, but this does match the C code in libffm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@henrygouk true and not sure - need to double check. For now I went with the original C implementation but need to experiment.

const T w1gdup = (weightsPtr + idx1)[d+1] + g1 * g1;
const T w2gdup = (weightsPtr + idx2)[d+1] + g2 * g2;

(weightsPtr + idx1)[d] -= cLearningRate[0] / std::sqrt(w1gdup) * g1;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the AdaGrad update rule. Could be useful to look at using other update rules in future. Something like AMSGrad could potentially converge faster.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this technique is using HogWild! optimisation, which does not guarantee repeatability across different runs. Is this a problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@henrygouk yes to both :-)

  1. definitely different updaters would be great, preferably abstracting this part and encapsulating it in an Updater class would be best. Can make an issue with potential candidates

  2. for initial release I think it should be ok but in the long run I'm pretty sure indeterministic results will be a no-go, especially for our own use in DAI. Any suggestions how this could be fixed? Serializing it would incur a huge performance hit. Stochastic gradient descent? Not sure how viable that method is? Anything else?


const T v = vals[threadIdx.x] * (threadIdx.x + i < MAX_BLOCK_THREADS ? vals[threadIdx.x + i] : values[n1 + i]) * scales[rowIdx % MAX_BLOCK_THREADS];

if (update) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In future, it might be a bit cleaner to separate the training/prediction code into different kernels.


__syncthreads();

T loss = 0.0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it accurate to call this a loss? I think this is more like the prediction, but I may be misinterpreting things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original I think they call it t, not sure why - can rename. From my understanding, it is used as loss and also used to calculate predictions.


T loss = 0.0;

for(int i = 1; n1 + i < rowSizes[rowIdx + 1] - cBatchOffset[0]; i++) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How difficult would it be to parallelise this loop and add a reduction step afterwards? I'm guessing nontrivial, due to the different row sizes, but this is the main way I can think of to get some more parallelism out of this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, instead of spinning up a thread for each field:feature:value (lets call it a node) tuple and running this loop for each consecutive node in that row we could spin up a thread for each node pair. Didn't try that approach yet as I was afraid of thread overhead.

Currently the main slowdown I'm noticing is becuase of:

  1. number of registers used by the wTx kernel which limits the number of concurrent blocks being ran

  2. read/write from global memory when using the weights array since that access isn't very coalesce within blocks. Interesting experiment is to move lines 157/158:

(weightsPtr + idx1)[d+1] = w1gdup;
(weightsPtr + idx2)[d+1] = w2gdup;

Right after:

const T w1gdup = (weightsPtr + idx1)[d+1] + g1 * g1;
const T w2gdup = (weightsPtr + idx2)[d+1] + g2 * g2;

There's a visible (~10%) slowdown on relatively large data (400k rows, 39 nodes in each row). I'm assuming this is due to how CUDA loads/stores data.

Basically the major slowdown in coming from the if (update) { branch. Putting both weights and gradients into the same array helped quite a bit but even on 1080Ti this is only ~2-3x faster than the CPU implementation.

…older. Missing deps in runtime Docker runtime file.
…emporarily remove OMP. Add log_loss printout.
* squash weights and gradients into single array for memory reads
* utilize shared memory for fields/features/values/scales as much as possible
* compute kappa once per node instead of rowSize times
@mdymczyk mdymczyk changed the base branch from build/centos-rewrite to master June 1, 2018 04:25
@mdymczyk
Copy link
Contributor Author

So both CPU and GPU implementations are there and working, the only issue left is that GPU batch mode gives slightly different results with same # of iterations (or converges in a much larger number of iterations) compared to GPU batch mode with batch_size=1 and CPU modes. I'm guessing this is because we are using HOGWILD! and the order of computations during gradient update differs (and might not be 100% correct?).

@mdymczyk
Copy link
Contributor Author

One more thing: this needs to be compared against bigger data (libffm_toy.zip) and the original cpp implementation (https://github.com/guestwalk/libffm - not the Python API). I think the GPU version was getting a bit different results, so needs double checking before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants