Skip to content

Commit

Permalink
Merge pull request #49 from trivago/couple_fixes
Browse files Browse the repository at this point in the history
Couple fixes
  • Loading branch information
wching committed Feb 7, 2020
2 parents d4bccc3 + 9074cc5 commit 9da491f
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 48 deletions.
Expand Up @@ -4,21 +4,21 @@ import com.google.gson.annotations.SerializedName
import java.io.Serializable
import java.util.Calendar

class OAuth2AccessToken(
data class OAuth2AccessToken(
/**
* REQUIRED
* The type of the token issued as described in https://tools.ietf.org/html/rfc6749#section-7.1.
* Value is case insensitive.
*/
@SerializedName("token_type")
var tokenType: String? = null,
val tokenType: String = "",

/**
* REQUIRED
* The access token issued by the authorization server.
*/
@SerializedName("access_token")
var accessToken: String? = null,
val accessToken: String = "",

/**
* OPTIONAL
Expand All @@ -27,7 +27,7 @@ class OAuth2AccessToken(
* in https://tools.ietf.org/html/rfc6749#section-6.
*/
@SerializedName("refresh_token")
var refreshToken: String? = null,
val refreshToken: String? = null,

/**
* RECOMMENDED
Expand All @@ -38,13 +38,13 @@ class OAuth2AccessToken(
* expiration time via other means or document the default value.
*/
@SerializedName("expires_in")
var expiresIn: Int? = null,
val expiresIn: Int? = null,

/**
* The expiration date used by Heimdall.
*/
@SerializedName("heimdall_expiration_date")
var expirationDate: Calendar? = null
val expirationDate: Calendar? = null
) : Serializable {

// Public API
Expand All @@ -58,19 +58,4 @@ class OAuth2AccessToken(
expirationDate != null &&
Calendar.getInstance().after(expirationDate)


override fun equals(other: Any?): Boolean =
when {
this === other -> true
other !is OAuth2AccessToken -> false
else -> {
accessToken.equals(other.accessToken) && tokenType.equals(other.accessToken)
}
}


override fun hashCode(): Int =
tokenType.hashCode().let {
31 * it + accessToken.hashCode()
}
}
Expand Up @@ -29,15 +29,22 @@ open class OAuth2AccessTokenManager(
calendar: Calendar = Calendar.getInstance()
): Single<OAuth2AccessToken> =
grant.grantNewAccessToken()
.doOnSuccess { token ->
token.expiresIn?.let {
.map {
if (it.expiresIn != null) {
val newExpirationDate = (calendar.clone() as Calendar).apply {
add(Calendar.SECOND, it)
add(Calendar.SECOND, it.expiresIn)
}
token.expirationDate = newExpirationDate
it.copy(expirationDate = newExpirationDate)
} else {
it
}
mStorage.storeAccessToken(token)
}.cache()
}
.doOnSuccess { token ->
mStorage.storeAccessToken(
token = token
)
}
.cache()

/**
* Returns an Observable emitting an unexpired access token.
Expand Down
Expand Up @@ -45,9 +45,9 @@ abstract class OAuth2AuthorizationCodeGrant(

// Constants
companion object {
@JvmStatic
@JvmField
val RESPONSE_TYPE = "code"
@JvmStatic
@JvmField
val GRANT_TYPE = "authorization_code"
private const val UTF_8 = "UTF-8"
}
Expand Down
Expand Up @@ -16,7 +16,7 @@ abstract class OAuth2ClientCredentialsGrant(
* REQUIRED
* The OAuth2 "grant_type".
*/
@JvmStatic
@JvmField
val GRANT_TYPE = "client_credentials"
}
}
Expand Up @@ -44,7 +44,7 @@ abstract class OAuth2ImplicitGrant(
* The "response_type" which MUST be "token".
*/
companion object {
@JvmStatic
@JvmField
val RESPONSE_TYPE = "token"
}
}
Expand Up @@ -21,7 +21,7 @@ abstract class OAuth2RefreshAccessTokenGrant(
* The OAuth2 "grant_type".
*/
companion object {
@JvmStatic
@JvmField
val GRANT_TYPE = "refresh_token"
}
}
Expand Up @@ -21,7 +21,7 @@ abstract class OAuth2ResourceOwnerPasswordCredentialsGrant(
) : OAuth2Grant {

companion object {
@JvmStatic
@JvmField
val GRANT_TYPE = "password"
}
}
Expand Up @@ -17,11 +17,14 @@ class OAuth2AccessTokenManagerGrantNewAccessTokenTest {
val accessToken = OAuth2AccessToken(
expirationDate = null
)
accessToken.expiresIn = 3
val changedAccessToken = accessToken.copy(
expiresIn = 3
)


// and a grant that emits that token
val grant = mock<OAuth2Grant>().apply {
whenever(grantNewAccessToken()).thenReturn(Single.just(accessToken))
whenever(grantNewAccessToken()).thenReturn(Single.just(changedAccessToken))
}

// and a tokenManager
Expand Down Expand Up @@ -49,11 +52,13 @@ class OAuth2AccessTokenManagerGrantNewAccessTokenTest {
val accessToken = OAuth2AccessToken(
expirationDate = null
)
accessToken.expiresIn = null
val changedAccessToken = accessToken.copy(
expiresIn = null
)

// and a grant that emits that token
val grant = mock<OAuth2Grant>().apply {
whenever(grantNewAccessToken()).thenReturn(Single.just(accessToken))
whenever(grantNewAccessToken()).thenReturn(Single.just(changedAccessToken))
}

// and a tokenManager
Expand Down
Expand Up @@ -28,10 +28,14 @@ class OAuth2AccessTokenSerializationTest {
)
val expirationDate = Calendar.getInstance()
expirationDate.timeInMillis = 0
accessToken.expirationDate = expirationDate

// and an updated expiration date
val changedAccessToken = accessToken.copy(
expirationDate = expirationDate
)

// when it gets serialized with Gson
val json = Gson().toJson(accessToken)
val json = Gson().toJson(changedAccessToken)

// then the json should be written correctly
assertEquals(
Expand Down
Expand Up @@ -29,7 +29,7 @@ public URL buildAuthorizationUrl() {
.buildUpon()
.appendQueryParameter("client_id", getClientId())
.appendQueryParameter("redirect_uri", getRedirectUri())
.appendQueryParameter("response_type", OAuth2AuthorizationCodeGrant.getRESPONSE_TYPE())
.appendQueryParameter("response_type", OAuth2AuthorizationCodeGrant.RESPONSE_TYPE)
.build()
.toString()
);
Expand All @@ -41,7 +41,7 @@ public URL buildAuthorizationUrl() {
@Override
public Observable<OAuth2AccessToken> exchangeTokenUsingCode(String code) {
AccessTokenRequestBody body = new AccessTokenRequestBody(
code, getClientId(), getRedirectUri(), clientSecret, getGRANT_TYPE()
code, getClientId(), getRedirectUri(), clientSecret, GRANT_TYPE
);
return TraktTvApiFactory.newApiService().grantNewAccessToken(body);
}
Expand Down
Expand Up @@ -9,7 +9,7 @@
/**
* TraktTv refresh token grant as described in http://docs.trakt.apiary.io/#reference/authentication-oauth/token/exchange-refresh_token-for-access_token.
*/
public class TraktTvRefreshAccessTokenGrant extends OAuth2RefreshAccessTokenGrant<OAuth2AccessToken> {
public class TraktTvRefreshAccessTokenGrant extends OAuth2RefreshAccessTokenGrant {

// Properties

Expand All @@ -21,7 +21,7 @@ public class TraktTvRefreshAccessTokenGrant extends OAuth2RefreshAccessTokenGran

@Override
public Single<OAuth2AccessToken> grantNewAccessToken() {
RefreshTokenRequestBody body = new RefreshTokenRequestBody(getRefreshToken(), clientId, clientSecret, redirectUri, getGRANT_TYPE());
RefreshTokenRequestBody body = new RefreshTokenRequestBody(getRefreshToken(), clientId, clientSecret, redirectUri, GRANT_TYPE);
return TraktTvApiFactory.newApiService().refreshAccessToken(body).singleOrError();
}
}
Expand Up @@ -14,7 +14,7 @@
*
* @param <TAccessToken> The access token type.
*/
public class SharedPreferencesOAuth2AccessTokenStorage<TAccessToken extends OAuth2AccessToken> implements OAuth2AccessTokenStorage<TAccessToken> {
public class SharedPreferencesOAuth2AccessTokenStorage<TAccessToken extends OAuth2AccessToken> implements OAuth2AccessTokenStorage {

// Constants

Expand Down Expand Up @@ -50,19 +50,18 @@ public SharedPreferencesOAuth2AccessTokenStorage(SharedPreferences sharedPrefere

// OAuth2AccessTokenStorage

@SuppressWarnings("unchecked")
@Override
public Single<TAccessToken> getStoredAccessToken() {
public Single<OAuth2AccessToken> getStoredAccessToken() {
return Single
.just(mSharedPreferences.getString(ACCESS_TOKEN_PREFERENCES_KEY, null))
.map(json -> (TAccessToken) new Gson().fromJson(json, mTokenClass));
}

@Override
public void storeAccessToken(TAccessToken accessToken) {
public void storeAccessToken(OAuth2AccessToken token) {
mSharedPreferences
.edit()
.putString(ACCESS_TOKEN_PREFERENCES_KEY, new Gson().toJson(accessToken))
.putString(ACCESS_TOKEN_PREFERENCES_KEY, new Gson().toJson(token))
.apply();
}

Expand Down

0 comments on commit 9da491f

Please sign in to comment.