Paginate method using go 1.18 type parameters
The wind of change is blowing in the go community as the type parameters (otherwise called generics) are slowly, but surely, getting ready to be released.
Indeed, the Go team has been working hard to make the go 1.18 beta 1 available to us all. It is therefore the time to get our hands dirty and check out what the biggest go release of the past 8 years has in store for us !
Installing go 1.18 beta 1
Let’s get started by installing the beta. In order to do so simply use the go install command with the following parameters :
go install golang.org/dl/go1.18beta1@latest && \
go1.18beta1 download
As usual if everything is ok you should not see any output in the console and you should be ready to go (provided you’ve previously installed Go on your machine and properly set up the GOPATH environment varibale).
What we’ll be trying out
Throughout my career as a Go developer there are few times where I wished the language came with generic support.
One of the very first time was when trying to implement a generic function for fetching paginated rows in a SQL database. Basically I wanted to be able to have a private method declared on my Repository concrete implementation and use this method in each listing methods.
Though it was pretty trivial to implement this “generic method” the problem I faced came with scanning the rows into their corresponding Struct. There was simply no ways to do so directly within the method body and therefore I had to resolve to returning a callback function with the SQL rows and delegate the scanning to the caller …
This led to a lot of boilerplate code when using this private methods in my listing methods. Boilerplate code that could have been avoided by the use of type parameters !
Anyways, after this somewhat lenghty and boring explanation, let’s see if the upcoming go proposal is up to the task to help us tackle this common issue.
Traditional approach
Let’s get started by looking at our models. We’ll define a structure representing an HTTP request for fetching a paginated resource along with a paginated response and two dummy resources, namely a blog post and a blog category.
package models
// PaginationRequest is a struct that represents a pagination request
type PaginationRequest struct {
Page int `json:"page" query:"page"`
PerPage int `json:"perPage" query:"per_page"`
OrderBy []string `json:"orderBy" query:"order_by"`
OrderDir string `json:"orderDir" query:"order_dir"`
}
// PaginationResponse is a struct that represents the response of a paginated request
type PaginationResponse struct {
Page int `json:"page"`
Items interface{} `json:"items"`
PerPage int `json:"perPage"`
PrevPage int `json:"prevPage"`
NextPage int `json:"nextPage"`
TotalPage int `json:"totalPage"`
TotalItems int64 `json:"totalItems"`
}
// BlogPost is a struct that represents a blog post
type BlogPost struct {
ID int64 `json:"id" db:"id"`
Title string `json:"title" db:"title"`
Body string `json:"body" db:"body"`
CreatedAt time.Time `json:"createdAt" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
}
// BlogCategory is a struct that represents a blog category
type BlogCategory struct {
ID int64 `json:"id" db:"id"`
Title string `json:"title" db:"title"`
Body string `json:"body" db:"body"`
CreatedAt time.Time `json:"createdAt" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
}
Nest we’ll define a very basic store
package with it’s corresponding Repository interface.
package store
import (
"context"
"github.com/henripqt/lab/pagination/pkg/models"
)
// Repository is an interface that defines the methods that a store must implement
type Repository interface {
GetBlogPosts(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse, error)
GetBlogCategories(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse, error)
Close() error
}
// repository is the concrete implementation of the Repository interface
type repository struct {
repository Repository
}
// NewRepository returns a new instance of the Repository interface
func NewReposoitory(r Repository) Repository {
return &repository{
repository: r,
}
}
var _ Repository = (*repository)(nil)
// GetBlogPosts returns paginated blog posts from the underlying Repository
func (r *repository) GetBlogPosts(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse, error) {
return r.repository.GetBlogPosts(ctx, paginationReq)
}
// GetBlogPosts returns paginated blog categories from the underlying Repository
func (r *repository) GetBlogCategories(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse, error) {
return r.repository.GetBlogCategories(ctx, paginationReq)
}
// Close closes the underlying Repository connection
func (r *repository) Close() error {
return r.repository.Close()
}
Finally we’ll have a very basic concrete implementation of the store.Repository
interface for a postgres
database. You’ll notice that we’re using two neat packages when it comes to interact with sql databases in go : sqlx
along with squirrel
.
package store
import (
"context"
"fmt"
"log"
"math"
"strconv"
"strings"
"github.com/Masterminds/squirrel"
"github.com/henripqt/lab/pagination/pkg/models"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"golang.org/x/sync/errgroup"
)
type PGRepository struct {
sq squirrel.StatementBuilderType
db *sqlx.DB
}
var _ Repository = (*PGRepository)(nil)
func NewPGRepository(userName, password, dbName string) Repository {
db, err := sqlx.Connect("postgres", fmt.Sprintf("user=%v password=%v dbname=%v sslmode=disable", userName, password, dbName))
if err != nil {
log.Fatalln(err)
}
return &PGRepository{
sq: squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar),
db: db,
}
}
// GetBLogPosts returns paginated blog posts
func (r *PGRepository) GetBlogPosts(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse, error) {
query, queryArgs, err := r.sq.Select("*").From("blog_posts").ToSql()
if err != nil {
return nil, err
}
countQuery, countQueryArgs, err := r.sq.Select("count(*)").From("blog_posts").ToSql()
if err != nil {
return nil, err
}
rows, pRes, err := r.paginate(
ctx,
query,
queryArgs,
countQuery,
countQueryArgs,
paginationReq,
)
if err != nil {
return nil, err
}
defer rows.Close()
blogPosts := make([]models.BlogPost, 0)
for rows.Next() {
var blogPost models.BlogPost
err := rows.StructScan(&blogPost)
if err != nil {
return nil, err
}
blogPosts = append(blogPosts, blogPost)
}
pRes.Items = blogPosts
return pRes, nil
}
// GetBlogCategories returns paginated blog categories
func (r *PGRepository) GetBlogCategories(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse, error) {
query, queryArgs, err := r.sq.Select("*").From("blog_categories").ToSql()
if err != nil {
return nil, err
}
countQuery, countQueryArgs, err := r.sq.Select("count(*)").From("blog_categories").ToSql()
if err != nil {
return nil, err
}
rows, pRes, err := r.paginate(
ctx,
query,
queryArgs,
countQuery,
countQueryArgs,
paginationReq,
)
if err != nil {
return nil, err
}
defer rows.Close()
blogCategories := make([]models.BlogCategory, 0)
for rows.Next() {
var blogCategory models.BlogCategory
err := rows.StructScan(&blogCategory)
if err != nil {
return nil, err
}
blogCategories = append(blogCategories, blogCategory)
}
pRes.Items = blogCategories
return pRes, nil
}
// Close allows for closing the database connection
func (r *PGRepository) Close() error {
return r.db.Close()
}
// paginate is a helper function for fetching paginated ressources
func (r *PGRepository) paginate(
ctx context.Context,
query string,
queryArgs []interface{},
countQuery string,
countQueryArgs []interface{},
paginationReq models.PaginationRequest,
) (*sqlx.Rows, *models.PaginationResponse, error) {
paginationRes := models.PaginationResponse{
Page: paginationReq.Page,
PerPage: paginationReq.PerPage,
}
g, _ := errgroup.WithContext(ctx)
// Retrieve the total number of items
g.Go(func() error {
return r.db.GetContext(
ctx,
&paginationRes.TotalItems,
countQuery,
countQueryArgs...,
)
})
// Retrieve the items
var rows *sqlx.Rows
g.Go(func() error {
var err error
rows, err = r.db.QueryxContext(
ctx,
r.decoratePaginatedQuery(query, paginationReq),
queryArgs...,
)
return err
})
if err := g.Wait(); err != nil {
return nil, nil, err
}
paginationRes.TotalPage = r.getTotalPage(int(paginationRes.TotalItems), paginationRes.PerPage)
paginationRes.PrevPage = r.getPrevPage(paginationRes.Page)
paginationRes.NextPage = r.getNextPage(paginationRes.Page, paginationRes.TotalPage)
return rows, &paginationRes, nil
}
// getTotalPage is a helper function for getting the total number of pages
func (r *PGRepository) getTotalPage(totalItems, perPage int) int {
return int(math.Ceil(float64(totalItems) / float64(perPage)))
}
// getPrevPage is a helper function for getting the previous page
func (r *PGRepository) getPrevPage(currentPage int) int {
if currentPage >= 2 {
return currentPage - 1
}
return currentPage
}
// getNextPage is a helper function for getting the next page
func (r *PGRepository) getNextPage(currentPage, totalPage int) int {
if currentPage >= totalPage {
return currentPage
}
return currentPage + 1
}
func (r *PGRepository) decoratePaginatedQuery(query string, pReq models.PaginationRequest) string {
q := strings.Builder{}
q.WriteString(query)
if len(pReq.OrderBy) > 0 {
// ORDER BY instructions
q.WriteRune(' ')
q.WriteString("ORDER BY")
for i, orderBy := range pReq.OrderBy {
if i > 0 {
q.WriteRune(',')
}
q.WriteRune(' ')
q.WriteString(orderBy)
}
q.WriteRune(' ')
if len(pReq.OrderDir) == 0 {
q.WriteString("DESC")
} else {
q.WriteString(pReq.OrderDir)
}
}
// LIMIT instruction
q.WriteRune(' ')
q.WriteString("LIMIT")
q.WriteRune(' ')
q.WriteString(strconv.Itoa(pReq.PerPage))
// OFFSET instruction
q.WriteRune(' ')
q.WriteString("OFFSET")
q.WriteRune(' ')
q.WriteString(strconv.Itoa(pReq.PerPage * (pReq.Page - 1)))
return q.String()
}
The part that we should pay attention to here is the paginate
method and more specifically it’s return signature : (*sqlx.Rows, *models.PaginationResponse, error)
.
As you can see we end up returning the *sqlx.Rows
along with a *models.PaginationResponse
.
We do so because there’s simply no easy way for us to have the paginate method scan our sqlx.Rows into the appropriate structures and we therefore have to let the called handle the scanning :
blogPosts := make([]models.BlogPost, 0)
for rows.Next() {
var blogPost models.BlogPost
err := rows.StructScan(&blogPost)
if err != nil {
return nil, err
}
blogPosts = append(blogPosts, blogPost)
}
blogCategories := make([]models.BlogCategory, 0)
for rows.Next() {
var blogCategory models.BlogCategory
err := rows.StructScan(&blogCategory)
if err != nil {
return nil, err
}
blogCategories = append(blogCategories, blogCategory)
}
Type parameters to the rescue
Now let’s see how can type parameters
help us getting a more scalable code.
We’ll start by refactoring our models in order to inject a generic type of any
instead of the interface{}
.
package models
// PaginationResponse is a struct that represents the response of a paginated request
type PaginationResponse[T any] struct {
Page int `json:"page"`
Items T `json:"items"`
PerPage int `json:"perPage"`
PrevPage int `json:"prevPage"`
NextPage int `json:"nextPage"`
TotalPage int `json:"totalPage"`
TotalItems int64 `json:"totalItems"`
}
The notable difference here is the addition of the [T any]
type parameters to the pagination struct signature along with the use of this T
param in lieu of the interface{}
for the items.
Next we’ll update the signature of our store.Repository
interface to reflect the changes made to the model :
// Repository is an interface that defines the methods that a store must implement
type Repository interface {
GetBlogPosts(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse[[]models.BlogPost], error)
GetBlogCategories(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse[[]models.BlogCategory], error)
Close() error
}
// ...
func (r *repository) GetBlogPosts(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse[[]models.BlogPost], error) {
return r.repository.GetBlogPosts(ctx, paginationReq)
}
func (r *repository) GetBlogCategories(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse[[]models.BlogCategory], error) {
return r.repository.GetBlogCategories(ctx, paginationReq)
}
As you can see we now define the pagination response passing with the it’s underlying Items types. We can already see the benefit of type parameters : Our methods signatures no longer show a PaginationResponse
with an esoteric interface{}
type for it’s underlying items but clearly shows what we’ll be paginating.
Finally let’s refactor our concret implementation. And here’s the tricky part (keep in mind I’m learning the ins and out of type parameters and therefore I might not be something right here) : at first one can be tempted to simply update the paginate method in order to pass in a [T any]
type parameter. Something like the following :
func (r *PGRepository) paginate[T any](
ctx context.Context,
query string,
queryArgs []interface{},
countQuery string,
countQueryArgs []interface{},
paginationReq models.PaginationRequest,
) (*models.PaginationResponse[[]T], error)
However this will throw you an AST error. Why so ? Simply because go doesn’t allow parameterized methods as this would bring up a lot of issue with the way go handles implicit interface implementation.
We’ll then have to resort to a less elegant solution : using a custom type with type parameters (known as the facilitator pattern) :
type PGPaginator[T any] []T
We’ll then implement the pagination method onto the newly declared PGPaginator type and, now that we have a way to pass in the struct types onto which sqlx.Rows
need to be scanned, we’ll also move the scanning logic into the method’s body.
// paginate is a helper function for fetching paginated ressources
func (r PGPaginator[T]) paginate(
db *sqlx.DB,
ctx context.Context,
query string,
queryArgs []interface{},
countQuery string,
countQueryArgs []interface{},
paginationReq models.PaginationRequest,
) (*models.PaginationResponse[[]T], error) {
paginationRes := models.PaginationResponse[[]T]{
Page: paginationReq.Page,
PerPage: paginationReq.PerPage,
}
g, _ := errgroup.WithContext(ctx)
// Retrieve the total number of items
g.Go(func() error {
return db.GetContext(
ctx,
&paginationRes.TotalItems,
countQuery,
countQueryArgs...,
)
})
// Retrieve the items
var rows *sqlx.Rows
g.Go(func() error {
var err error
rows, err = db.QueryxContext(
ctx,
r.decoratePaginatedQuery(query, paginationReq),
queryArgs...,
)
return err
})
if err := g.Wait(); err != nil {
return nil, err
}
defer rows.Close()
items := make([]T, 0)
for rows.Next() {
var item T
err := rows.StructScan(&item)
if err != nil {
return nil, err
}
items = append(items, item)
}
paginationRes.Items = items
paginationRes.TotalPage = r.getTotalPage(int(paginationRes.TotalItems), paginationRes.PerPage)
paginationRes.PrevPage = r.getPrevPage(paginationRes.Page)
paginationRes.NextPage = r.getNextPage(paginationRes.Page, paginationRes.TotalPage)
return &paginationRes, nil
}
// getTotalPage is a helper function for getting the total number of pages
func (r PGPaginator[T]) getTotalPage(totalItems, perPage int) int {
return int(math.Ceil(float64(totalItems) / float64(perPage)))
}
// getPrevPage is a helper function for getting the previous page
func (r PGPaginator[T]) getPrevPage(currentPage int) int {
if currentPage >= 2 {
return currentPage - 1
}
return currentPage
}
// getNextPage is a helper function for getting the next page
func (r PGPaginator[T]) getNextPage(currentPage, totalPage int) int {
if currentPage >= totalPage {
return currentPage
}
return currentPage + 1
}
func (r PGPaginator[T]) decoratePaginatedQuery(query string, pReq models.PaginationRequest) string {
q := strings.Builder{}
q.WriteString(query)
if len(pReq.OrderBy) > 0 {
// ORDER BY instructions
q.WriteRune(' ')
q.WriteString("ORDER BY")
for i, orderBy := range pReq.OrderBy {
if i > 0 {
q.WriteRune(',')
}
q.WriteRune(' ')
q.WriteString(orderBy)
}
q.WriteRune(' ')
if len(pReq.OrderDir) == 0 {
q.WriteString("DESC")
} else {
q.WriteString(pReq.OrderDir)
}
}
// LIMIT instruction
q.WriteRune(' ')
q.WriteString("LIMIT")
q.WriteRune(' ')
q.WriteString(strconv.Itoa(pReq.PerPage))
// OFFSET instruction
q.WriteRune(' ')
q.WriteString("OFFSET")
q.WriteRune(' ')
q.WriteString(strconv.Itoa(pReq.PerPage * (pReq.Page - 1)))
return q.String()
}
Notice that we’ve added a new argument to the paginate
method as we’ll now need to inject the sqlx.DB
.
Finally let’s refactor and cleanup our GetBlogPosts
and GetBlogCategories
methods :
// GetBLogPosts returns paginated blog posts
func (r *PGRepository) GetBlogPosts(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse[[]models.BlogPost], error) {
query, queryArgs, err := r.sq.Select("*").From("blog_posts").ToSql()
if err != nil {
return nil, err
}
countQuery, countQueryArgs, err := r.sq.Select("count(*)").From("blog_posts").ToSql()
if err != nil {
return nil, err
}
return PGPaginator[models.BlogPost]{}.paginate(
r.db,
ctx,
query,
queryArgs,
countQuery,
countQueryArgs,
paginationReq,
)
}
// GetBLogPosts returns paginated blog categories
func (r *PGRepository) GetBlogCategories(ctx context.Context, paginationReq models.PaginationRequest) (*models.PaginationResponse[[]models.BlogCategory], error) {
query, queryArgs, err := r.sq.Select("*").From("blog_categories").ToSql()
if err != nil {
return nil, err
}
countQuery, countQueryArgs, err := r.sq.Select("count(*)").From("blog_categories").ToSql()
if err != nil {
return nil, err
}
return PGPaginator[models.BlogCategory]{}.paginate(
r.db,
ctx,
query,
queryArgs,
countQuery,
countQueryArgs,
paginationReq,
)
}
As you can see we now no longer need to implement the scanning logic into each method, leaving the lazy developer that we all are happy and free to get another cup of coffee !
But there’s more ! Recall how our models.PaginationResponse
used an interface before ? Well that left us unable to know what was inside our struct without casting the items.
Adding up to more boilerplate, more error checking, more chances to f**k up something etc…
Well now that we use type parameters we don’t need to do all this as we know exactly the type of our items ! Let’s demonstrate this by adding specific methods to our resources models :
package models
//...
func (b BlogPost) BlogPostMethod() string {
return fmt.Sprintf("I'm a blog post ! Title: %v", b.Title)
}
//...
func (b BlogCategory) BlogCategoryMethod() string {
return fmt.Sprintf("I'm a blog category ! Title: %v", b.Title)
}
We can now call these gloriously useless methods without having to type cast our PaginationResponse.Items
:
package main
//...
func (a *API) blogCategoriesHandler(w http.ResponseWriter, r *http.Request) {
paginationReq := a.parsePaginationReq(r)
paginationResponse, err := a.repository.GetBlogCategories(r.Context(), paginationReq)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
for _, post := range paginationResponse.Items {
fmt.Println(post.BlogCategoryMethod())
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(paginationResponse); err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func (a *API) blogPostsHandler(w http.ResponseWriter, r *http.Request) {
paginationReq := a.parsePaginationReq(r)
paginationResponse, err := a.repository.GetBlogPosts(r.Context(), paginationReq)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
for _, post := range paginationResponse.Items {
fmt.Println(post.BlogPostMethod())
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(paginationResponse); err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}
Conclusion
Type parameters in go, while having some drawbacks are still a welcome addition to an already great language and I for one cannot wait to see how the community will make use of them.