fly 1 жил өмнө
parent
commit
6824d4ddee

+ 61 - 0
algorithm/encryption/AesCBCWithNoPadding.go

@@ -0,0 +1,61 @@
+package encryption
+
+import (
+	"bytes"
+	"crypto/aes"
+	"crypto/cipher"
+	"fmt"
+)
+
+type aesCBCWithNoPadding struct {
+	key []byte
+	iv  []byte
+}
+
+func NewAesCBCWithNoPadding(key, iv []byte) *aesCBCWithNoPadding {
+	return &aesCBCWithNoPadding{
+		key: key,
+		iv:  iv,
+	}
+}
+
+// AES加密函数
+func (a *aesCBCWithNoPadding) Encrypt(plaintext []byte) ([]byte, error) {
+	if len(plaintext)%aes.BlockSize != 0 {
+		padding := aes.BlockSize - len(plaintext)%aes.BlockSize
+		plaintext = append(plaintext, bytes.Repeat([]byte{byte(padding)}, padding)...)
+	}
+
+	block, err := aes.NewCipher(a.key)
+	if err != nil {
+		return nil, err
+	}
+
+	//if len(plaintext)%aes.BlockSize != 0 {
+	//	return nil, fmt.Errorf("plaintext is not a multiple of the block size")
+	//}
+
+	ciphertext := make([]byte, len(plaintext))
+	mode := cipher.NewCBCEncrypter(block, a.iv)
+	mode.CryptBlocks(ciphertext, plaintext)
+
+	return ciphertext, nil
+}
+
+// AES解密函数
+func (a *aesCBCWithNoPadding) Decrypt(ciphertext []byte) ([]byte, error) {
+	block, err := aes.NewCipher(a.key)
+	if err != nil {
+		return nil, err
+	}
+
+	if len(ciphertext)%aes.BlockSize != 0 {
+		return nil, fmt.Errorf("ciphertext is not a multiple of the block size")
+	}
+
+	plaintext := make([]byte, len(ciphertext))
+	mode := cipher.NewCBCDecrypter(block, a.iv)
+	mode.CryptBlocks(plaintext, ciphertext)
+
+	return plaintext, nil
+}

+ 114 - 0
algorithm/encryption/AesCBCWithNoPadding_test.go

@@ -0,0 +1,114 @@
+package encryption
+
+import (
+	"bytes"
+	"crypto/aes"
+	"encoding/hex"
+	"reflect"
+	"testing"
+)
+
+func Test_aesCBCWithNoPadding_Decrypt(t *testing.T) {
+	type fields struct {
+		key []byte
+		iv  []byte
+	}
+	type args struct {
+		ciphertext []byte
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		want    []byte
+		wantErr bool
+	}{
+		// TODO: Add test cases.
+		{
+			name: "test",
+			fields: fields{
+				key: []byte("a2b1805169887fd9bca40e45"),
+				iv:  []byte("69887fd9bca40e45"),
+			},
+			args: args{
+				ciphertext: []byte("69001dd634e3deb6357df6e6d7a867db"),
+			},
+			want:    []byte("hello"),
+			wantErr: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			b, _ := hex.DecodeString(string(tt.args.ciphertext))
+			a := &aesCBCWithNoPadding{
+				key: tt.fields.key,
+				iv:  tt.fields.iv,
+			}
+			got, err := a.Decrypt(b)
+
+			//got, err := aesDecrypt(b, tt.fields.key, tt.fields.iv)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			c := hex.EncodeToString(got)
+			if !reflect.DeepEqual([]byte(c), tt.want) {
+				t.Errorf("Decrypt() got = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_aesCBCWithNoPadding_Encrypt(t *testing.T) {
+	type fields struct {
+		key []byte
+		iv  []byte
+	}
+	type args struct {
+		plaintext []byte
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		want    []byte
+		wantErr bool
+	}{
+		// TODO: Add test cases.
+		{
+			name: "test",
+			fields: fields{
+				key: []byte("a2b1805169887fd9bca40e45"),
+				iv:  []byte("69887fd9bca40e45"),
+			},
+			args: args{
+				[]byte("hello"),
+			},
+			want:    []byte("69001dd634e3deb6357df6e6d7a867db"),
+			wantErr: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if len(tt.args.plaintext)%aes.BlockSize != 0 {
+				padding := aes.BlockSize - len(tt.args.plaintext)%aes.BlockSize
+				tt.args.plaintext = append(tt.args.plaintext, bytes.Repeat([]byte{byte(0)}, padding)...)
+			}
+			//got, err := aesEncrypt(tt.args.plaintext, tt.fields.key, tt.fields.iv)
+			a := &aesCBCWithNoPadding{
+				key: tt.fields.key,
+				iv:  tt.fields.iv,
+			}
+			got, err := a.Encrypt(tt.args.plaintext)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			c := hex.EncodeToString(got)
+			t.Log("密文:", c, string(tt.want))
+			if !reflect.DeepEqual([]byte(c), tt.want) {
+				t.Errorf("Encrypt() got = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}