Skip to content

Commit

Permalink
Little refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
lalinsky committed Feb 11, 2024
1 parent 7646325 commit 8b02f50
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
7 changes: 7 additions & 0 deletions common/config.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package common

import (
"database/sql"
"fmt"
"net"
"net/url"
"os"
"strconv"

_ "github.com/lib/pq"
)

type DatabaseConfig struct {
Expand Down Expand Up @@ -50,6 +53,10 @@ func (cfg *DatabaseConfig) URL() *url.URL {
return &u
}

func (cfg *DatabaseConfig) Connect() (*sql.DB, error) {
return sql.Open("postgres", cfg.URL().String())
}

func (cfg *DatabaseConfig) readEnv(prefix string) {
name := os.Getenv(prefix + "NAME")
if name != "" {
Expand Down
8 changes: 4 additions & 4 deletions fpstore/cmd/fpstore/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"database/sql"
"os"

_ "github.com/lib/pq"
Expand Down Expand Up @@ -98,11 +97,12 @@ func PrepareFingerprintStore(c *cli.Context) (fpstore.FingerprintStore, error) {
config.Password = c.String(PostgresPassword.Name)
config.Database = c.String(PostgresDatabase.Name)

db, err := sql.Open("postgres", config.URL().String())
db, err := config.Connect()
if err != nil {
return nil, err
return nil, errors.WithMessage(err, "failed to connect to database")
}
return fpstore.NewSqlFingerprintStore(db), nil

return fpstore.NewPostgresFingerprintStore(db), nil
}

func PrepareAndRunServer(c *cli.Context) error {
Expand Down
37 changes: 28 additions & 9 deletions fpstore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -67,15 +68,15 @@ func (a *Uint32Array) scanString(src string) error {
return nil
}

type SqlFingerprintStore struct {
type PostgresFingerprintStore struct {
db *sql.DB
}

func NewSqlFingerprintStore(db *sql.DB) *SqlFingerprintStore {
return &SqlFingerprintStore{db: db}
func NewPostgresFingerprintStore(db *sql.DB) *PostgresFingerprintStore {
return &PostgresFingerprintStore{db: db}
}

func (s *SqlFingerprintStore) Insert(ctx context.Context, fp *pb.Fingerprint) (uint64, error) {
func (s *PostgresFingerprintStore) Insert(ctx context.Context, fp *pb.Fingerprint) (uint64, error) {
data, err := EncodeFingerprint(fp)
if err != nil {
return 0, err
Expand All @@ -88,12 +89,30 @@ func (s *SqlFingerprintStore) Insert(ctx context.Context, fp *pb.Fingerprint) (u
return id, nil
}

func (s *SqlFingerprintStore) Delete(ctx context.Context, id uint64) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM fingerprint_v2 WHERE id = $1", id)
var ErrCannotDeleteLegacyFingerprint = errors.New("cannot delete legacy fingerprint")

func (s *PostgresFingerprintStore) Delete(ctx context.Context, id uint64) error {
existsAsV1, err := s.checkV1(ctx, id)
if err != nil {
return err
}
if existsAsV1 {
return ErrCannotDeleteLegacyFingerprint
}
_, err = s.db.ExecContext(ctx, "DELETE FROM fingerprint_v2 WHERE id = $1", id)
return err
}

func (s *SqlFingerprintStore) getV1(ctx context.Context, id uint64) (*pb.Fingerprint, error) {
func (s *PostgresFingerprintStore) checkV1(ctx context.Context, id uint64) (bool, error) {
var count int
err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM fingerprint WHERE id = $1", id).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}

func (s *PostgresFingerprintStore) getV1(ctx context.Context, id uint64) (*pb.Fingerprint, error) {
var hashes Uint32Array
query := "SELECT fingerprint FROM fingerprint WHERE id = $1"
err := s.db.QueryRowContext(ctx, query, id).Scan(&hashes)
Expand All @@ -106,7 +125,7 @@ func (s *SqlFingerprintStore) getV1(ctx context.Context, id uint64) (*pb.Fingerp
return &pb.Fingerprint{Hashes: hashes}, nil
}

func (s *SqlFingerprintStore) getV2(ctx context.Context, id uint64) (*pb.Fingerprint, error) {
func (s *PostgresFingerprintStore) getV2(ctx context.Context, id uint64) (*pb.Fingerprint, error) {
var data []byte
query := "SELECT data FROM fingerprint_v2 WHERE id = $1"
err := s.db.QueryRowContext(ctx, query, id).Scan(&data)
Expand All @@ -119,7 +138,7 @@ func (s *SqlFingerprintStore) getV2(ctx context.Context, id uint64) (*pb.Fingerp
return DecodeFingerprint(data)
}

func (s *SqlFingerprintStore) Get(ctx context.Context, id uint64) (*pb.Fingerprint, error) {
func (s *PostgresFingerprintStore) Get(ctx context.Context, id uint64) (*pb.Fingerprint, error) {
fp, err := s.getV2(ctx, id)
if err != nil {
log.Warnf("failed to get fingerprint from v2 table: %v", err)
Expand Down

0 comments on commit 8b02f50

Please sign in to comment.