Go+typescript+GraphQL+react構建簡書網站(三) 編寫Model

補遺:數據庫增長Tag表

新建tag表:前端

CREATE TABLE "public"."tag" (
  "id" int8 NOT NULL,
  "name" varchar(255) NOT NULL,
  "created_at" timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
  "updated_at" timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
  "deleted_at" timestamp(6) NOT NULL,
  PRIMARY KEY ("id")
)
;

COMMENT ON COLUMN "public"."tag"."id" IS 'ID';

COMMENT ON COLUMN "public"."tag"."name" IS '標籤名';

COMMENT ON COLUMN "public"."tag"."created_at" IS '建立時間';

COMMENT ON COLUMN "public"."tag"."updated_at" IS '更新時間';

COMMENT ON COLUMN "public"."tag"."deleted_at" IS '刪除時間';

這裏不得不說一下,因爲是一邊寫代碼一邊寫文章(文章的做用只是用來給本身釐清思路),因此文章中的代碼內容極可能下一次就變了,畢竟文章中的代碼,只是我初步寫時的思路,確定存在錯漏之處,後續會慢慢完善。如要看最新的代碼,還請移步:https://github.com/unrotten/h...git

編寫CURD基礎方法

依然先看結果,修改db.go文件:github

package model

import (
    "context"
    "database/sql"
    "database/sql/driver"
    "fmt"
    "github.com/jmoiron/sqlx"
    _ "github.com/lib/pq"
    "github.com/rs/zerolog"
    "github.com/sony/sonyflake"
    "github.com/spf13/viper"
    "github.com/unrotten/builder"
    "github.com/unrotten/sqlex"
    "log"
    "os"
    "reflect"
    "time"
)

var (
    DB        *sqlx.DB
    psql      sqlex.StatementBuilderType
    idfetcher *sonyflake.Sonyflake
)

const defaultSkip int = 2

type cv map[string]interface{}

type where []sqlex.Sqlex

type result struct {
    b       builder.Builder
    success bool
}

// 初始化數據庫鏈接
func init() {
    viper.AddConfigPath("../config") // 測試使用
    viper.ReadInConfig()
    // 獲取數據庫配置信息
    user := viper.Get("storage.user")
    password := viper.Get("storage.password")
    host := viper.Get("storage.host")
    port := viper.Get("storage.port")
    dbname := viper.Get("storage.dbname")

    // 鏈接數據庫
    psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
        host, port, user, password, dbname)
    DB = sqlx.MustOpen("postgres", psqlInfo)
    if err := DB.Ping(); err != nil {
        log.Fatalf("鏈接數據庫失敗:%s", err)
    }

    // 初始化sql構建器,指定format形式
    psql = sqlex.StatementBuilder.PlaceholderFormat(sqlex.Dollar)
    sqlex.SetLogger(os.Stdout)

    // 初始化sonyflake
    st := sonyflake.Settings{
        StartTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.Local),
    }
    idfetcher = sonyflake.NewSonyflake(st)
}

func get(query *sql.Rows, columnTypes []*sql.ColumnType, logger zerolog.Logger) result {
    dest := make([]interface{}, len(columnTypes))
    for index, col := range columnTypes {
        switch col.ScanType().String() {
        case "string", "interface {}":
            dest[index] = &sql.NullString{}
        case "bool":
            dest[index] = &sql.NullBool{}
        case "float64":
            dest[index] = &sql.NullFloat64{}
        case "int32":
            dest[index] = &sql.NullInt32{}
        case "int64":
            dest[index] = &sql.NullInt64{}
        case "time.Time":
            dest[index] = &sql.NullTime{}
        default:
            dest[index] = reflect.New(col.ScanType()).Interface()
        }
    }
    err := query.Scan(dest...)
    if err != nil {
        logger.Error().Caller(2).Err(err).Send()
        return result{success: false}
    }
    build := builder.EmptyBuilder
    for index, col := range columnTypes {
        switch val := dest[index].(type) {
        case driver.Valuer:
            var value interface{}
            switch col.ScanType().String() {
            case "string", "interface {}":
                value = dest[index].(*sql.NullString).String
            case "bool":
                value = dest[index].(*sql.NullBool).Bool
            case "float64":
                value = dest[index].(*sql.NullFloat64).Float64
            case "int32":
                value = dest[index].(*sql.NullInt32).Int32
            case "int64":
                value = dest[index].(*sql.NullInt64).Int64
            case "time.Time":
                value = dest[index].(*sql.NullTime).Time
            }
            build = builder.Set(build, col.Name(), value).(builder.Builder)
        default:
            build = builder.Set(build, col.Name(), val).(builder.Builder)
        }
    }
    return result{success: true, b: build}
}

