diff --git a/backend/app/adapter/db/url.go b/backend/app/adapter/db/url.go index b1fa14b97..978d3387c 100644 --- a/backend/app/adapter/db/url.go +++ b/backend/app/adapter/db/url.go @@ -3,6 +3,7 @@ package db import ( "database/sql" "fmt" + "strings" "github.com/short-d/short/app/adapter/db/table" "github.com/short-d/short/app/entity" @@ -96,6 +97,79 @@ WHERE "%s"=$1;`, return url, nil } +// GetByAliases finds URLs for a list of aliases +func (u URLSql) GetByAliases(aliases []string) ([]entity.URL, error) { + parameterStr := u.composeParamList(len(aliases)) + + // create a list of interface{} to hold aliases for db.Query() + aliasesInterface := []interface{}{} + for _, alias := range aliases { + aliasesInterface = append(aliasesInterface, alias) + } + + var urls []entity.URL + + // TODO: compare performance between Query and QueryRow. Prefer QueryRow for readability + statement := fmt.Sprintf(` +SELECT "%s","%s","%s","%s","%s" +FROM "%s" +WHERE "%s" IN (%s);`, + table.URL.ColumnAlias, + table.URL.ColumnOriginalURL, + table.URL.ColumnExpireAt, + table.URL.ColumnCreatedAt, + table.URL.ColumnUpdatedAt, + table.URL.TableName, + table.URL.ColumnAlias, + parameterStr, + ) + + stmt, err := u.db.Prepare(statement) + if err != nil { + return urls, err + } + defer stmt.Close() + + rows, err := stmt.Query(aliasesInterface...) + if err != nil { + return urls, nil + } + + defer rows.Close() + for rows.Next() { + url := entity.URL{} + err := rows.Scan( + &url.Alias, + &url.OriginalURL, + &url.ExpireAt, + &url.CreatedAt, + &url.UpdatedAt, + ) + if err != nil { + return urls, err + } + + url.CreatedAt = utc(url.CreatedAt) + url.UpdatedAt = utc(url.UpdatedAt) + url.ExpireAt = utc(url.ExpireAt) + + urls = append(urls, url) + } + + return urls, nil +} + +// composeParamList converts an slice to a parameters string with format: $1, $2, $3, ... +func (u URLSql) composeParamList(numParams int) string { + params := make([]string, 0, numParams) + for i := 0; i < numParams; i++ { + params = append(params, fmt.Sprintf("$%d", i+1)) + } + + parameterStr := strings.Join(params, ", ") + return parameterStr +} + // NewURLSql creates URLSql func NewURLSql(db *sql.DB) *URLSql { return &URLSql{ diff --git a/backend/app/adapter/db/url_integration_test.go b/backend/app/adapter/db/url_integration_test.go index 80886bebe..e68a565b8 100644 --- a/backend/app/adapter/db/url_integration_test.go +++ b/backend/app/adapter/db/url_integration_test.go @@ -242,6 +242,87 @@ func TestURLSql_Create(t *testing.T) { } } +func TestURLSql_GetByAliases(t *testing.T) { + twoYearsAgo := mustParseTime(t, "2017-05-01T08:02:16-07:00") + now := mustParseTime(t, "2019-05-01T08:02:16-07:00") + + testCases := []struct { + name string + tableRows []urlTableRow + aliases []string + hasErr bool + expectedURLs []entity.URL + }{ + { + name: "alias not found", + tableRows: []urlTableRow{}, + aliases: []string{"220uFicCJj"}, + hasErr: false, + }, + { + name: "found url", + tableRows: []urlTableRow{ + { + alias: "220uFicCJj", + longLink: "http://www.google.com", + createdAt: &twoYearsAgo, + expireAt: &now, + updatedAt: &now, + }, + { + alias: "yDOBcj5HIPbUAsw", + longLink: "http://www.facebook.com", + createdAt: &twoYearsAgo, + expireAt: &now, + updatedAt: &now, + }, + }, + aliases: []string{"220uFicCJj", "yDOBcj5HIPbUAsw"}, + hasErr: false, + expectedURLs: []entity.URL{ + entity.URL{ + Alias: "220uFicCJj", + OriginalURL: "http://www.google.com", + CreatedAt: &twoYearsAgo, + ExpireAt: &now, + UpdatedAt: &now, + }, + entity.URL{ + Alias: "yDOBcj5HIPbUAsw", + OriginalURL: "http://www.facebook.com", + CreatedAt: &twoYearsAgo, + ExpireAt: &now, + UpdatedAt: &now, + }, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mdtest.AccessTestDB( + dbConnector, + dbMigrationTool, + dbMigrationRoot, + dbConfig, + func(sqlDB *sql.DB) { + insertURLTableRows(t, sqlDB, testCase.tableRows) + + urlRepo := db.NewURLSql(sqlDB) + urls, err := urlRepo.GetByAliases(testCase.aliases) + + if testCase.hasErr { + mdtest.NotEqual(t, nil, err) + return + } + mdtest.Equal(t, nil, err) + mdtest.Equal(t, testCase.expectedURLs, urls) + }, + ) + }) + } +} + func insertURLTableRows(t *testing.T, sqlDB *sql.DB, tableRows []urlTableRow) { for _, tableRow := range tableRows { _, err := sqlDB.Exec( diff --git a/backend/app/usecase/repository/url.go b/backend/app/usecase/repository/url.go index 792dcd587..1d79dbfc4 100644 --- a/backend/app/usecase/repository/url.go +++ b/backend/app/usecase/repository/url.go @@ -7,4 +7,5 @@ type URL interface { IsAliasExist(alias string) (bool, error) GetByAlias(alias string) (entity.URL, error) Create(url entity.URL) error + GetByAliases(aliases []string) ([]entity.URL, error) } diff --git a/backend/app/usecase/repository/url_fake.go b/backend/app/usecase/repository/url_fake.go index ff1f6dd7c..dcd4f7e67 100644 --- a/backend/app/usecase/repository/url_fake.go +++ b/backend/app/usecase/repository/url_fake.go @@ -45,6 +45,21 @@ func (u URLFake) GetByAlias(alias string) (entity.URL, error) { return url, nil } +// GetByAliases finds all URL for a list of aliases +func (u URLFake) GetByAliases(aliases []string) ([]entity.URL, error) { + var urls []entity.URL + + for _, alias := range aliases { + url, err := u.GetByAlias(alias) + + if err != nil { + return urls, err + } + urls = append(urls, url) + } + return urls, nil +} + // NewURLFake creates in memory URL repository func NewURLFake(urls map[string]entity.URL) URLFake { return URLFake{