src.dualinventive.com/go/companies-service/internal/storage/gorm/companyRepository.go

98 lines
2.5 KiB
Go

package gorm
import (
"github.com/jinzhu/gorm"
"src.dualinventive.com/go/companies-service/internal/domain"
"src.dualinventive.com/go/companies-service/internal/storage"
"src.dualinventive.com/go/lib/dilog"
)
//CompanyRepository reflects a gorm based CompanyRepository
type CompanyRepository struct {
DB *gorm.DB
logger dilog.Logger
}
var sortMapping = map[domain.SortCol]SortColQuery{ //nolint: gochecknoglobals
domain.SortColIDAsc: SortColQueryIDAsc,
domain.SortColIDDesc: SortColQueryIDDesc,
domain.SortColNameAsc: SortColQueryNameAsc,
domain.SortColNameDesc: SortColQueryNameDesc,
}
//NewCompanyRepository returns a new gorm based storage.CompanyRepository
func NewCompanyRepository(logger dilog.Logger, host, port, name, user, pass string) (storage.CompanyRepository, error) {
db, err := openDB(host, port, name, user, pass)
if err != nil {
return nil, err
}
repo := &CompanyRepository{
DB: db,
logger: logger,
}
return repo, nil
}
//GetCompanyByID returns a company by the given companyID or nil if not found.
func (r *CompanyRepository) GetCompanyByID(companyID uint64) (*domain.Company, error) {
var company domain.Company
err := r.DB.
Where(&domain.Company{ID: uint(companyID)}).
First(&company).Error
if err == gorm.ErrRecordNotFound {
return nil, nil
}
if err != nil {
return nil, err
}
return &company, nil
}
//GetCompanies returns all companies or nil if not found.
func (r *CompanyRepository) GetCompanies(
page uint64, perPage uint64, sort domain.SortCol) (
[]domain.Company, uint64, error) {
var companies []domain.Company
var count uint64
order, ok := sortMapping[sort]
if !ok {
order = SortColQueryIDAsc
}
err := r.DB.
Where(&domain.Company{ID: uint(0)}).
Order(order.String()).
Offset((page - 1) * perPage).
Limit(perPage).
Find(&companies).
Offset(0). // Count doesn't work with offset (https://github.com/jinzhu/gorm/issues/1752)
Count(&count).
Error
if err == gorm.ErrRecordNotFound {
return nil, 0, nil
}
if err != nil {
return nil, 0, err
}
return companies, count, nil
}
// SortColQuery represents the sorting column
type SortColQuery string
// SortColQuery constants
const (
SortColQueryIDAsc SortColQuery = "company_id asc"
SortColQueryIDDesc SortColQuery = "company_id desc"
SortColQueryNameAsc SortColQuery = "company_name asc"
SortColQueryNameDesc SortColQuery = "company_name desc"
)
// String for MsgType
func (scq SortColQuery) String() string {
return string(scq)
}