func selectList(ctx context.Context, table string, where where, columns ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    var selectBuilder sqlex.SelectBuilder
    if len(columns) > 0 {
        selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is null")
    } else {
        selectBuilder = psql.Select("*").From(table).Where("deleted_at is null")
    }
    for _, arg := range where {
        selectBuilder = selectBuilder.Where(arg)
    }
    query, err := selectBuilder.RunWith(tx).Query()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    columnTypes, err := query.ColumnTypes()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }
    var resultSlice []interface{}
    for query.Next() {
        r := get(query, columnTypes, logger)
        if !r.success {
            return r
        }
        resultSlice = append(resultSlice, r.b)
    }
    return result{success: true, b: builder.Set(builder.EmptyBuilder, "list", resultSlice).(builder.Builder)}
}

func selectOne(ctx context.Context, table string, where where, columns ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    var selectBuilder sqlex.SelectBuilder
    if len(columns) > 0 {
        selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is null").Limit(1)
    } else {
        selectBuilder = psql.Select("*").From(table).Where("deleted_at is null").Limit(1)
    }
    for _, arg := range where {
        selectBuilder = selectBuilder.Where(arg)
    }
    query, err := selectBuilder.RunWith(tx).Query()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    columnTypes, err := query.ColumnTypes()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    if query.Next() {
        return get(query, columnTypes, logger)
    }
    return result{success: false}
}

func selectReal(ctx context.Context, table string, where where, columns ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    var selectBuilder sqlex.SelectBuilder
    if len(columns) > 0 {
        selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is not null")
    } else {
        selectBuilder = psql.Select("*").From(table).Where("deleted_at is not null")
    }
    for _, arg := range where {
        selectBuilder = selectBuilder.Where(arg)
    }
    query, err := selectBuilder.RunWith(tx).Query()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    columnTypes, err := query.ColumnTypes()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }
    var resultSlice []interface{}
    for query.Next() {
        r := get(query, columnTypes, logger)
        if !r.success {
            return r
        }
        resultSlice = append(resultSlice, r.b)
    }
    return result{success: true, b: builder.Set(builder.EmptyBuilder, "list", resultSlice).(builder.Builder)}
}

func insertOne(ctx context.Context, table string, cv cv) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)
    build := builder.EmptyBuilder
    cv["created_at"], cv["updated_at"] = time.Now(), time.Now()
    columns, values := make([]string, 0, len(cv)), make([]interface{}, 0, len(cv))
    for col, value := range cv {
        build = builder.Set(build, col, value).(builder.Builder)
        columns, values = append(columns, col), append(values, value)
    }
    r, err := psql.Insert(table).Columns(columns...).Values(values...).RunWith(tx).Exec()
    return assertSqlResult(r, err, logger)
}

func update(ctx context.Context, table string, cv cv, where where, directSet ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)
    cv["updated_at"] = time.Now()
    updateBuilder := psql.Update(table).SetMap(cv).Where("deleted_at is null")
    for _, set := range directSet {
        updateBuilder = updateBuilder.DirectSet(set)
    }
    for _, arg := range where {
        updateBuilder = updateBuilder.Where(arg)
    }
    r, err := updateBuilder.RunWith(tx).Exec()
    return assertSqlResult(r, err, logger)
}

