xiaochang 5 years ago
parent
commit
ba3df493af
6 changed files with 325 additions and 19 deletions
  1. 77 0
      lib/sequence/db/sequence.go
  2. 73 0
      lib/sequence/sequence.go
  3. 157 0
      lib/shorturl/shorturl.go
  4. 2 0
      main.go
  5. 15 2
      web/api/shortUrl.go
  6. 1 17
      web/server.go

+ 77 - 0
lib/sequence/db/sequence.go

@@ -0,0 +1,77 @@
+package db
+import (
+	"database/sql"
+	"git.sfnt.net/sfnt/cnlink/lib/sequence"
+	"log"
+	"git.sfnt.net/sfnt/cnlink/conf"
+	_ "github.com/go-sql-driver/mysql"
+)
+
+type SequenceDB struct {
+	db *sql.DB
+}
+
+func (dbSeq *SequenceDB) Open() (err error) {
+	var db *sql.DB
+	db, err = sql.Open("mysql", conf.Conf.SequenceDB.DSN)
+	if err != nil {
+		log.Printf("sequence db open error. %v", err)
+		return err
+	}
+
+	err = db.Ping()
+	if err != nil {
+		log.Printf("sequence db ping error. %v", err)
+		return err
+	}
+
+	db.SetMaxIdleConns(conf.Conf.SequenceDB.MaxIdleConns)
+	db.SetMaxOpenConns(conf.Conf.SequenceDB.MaxOpenConns)
+
+	dbSeq.db = db
+	return nil
+}
+
+func (dbSeq *SequenceDB) Close() {
+	if dbSeq.db != nil {
+		dbSeq.db.Close()
+		dbSeq.db = nil
+	}
+}
+
+func (dbSeq *SequenceDB) NextSequence() (sequence uint64, err error) {
+	var stmt *sql.Stmt
+	stmt, err = dbSeq.db.Prepare(`REPLACE INTO sequence(stub) VALUES ("sequence")`)
+	if err != nil {
+		log.Printf("sequence db prepare error. %v", err)
+		return 0, err
+	}
+	defer stmt.Close()
+
+	var res sql.Result
+	res, err = stmt.Exec()
+	if err != nil {
+		log.Printf("sequence db replace into error. %v", err)
+		return 0, err
+	}
+
+	// 兼容LastInsertId方法的返回值
+	var lastID int64
+	lastID, err = res.LastInsertId()
+	if err != nil {
+		log.Printf("sequence db get LastInsertId error. %v", err)
+		return 0, err
+	} else {
+		sequence = uint64(lastID)
+		// mysql sequence will start at 1, we actually want it to be
+		// started at 0. :)
+		sequence -= 1
+		return sequence, nil
+	}
+}
+
+var dbSeq = SequenceDB{}
+
+func init() {
+	sequence.MustRegister("db", &dbSeq)
+}

+ 73 - 0
lib/sequence/sequence.go

@@ -0,0 +1,73 @@
+package sequence
+
+import (
+	"fmt"
+	"sort"
+	"sync"
+)
+
+var (
+	sequencesMu sync.RWMutex
+	sequences   = map[string]Sequence{}
+)
+
+type Sequence interface {
+	// Open opens the sequence generator.
+	Open() (err error)
+	// NextSequence generates next sequence integer(unsigned 64bit).
+	// If some error happens, err will not be nil and seq will be 0.
+	// Else, err will be nil and next valid sequence integer will be in seq.
+	NextSequence() (seq uint64, err error)
+	// Close closes the sequence generator.
+	Close()
+}
+
+// GetSequence returns corresponding sequence instance with the specified
+// sequenceType. If the specified sequenceType does not register itself, then
+// err will be non nil and sequence will be nil. Else, err will be nil and
+// sequence will be corresponding sequence instance.
+func GetSequence(sequenceType string) (sequence Sequence, err error) {
+
+	sequencesMu.RLock()
+	defer sequencesMu.RUnlock()
+
+	if value, ok := sequences[sequenceType]; ok {
+		sequence = value
+		return sequence, nil
+	} else {
+		return nil, fmt.Errorf("%v is not registered.", sequenceType)
+	}
+}
+
+// Register makes a sequence generator available by the provided sequenceType.
+// If Register is called twice with the same name or if driver is nil, it
+// panics.
+func MustRegister(sequenceType string, sequence Sequence) {
+
+	sequencesMu.Lock()
+	defer sequencesMu.Unlock()
+
+	if sequence == nil {
+		panic("sequence: Registered sequence is nil")
+	}
+
+	if _, dup := sequences[sequenceType]; dup {
+		panic("sequence: Register called twice for driver " + sequenceType)
+	}
+
+	sequences[sequenceType] = sequence
+}
+
+// Sequences returns a sorted list of the types of the registered sequences.
+func Sequences() []string {
+	sequencesMu.RLock()
+	defer sequencesMu.RUnlock()
+
+	var list []string
+	for name := range sequences {
+		list = append(list, name)
+	}
+
+	sort.Strings(list)
+	return list
+}

