package main

import (
	"context"
	"encoding/json"
	"flag"
	"log"
	"os"
	"strconv"
	"sync"
	"time"

	influxdb2 "github.com/influxdata/influxdb-client-go/v2"
)

type sdm630Register struct {
	Address uint16
	Unit    string
}

type measure struct {
	Value float32 `json:"value"`
	Unit  string  `json:"unit"`
}

type sampleCache struct {
	sample map[string]measure
	mutex  sync.Mutex
}

type config struct {
	RtuAddress    int    `json:"rtu_address"`
	SampleRate    int    `json:"sample_rate"`
	ModbusHost    string `json:"modbus_host"`
	ModbusPort    int    `json:"modbus_port"`
	InfluxdbHost  string `json:"influxdb_host"`
	InfluxdbPort  int    `json:"influxdb_port"`
	InfluxdbToken string `json:"influxdb_token"`
}

var (
	// SDM630 input registers
	sdm_registers = map[string]sdm630Register{
		"L1Voltage":    {0x00, "V"},
		"L2Voltage":    {0x02, "V"},
		"L3Voltage":    {0x04, "V"},
		"L1Current":    {0x06, "A"},
		"L2Current":    {0x08, "A"},
		"L3Current":    {0x0A, "A"},
		"L1PowerW":     {0x0C, "W"},
		"L2PowerW":     {0x0E, "W"},
		"L3PowerW":     {0x10, "W"},
		"L1PhaseAngle": {0x24, "°"},
		"L2PhaseAngle": {0x26, "°"},
		"L3PhaseAngle": {0x28, "°"},
		"TotalImport":  {0x48, "kWh"},
		"TotalExport":  {0x4A, "kWh"},
	}

	logger            log.Logger = *log.Default()
	last_total_import float32    = 0.0
	last_total_export float32    = 0.0
	config_path       string
	config_cache      = config{
		RtuAddress:    1,    // modbus device slave address (0x01)
		SampleRate:    1,    // sec
		ModbusHost:    "",   // hostname of the modbus tcp to rtu bridge
		ModbusPort:    502,  // port of the modbus tcp to rtu bridge
		InfluxdbHost:  "",   // hostname of the influxdb server
		InfluxdbPort:  8086, // port of the influxdb server
		InfluxdbToken: "",   // access token for the influx db
	}
)

func read_config() {
	data, err := os.ReadFile(config_path)
	if err != nil {
		logger.Printf("Unable to read %s", config_path)
		return
	}
	err = json.Unmarshal(data, &config_cache)
	if err != nil {
		logger.Print("Unable to evaluate config data")
		return
	}
}

func init() {
	logger.SetPrefix("PowerMeterSDM630: ")
	logger.Println("Starting")

}

func sample_is_valid(sample map[string]measure) bool {
	if len(sample) == 0 {
		logger.Print("Sample invalid - contains no data")
		return false
	}

	result := float32(0.0)
	// sample is invalid when all values equals 0.0
	for _, value := range sample {
		result += value.Value
	}
	if result == 0.0 {
		logger.Print("Sample invalid - all values equals 0.0")
		return false
	}

	// total import and total export can never decrease
	for key, value := range sample {
		if key == "TotalImport" {
			if value.Value < last_total_import {
				logger.Printf("Sample invalid - Total import lower than previous one (%f vs. %f)", value.Value, last_total_import)
				return false
			}
			last_total_import = value.Value
		}
		if key == "TotalExport" {
			if value.Value < last_total_export {
				logger.Printf("Sample invalid - Total export lower than previous one (%f vs. %f)", value.Value, last_total_export)
				return false
			}
			last_total_export = value.Value
		}
	}
	return true
}

func main() {
	flag.StringVar(&config_path, "c", "./config/config.json", "Specify path to find the config file. Default is ./config/config.json")
	flag.Parse()
	read_config()

	go keep_mbus_connection()
	go collect_mbus_samples()

	influxdb_url := "http://" + config_cache.InfluxdbHost + ":" + strconv.Itoa(config_cache.InfluxdbPort)
	client := influxdb2.NewClient(influxdb_url, config_cache.InfluxdbToken)
	// always close client at the end
	defer client.Close()

	// get non-blocking write client
	writeAPI := client.WriteAPIBlocking("tkl", "home")

	ctx := context.Background()

	for {
		time.Sleep(time.Duration(config_cache.SampleRate * 1000000000))

		sample_cache.mutex.Lock()
		sample := sample_cache.sample
		sample_cache.mutex.Unlock()

		if sample_is_valid(sample) {
			point := influxdb2.NewPointWithMeasurement("power")
			point.AddTag("sensor", "powermeter")
			for key, value := range sample {
				point.AddField(key, value.Value)
				point.AddField("unit", value.Unit)
			}

			// calculate TotalPowerW
			total_power := sample["L1PowerW"].Value + sample["L2PowerW"].Value + sample["L3PowerW"].Value
			point.AddField("TotalPowerW", total_power)
			point.AddField("unit", "W")

			point.SetTime(time.Now())
			err := writeAPI.WritePoint(ctx, point)
			if err != nil {
				logger.Print(err.Error())
				continue
			}
			err = writeAPI.Flush(ctx)
			if err != nil {
				logger.Print(err.Error())
				continue
			}
		}
	}
}