// note: if where is null,then will delete the whole table
func remove(ctx context.Context, table string, where where) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    updateBuilder := psql.Update(table).Set("deleted_at", time.Now()).Where("deleted_at is null")
    for _, arg := range where {
        updateBuilder = updateBuilder.Where(arg)
    }
    r, err := updateBuilder.RunWith(tx).Exec()
    return assertSqlResult(r, err, logger)
}

func assertSqlResult(r sql.Result, err error, logger zerolog.Logger, skip ...int) result {
    sk := defaultSkip
    if len(skip) > 0 {
        sk += skip[0]
    }
    if err != nil {
        logger.Error().Caller(sk).Err(err).Send()
        return result{success: false}
    }
    affected, err := r.RowsAffected()
    if err != nil {
        logger.Error().Caller(2).Err(err).Send()
        return result{success: false}
    }
    if affected == 0 {
        return result{success: false}
    }
    return result{success: true}
}

在這裏咱們只看查詢,selectList和selectOne依託於get方法實現,而get的核心就是設值。由於在數據庫中,數據存在NULL的狀況,而Go中的基礎類型如string,int64等並不支持,因此咱們必須使用其對應的sql.NullString等類型去scan。做者這裏爲了保持model中定義的struct可以繼續使用string等基礎類型,在get中進行了類型的判斷,不可空的基礎類型經過兩次switch轉換,最終即使對於NULL值,也會獲得基礎類型的默認空值。web

在get方法中,咱們使用reflect.New(col.ScanType()).Interface()方法,得到字段對應的指針值,這裏使用了反射,效果等同於new()。sql

在記錄錯誤日誌logger.Error().Caller(sk).Err(err).Send()時,咱們先指定了日誌的類別爲Error,再調用了Caller(sk),獲取運行時上下文。Caller的原理是調用runtime.Caller(skip)方法,以獲取指定的代碼段位置。最終效果就是一般咱們程序報錯時,在控制檯可以看到的,各個文件的指定行。數據庫

在get方法的最後,咱們經過builder.Set(build, col.Name(), value).(builder.Builder)這樣的代碼段,將數據對應的名字和值存入指定的builer中。builder的效果相似於map,只是使用builder庫能夠更方便直接將map轉爲指定的struct。json

再把目光轉到selectOne方法,能夠看到咱們從上下文context中獲取了logger和事務tx,這裏是方便後續的工做。咱們須要注意的是,sqlex庫進行sql構建時,嚴格按照了sql語法的規定,固然where和from之間的順序在這裏能夠不用管。咱們在初始化selectBuilder的時候,Where("1=1")給定了一個初始的where條件,這樣作的用意是,因爲sqlex庫提供了IF操做,譬如:數組

psql.Select("*").From("user").Where(sqlex.IF{Condition: "a" == "", Sq: sqlex.Eq{"a": "3"}})

這樣的代碼,因爲「a」==「」不知足,因此IF中的」a」==「3」並不會被歸入構建器中,但是也由於調用了Where,因此構建器中sql中必然會增長一個where,最終獲得錯誤的sql:SELECT * FROM "user" WHERE微信

編寫Model

model目錄下新建user.go文件:app

package model

import (
    "context"
    "errors"
    "github.com/unrotten/builder"
    "time"
)

type User struct {
    Id        int64     `json:"id" db:"id"`
    Username  string    `json:"username" db:"username"`
    Email     string    `json:"email" db:"email"`
    Password  string    `json:"password" db:"password"`
    Avatar    string    `json:"avatar" db:"avatar"`
    Gender    string    `json:"gender" db:"gender"`
    Introduce string    `json:"introduce" db:"introduce"`
    State     string    `json:"state" db:"state"`
    Root      bool      `json:"root" db:"root"`
    CreatedAt time.Time `json:"createdAt" db:"created_at"`
    UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
    DeletedAt time.Time `json:"deletedAt" db:"deleted_at"`
}