+ 157 - 0
lib/shorturl/shorturl.go

@@ -0,0 +1,157 @@
+package shorturl
+
+import (
+	"database/sql"
+	"errors"
+	"fmt"
+	"git.sfnt.net/sfnt/cnlink/lib/sequence"
+	_ "git.sfnt.net/sfnt/cnlink/lib/sequence/db"
+	"git.sfnt.net/sfnt/cnlink/lib"
+
+	"git.sfnt.net/sfnt/cnlink/conf"
+	"log"
+	_ "github.com/go-sql-driver/mysql"
+)
+
+type shorter struct {
+	readDB   *sql.DB
+	writeDB  *sql.DB
+	sequence sequence.Sequence
+}
+
+// connect will panic when it can not connect to DB server.
+func (shorter *shorter) mustConnect() {
+	db, err := sql.Open("mysql", conf.Conf.ShortDB.ReadDSN)
+	if err != nil {
+		log.Panicf("short read db open error. %v", err)
+	}
+
+	err = db.Ping()
+	if err != nil {
+		log.Panicf("short read db ping error. %v", err)
+	}
+
+	db.SetMaxIdleConns(conf.Conf.ShortDB.MaxIdleConns)
+	db.SetMaxOpenConns(conf.Conf.ShortDB.MaxOpenConns)
+
+	shorter.readDB = db
+
+	db, err = sql.Open("mysql", conf.Conf.ShortDB.WriteDSN)
+	if err != nil {
+		log.Panicf("short write db open error. %v", err)
+	}
+
+	err = db.Ping()
+	if err != nil {
+		log.Panicf("short write db ping error. %v", err)
+	}
+
+	db.SetMaxIdleConns(conf.Conf.ShortDB.MaxIdleConns)
+	db.SetMaxOpenConns(conf.Conf.ShortDB.MaxOpenConns)
+
+	shorter.writeDB = db
+}
+
+// initSequence will panic when it can not open the sequence successfully.
+func (shorter *shorter) mustInitSequence() {
+
+	sequence, err := sequence.GetSequence("db")
+	if err != nil {
+		log.Panicf("get sequence instance error. %v", err)
+	}
+
+	err = sequence.Open()
+	if err != nil {
+		log.Panicf("open sequence instance error. %v", err)
+	}
+
+	shorter.sequence = sequence
+}
+
+func (shorter *shorter) close() {
+	if shorter.readDB != nil {
+		shorter.readDB.Close()
+		shorter.readDB = nil
+	}
+
+	if shorter.writeDB != nil {
+		shorter.writeDB.Close()
+		shorter.writeDB = nil
+	}
+}
+
+func (shorter *shorter) Expand(shortURL string) (longURL string, err error) {
+	selectSQL := fmt.Sprintf(`SELECT long_url FROM short WHERE short_url=?`)
+
+	var rows *sql.Rows
+	rows, err = shorter.readDB.Query(selectSQL, shortURL)
+	if err != nil {
+		log.Printf("short read db query error. %v", err)
+		return "", errors.New("short read db query error")
+	}
+
+	defer rows.Close()
+
+	for rows.Next() {
+		err = rows.Scan(&longURL)
+		if err != nil {
+			log.Printf("short read db query rows scan error. %v", err)
+			return "", errors.New("short read db query rows scan error")
+		}
+	}
+
+	err = rows.Err()
+	if err != nil {
+		log.Printf("short read db query rows iterate error. %v", err)
+		return "", errors.New("short read db query rows iterate error")
+	}
+
+	return longURL, nil
+}
+func (shorter *shorter) GetSequence()(seq uint64, err error){
+	seq, err = shorter.sequence.NextSequence()
+	return seq,nil
+}
+func (shorter *shorter) Short(longURL string) (shortURL string, err error) {
+	for {
+		var seq uint64
+		seq, err = shorter.sequence.NextSequence()
+		if err != nil {
+			log.Printf("get next sequence error. %v", err)
+			return "", errors.New("get next sequence error")
+		}
+
+		shortURL = lib.Int2String(seq)
+		if _, exists := conf.Conf.Common.BlackShortURLsMap[shortURL]; exists {
+			continue
+		} else {
+			break
+		}
+	}
+
+	insertSQL := fmt.Sprintf(`INSERT INTO short(long_url, short_url) VALUES(?, ?)`)
+
+	var stmt *sql.Stmt
+	stmt, err = shorter.writeDB.Prepare(insertSQL)
+	if err != nil {
+		log.Printf("short write db prepares error. %v", err)
+		return "", errors.New("short write db prepares error")
+	}
+	defer stmt.Close()
+
+	_, err = stmt.Exec(longURL, shortURL)
+	if err != nil {
+		log.Printf("short write db insert error. %v", err)
+		return "", errors.New("short write db insert error")
+	}
+
+	return shortURL, nil
+}
+
+var Shorter shorter
+
+func Start() {
+	Shorter.mustConnect()
+	Shorter.mustInitSequence()
+	log.Println("shorter starts")
+}

