package companies import ( "errors" "testing" "github.com/stretchr/testify/require" "src.dualinventive.com/go/companies-service/internal/domain" "src.dualinventive.com/go/mtinfo-go" ) var companyService *CompanyService //nolint:gochecknoglobals var companyRepository *TestCompanyRepository //nolint:gochecknoglobals func TestGetCompanyByID_GivenUnknownID_ReturnsCompanyNotFoundError(t *testing.T) { before(t) //given companyRepository.returnMe = []*domain.Company{nil} companyRepository.returnErr = []error{nil} //when company, err := companyService.GetCompanyByID("", 123) //then require.Nil(t, company) require.NotNil(t, err) require.EqualError(t, err, "company not found") } func TestGetCompanyByID_GivenKnownID_ReturnsCompany(t *testing.T) { before(t) //given companyRepository.returnMe = []*domain.Company{{ID: 123, Name: "someCompanyname"}} companyRepository.returnErr = []error{nil} //when company, err := companyService.GetCompanyByID("", 123) //then require.NotNil(t, company) require.Equal(t, company.Name, "someCompanyname") require.Nil(t, err) } func TestGetCompanyByID_WhenRepositoryError_ReturnsRepositoryError(t *testing.T) { before(t) //given companyRepository.returnMe = []*domain.Company{nil} companyRepository.returnErr = []error{errors.New("some repository error")} //when company, err := companyService.GetCompanyByID("", 123) //then require.Nil(t, company) require.NotNil(t, err) require.EqualError(t, err, "failed to fetch company: some repository error") } func TestGetCompanies_GivenCorrectParams_ReturnsCompanies(t *testing.T) { before(t) //given companyRepository.returnMeMulti = [][]domain.Company{{ {ID: 123, Name: "someCompanyname"}, {ID: 321, Name: "anotherCompanyname"}, }} companyRepository.returnErr = []error{nil} //when companies, count, err := companyService.GetCompanies("", 1, 1, "") //then require.NotNil(t, companies) require.Equal(t, count, uint64(2)) require.Equal(t, companies[0].Name, "someCompanyname") require.Nil(t, err) } func TestGetCompanies_GivenZeroPage_ReturnsInvalidArgument(t *testing.T) { //nolint: dupl before(t) //given companyRepository.returnMeMulti = [][]domain.Company{nil} companyRepository.returnErr = []error{nil} //when companies, count, err := companyService.GetCompanies("", 0, 1, "") //then require.Nil(t, companies) require.Equal(t, count, uint64(0)) require.IsType(t, new(ErrInvalidArgument), err) } func TestGetCompanies_GivenZeroPerPage_ReturnsInvalidArgument(t *testing.T) { //nolint: dupl before(t) //given companyRepository.returnMeMulti = [][]domain.Company{nil} companyRepository.returnErr = []error{nil} //when companies, count, err := companyService.GetCompanies("", 1, 0, "") //then require.Nil(t, companies) require.Equal(t, count, uint64(0)) require.IsType(t, new(ErrInvalidArgument), err) } func TestGetCompanies_GivenTooLargePerPage_ReturnsInvalidArgument(t *testing.T) { //nolint: dupl before(t) //given companyRepository.returnMeMulti = [][]domain.Company{nil} companyRepository.returnErr = []error{nil} //when companies, count, err := companyService.GetCompanies("", 1, 201, "") //then require.Nil(t, companies) require.Equal(t, count, uint64(0)) require.IsType(t, new(ErrInvalidArgument), err) } func TestGetCompanies_WhenRepositoryError_ReturnsRepositoryError(t *testing.T) { before(t) //given companyRepository.returnMeMulti = [][]domain.Company{nil} companyRepository.returnErr = []error{errors.New("some repository error")} //when companies, count, err := companyService.GetCompanies("", 1, 1, "") //then require.Nil(t, companies) require.Equal(t, count, uint64(0)) require.IsType(t, new(ErrCompanyRepositoryErr), err) } func before(t *testing.T) { companyRepository = NewTestCompanyRepository(t) companyService = &CompanyService{ CompanyRepository: companyRepository, Mtinfo: &mtinfo.Client{ Auth: &testAuthClient{}, }, } } //NewTestCompanyRepository returns a mockable company repository. func NewTestCompanyRepository(t *testing.T) *TestCompanyRepository { return &TestCompanyRepository{t: t, callCount: 0, returnMe: []*domain.Company{nil}, returnErr: []error{nil}} } type TestCompanyRepository struct { t *testing.T callCount int returnMe []*domain.Company returnMeMulti [][]domain.Company returnErr []error } func (r TestCompanyRepository) GetCompanyByID(companyID uint64) (*domain.Company, error) { returnMe := r.returnMe[r.callCount] returnErr := r.returnErr[r.callCount] r.callCount++ return returnMe, returnErr } func (r TestCompanyRepository) GetCompanies( page uint64, perPage uint64, sort domain.SortCol) ( []domain.Company, uint64, error) { returnMeMulti := r.returnMeMulti[r.callCount] returnErr := r.returnErr[r.callCount] r.callCount++ return returnMeMulti, uint64(len(returnMeMulti)), returnErr } //testAuthClient contains authentication related operations over GRPC type testAuthClient struct{} //VerifyToken verifies the given token to see if its valid. //When a public key is configured, token is locally verified using JWT. //When no public key is configured, token is remotely verified using GRPC. //Returns false when token signed portion is invalid, or token is expired. func (as *testAuthClient) VerifyToken(token string) (bool, error) { return true, nil } func (as *testAuthClient) VerifyTokenRemotely(token string) (bool, error) { return true, nil } func (as *testAuthClient) Login(username, companyCode, password string) (string, error) { return "", nil } func (as *testAuthClient) Logout(token string) error { return nil } func (as *testAuthClient) Me(token string) (*mtinfo.User, error) { return nil, nil } func (as *testAuthClient) RequestPasswordReset(username string) error { return nil } func (as *testAuthClient) RedeemPasswordReset(username, resetCode, password, passwordVerify string) error { return nil } func (as *testAuthClient) ListTokens(token string) (mtinfo.OpaqueTokens, error) { return nil, nil } func (as *testAuthClient) DeleteToken(token string, opaqueToken string) error { return nil } func (as *testAuthClient) UserAgent() string { return "" } func (as *testAuthClient) SetUserAgent(string) {}