func GetUsers(ctx context.Context, where where) ([]User, error) {
    result := selectList(ctx, `"user"`, where)
    if !result.success {
        return nil, errors.New("獲取用戶列表失敗")
    }
    list, ok := builder.Get(result.b, "list")
    if !ok {
        return nil, errors.New("獲取用戶列表失敗")
    }
    users := make([]User, 0, len(list.([]interface{})))
    for _, item := range list.([]interface{}) {
        users = append(users, builder.GetStructLikeByTag(item.(builder.Builder), User{}, "db").(User))
    }
    return users, nil
}

func GetUser(ctx context.Context, where where) (User, error) {
    result := selectOne(ctx, `"user"`, where)
    if !result.success {
        return User{}, errors.New("查詢用戶數據失敗")
    }
    return builder.GetStructLikeByTag(result.b, User{}, "db").(User), nil
}

func InsertUser(ctx context.Context, cv map[string]interface{}) (User, error) {
    id, err := idfetcher.NextID()
    if err != nil {
        return User{}, err
    }

    cv["id"] = int64(id)
    result := insertOne(ctx, `"user"`, cv)
    if !result.success {
        return User{}, errors.New("插入用戶數據失敗")
    }
    return builder.GetStructLikeByTag(result.b, User{}, "db").(User), nil
}

func UpdateUser(ctx context.Context, cv cv, where where) error {
    result := update(ctx, `"user"`, cv, where)
    if !result.success {
        return errors.New("更新用戶數據失敗")
    }
    return nil
}

這裏惟一須要注意的是,咱們使用builder.GetStructLikeByTag(result.b, User{}, "db").(User)方法,將CURD中得到的Builder根據指定的tag內容,轉化爲對應結構體。

接下來,就是繼續完善其餘的model。

userCount.go:

package model

import (
    "context"
    "errors"
    "github.com/unrotten/builder"
    "github.com/unrotten/sqlex"
    "time"
)

type UserCount struct {
    Uid        int64     `json:"uid" db:"uid"`
    FansNum    int32     `json:"fansNum" db:"fans_num"`
    FollowNum  int32     `json:"followNum" db:"follow_num"`
    ArticleNum int32     `json:"articleNum" db:"article_num"`
    Words      int32     `json:"words" db:"words"`
    ZanNum     int32     `json:"zanNum" db:"zan_num"`
    CreatedAt  time.Time `json:"createdAt" db:"created_at"`
    UpdatedAt  time.Time `json:"updatedAt" db:"updated_at"`
    DeletedAt  time.Time `json:"deletedAt" db:"deleted_at"`
}

func GetUserCount(ctx context.Context, uid int64, columns ...string) (UserCount, error) {
    result := selectOne(ctx, "user_count", append(where{}, sqlex.Eq{"uid": uid}), columns...)
    if !result.success {
        return UserCount{}, errors.New("查詢用戶計數失敗")
    }
    return builder.GetStructLikeByTag(result.b, UserCount{}, "db").(UserCount), nil
}

func InsertUserCount(ctx context.Context, uid int64) error {
    result := insertOne(ctx, "user_count", cv{"uid": uid})
    if !result.success {
        return errors.New("保存用戶計數表失敗")
    }
    return nil
}

func UpdateUserCount(ctx context.Context, uid int64, add bool, columns ...string) error {
    directSets, directSet := make([]string, 0, len(columns)), " + 1"
    if !add {
        directSet = " - 1"
    }
    for _, col := range columns {
        directSets = append(directSets, col+directSet)
    }
    if !update(ctx, "user_count", cv{}, where{sqlex.Eq{"uid": uid}}, directSets...).success {
        return errors.New("增長用戶計數失敗")
    }
    return nil
}