+ 2 - 0
main.go

@@ -4,6 +4,7 @@ import (
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"git.sfnt.net/sfnt/cnlink/conf"
 	"git.sfnt.net/sfnt/cnlink/conf"
+	"git.sfnt.net/sfnt/cnlink/lib/shorturl"
 	"git.sfnt.net/sfnt/cnlink/web"
 	"git.sfnt.net/sfnt/cnlink/web"
 	"os"
 	"os"
 )
 )
@@ -21,5 +22,6 @@ func main() {
 
 
 	// parse config
 	// parse config
 	conf.MustParseConfig(*cfgFile)
 	conf.MustParseConfig(*cfgFile)
+	shorturl.Start()
 	web.Start()
 	web.Start()
 }
 }

+ 15 - 2
web/api/shortUrl.go

@@ -1,7 +1,20 @@
 package api
 package api
 
 
-import "github.com/gin-gonic/gin"
+import (
+	"git.sfnt.net/sfnt/cnlink/lib"
+	"git.sfnt.net/sfnt/cnlink/lib/shorturl"
+	"github.com/gin-gonic/gin"
+)
 
 
-func redirect(c *gin.Context){
+func Redirect(c *gin.Context){
 
 
+}
+func Short(c *gin.Context){
+	sid := c.Params.ByName("sid")
+	sq, _ := shorturl.Shorter.GetSequence()
+	c.JSON(200, gin.H{
+		"message": lib.String2Int(sid),
+		"sequence": lib.Int2String(sq),
+		"sequence_id": sq,
+	})
 }
 }

+ 1 - 17
web/server.go

@@ -1,7 +1,6 @@
 package web
 package web
 
 
 import (
 import (
-	"git.sfnt.net/sfnt/cnlink/lib"
 	"git.sfnt.net/sfnt/cnlink/web/api"
 	"git.sfnt.net/sfnt/cnlink/web/api"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
@@ -16,23 +15,8 @@ func Start() {
 	})
 	})
 	r.GET("/ver", api.CheckVersion)
 	r.GET("/ver", api.CheckVersion)
 	r.GET("/health", api.CheckHealth)
 	r.GET("/health", api.CheckHealth)
-	/*
-	r.GET("/create", func(c *gin.Context) {
-		c.JSON(200, gin.H{
-			"message": "create page",
-		})
-	})
-	*/
-	r.GET("/u:sid", func(c *gin.Context) {
-		sid := c.Params.ByName("sid")
-		//intSid,_ := strconv.ParseUint(sid,10,64)
-
+	r.GET("/u:sid", api.Redirect)
 
 
-		c.JSON(200, gin.H{
-			"message": lib.String2Int(sid),
-		})
-		//c.Redirect(http.StatusFound,"")
-	})
 
 
 	r.Run()
 	r.Run()
 }
 }