咱們爲了改變userCount中的計數值,定義了方法UpdateUserCount。能夠經過指定加減和相應字段來實現計數值的加減。咱們能夠注意到了,這裏在調用update的時候,傳入了directSets,最終將經過update中的:

for _, set := range directSet {
        updateBuilder = updateBuilder.DirectSet(set)
}

將設置好的值構建到SQL中。DirectSet目的是構建無參數的set語句,因此並不建議暴露給從接口傳入的參數,不然會有SQL注入的風險。

userFollow.go

package model

import (
    "context"
    "errors"
    "github.com/unrotten/builder"
    "github.com/unrotten/sqlex"
    "time"
)

type UserFollow struct {
    Id        int64     `json:"id" db:"id"`
    Uid       int64     `json:"uid" db:"uid"`
    Fuid      int64     `json:"fuid" db:"fuid"`
    CreatedAt time.Time `json:"createdAt" db:"created_at"`
    UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
    DeletedAt time.Time `json:"deletedAt" db:"deleted_at"`
}

func InsertUserFollow(ctx context.Context, uid, fuid int64) error {
    id, err := idfetcher.NextID()
    if err != nil {
        return err
    }
    if result := insertOne(ctx, "user_follow", cv{"id": int64(id), "uid": uid, "fuid": fuid}); !result.success {
        return errors.New("插入用戶關注表失敗")
    }
    return nil
}

func RemoveUserFollow(ctx context.Context, uid, fuid int64) error {
    if !remove(ctx, "user_follow", where{sqlex.Eq{"uid": uid, "fuid": fuid}}).success {
        return errors.New("刪除用戶關注失敗")
    }
    return nil
}

// 獲取用戶關注列表
func GetUserFollowList(ctx context.Context, fuid int64) ([]int64, error) {
    result := selectList(ctx, "user_follow", where{sqlex.Eq{"fuid": fuid}}, "uid")
    if !result.success {
        return nil, errors.New("獲取用戶關注列表失敗")
    }
    b, _ := builder.Get(result.b, "list")
    list := b.([]interface{})
    userList := make([]int64, 0, len(list))
    for _, item := range list {
        uid, _ := builder.Get(item.(builder.Builder), "uid")
        userList = append(userList, uid.(int64))
    }
    return userList, nil
}

// 獲取用戶粉絲列表
func GetFollowUserList(ctx context.Context, uid int64) ([]int64, error) {
    result := selectList(ctx, "user_follow", where{sqlex.Eq{"uid": uid}}, "fuid")
    if !result.success {
        return nil, errors.New("獲取用戶關注列表失敗")
    }
    b, _ := builder.Get(result.b, "list")
    list := b.([]interface{})
    userList := make([]int64, 0, len(list))
    for _, item := range list {
        uid, _ := builder.Get(item.(builder.Builder), "fuid")
        userList = append(userList, uid.(int64))
    }
    return userList, nil
}

在這裏不管是粉絲列表仍是關注列表,咱們都指定了獲取對應的userId列表,而非UserFollow數組。這是爲了便於後續dataloader的使用,之後會提到。

到這裏用戶相關的model就編寫完了,後面真正與前端一塊兒聯調時,定還有許多更改。而其餘諸如文章,評論等的model,便再也不贅述。用戶相關的model,已經將基本的CURD涵蓋。

看完這裏,咱們能夠發現,對於user的擴展表user_count 和 user_follow, 咱們並無在model層面去設計他們的關係,在數據的獲取,新增,修改上,也都是獨立的。這是由於咱們全部定義的數據之間的關係,都交由GraphQL去描述了,在數據層咱們反而不用多在乎這些關係的實現。


做者我的博客地址:https://unrotten.org
做者微信公衆號地址:
WechatIMG2.jpeg

相關文章
相關標籤/搜索