Skip to content

Commit

Permalink
Merge pull request #1604 from anywhy/indicator_adx
Browse files Browse the repository at this point in the history
FEATURE:[indicator] add adx indicator
  • Loading branch information
c9s committed Mar 26, 2024
2 parents f3f21f3 + 88281c1 commit 6ac642b
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pkg/bbgo/indicator_set.go
Expand Up @@ -3,7 +3,7 @@ package bbgo
import (
"github.com/sirupsen/logrus"

"github.com/c9s/bbgo/pkg/indicator/v2"
indicatorv2 "github.com/c9s/bbgo/pkg/indicator/v2"
"github.com/c9s/bbgo/pkg/types"
)

Expand Down Expand Up @@ -108,3 +108,7 @@ func (i *IndicatorSet) ATR(interval types.Interval, window int) *indicatorv2.ATR
func (i *IndicatorSet) ATRP(interval types.Interval, window int) *indicatorv2.ATRPStream {
return indicatorv2.ATRP2(i.KLines(interval), window)
}

func (i *IndicatorSet) ADX(interval types.Interval, window int) *indicatorv2.ADXStream {
return indicatorv2.ADX(i.KLines(interval), window)
}
72 changes: 72 additions & 0 deletions pkg/indicator/v2/adx.go
@@ -0,0 +1,72 @@
package indicatorv2

import (
"math"

"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
)

type ADXStream struct {
*RMAStream

Plus, Minus *types.Float64Series

window int
prevHigh, prevLow fixedpoint.Value
}

func ADX(source KLineSubscription, window int) *ADXStream {
var (
atr = ATR2(source, window)
dmp = types.NewFloat64Series()
dmn = types.NewFloat64Series()
adx = types.NewFloat64Series()
sdmp = RMA2(dmp, window, true)
sdmn = RMA2(dmn, window, true)
s = &ADXStream{
window: window,
Plus: types.NewFloat64Series(),
Minus: types.NewFloat64Series(),
prevHigh: fixedpoint.Zero,
prevLow: fixedpoint.Zero,
RMAStream: RMA2(adx, window, true),
}
)

source.AddSubscriber(func(k types.KLine) {
if s.prevHigh.IsZero() || s.prevLow.IsZero() {
s.prevHigh, s.prevLow = k.High, k.Low
return
}

up, dn := k.High.Sub(s.prevHigh), s.prevLow.Sub(k.Low)
if up.Compare(dn) > 0 && up.Float64() > 0 {
dmp.PushAndEmit(up.Float64())
} else {
dmp.PushAndEmit(0.0)
}
if dn.Compare(up) > 0 && dn.Float64() > 0 {
dmn.PushAndEmit(dn.Float64())
} else {
dmn.PushAndEmit(0.0)
}

s.Plus.PushAndEmit(sdmp.Last(0) * 100 / atr.Last(0))
s.Minus.PushAndEmit(sdmn.Last(0) * 100 / atr.Last(0))
dx := math.Abs(s.Plus.Last(0)-s.Minus.Last(0)) / (s.Plus.Last(0) + s.Minus.Last(0))
if !math.IsNaN(dx) {
adx.PushAndEmit(dx * 100.0)
} else {
adx.PushAndEmit(0.0)
}

s.prevHigh, s.prevLow = k.High, k.Low
s.Truncate()
})
return s
}

func (s *ADXStream) Truncate() {
s.Slice = s.Slice.Truncate(MaxNumOfRMA)
}
77 changes: 77 additions & 0 deletions pkg/indicator/v2/adx_test.go
@@ -0,0 +1,77 @@
package indicatorv2

import (
"encoding/json"
"math"
"testing"

"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
)

/*
import pandas as pd
import pandas_ta as ta
data = {
"high": [40145.0, 40186.36, 40196.39, 40344.6, 40245.48, 40273.24, 40464.0, 40699.0, 40627.48, 40436.31, 40370.0, 40376.8, 40227.03, 40056.52, 39721.7, 39597.94, 39750.15, 39927.0, 40289.02, 40189.0],
"low": [39870.71, 39834.98, 39866.31, 40108.31, 40016.09, 40094.66, 40105.0, 40196.48, 40154.99, 39800.0, 39959.21, 39922.98, 39940.02, 39632.0, 39261.39, 39254.63, 39473.91, 39555.51, 39819.0, 40006.84],
"close": [40105.78, 39935.23, 40183.97, 40182.03, 40212.26, 40149.99, 40378.0, 40618.37, 40401.03, 39990.39, 40179.13, 40097.23, 40014.72, 39667.85, 39303.1, 39519.99,39693.79, 39827.96, 40074.94, 40059.84]
}
high = pd.Series(data['high'])
low = pd.Series(data['low'])
close = pd.Series(data['close'])
result = ta.adx(high, low, close, length=7)
print(result)
*/
func Test_ADX(t *testing.T) {
var bytes = []byte(`{
"high": [40145.0, 40186.36, 40196.39, 40344.6, 40245.48, 40273.24, 40464.0, 40699.0, 40627.48, 40436.31, 40370.0, 40376.8, 40227.03, 40056.52, 39721.7, 39597.94, 39750.15, 39927.0, 40289.02, 40189.0],
"low": [39870.71, 39834.98, 39866.31, 40108.31, 40016.09, 40094.66, 40105.0, 40196.48, 40154.99, 39800.0, 39959.21, 39922.98, 39940.02, 39632.0, 39261.39, 39254.63, 39473.91, 39555.51, 39819.0, 40006.84],
"close": [40105.78, 39935.23, 40183.97, 40182.03, 40212.26, 40149.99, 40378.0, 40618.37, 40401.03, 39990.39, 40179.13, 40097.23, 40014.72, 39667.85, 39303.1, 39519.99,39693.79, 39827.96, 40074.94, 40059.84]
}`)

var buildKLines = func(bytes []byte) (kLines []types.KLine) {
var prices map[string][]fixedpoint.Value
_ = json.Unmarshal(bytes, &prices)
for i, h := range prices["high"] {
kLine := types.KLine{High: h, Low: prices["low"][i], Close: prices["close"][i]}
kLines = append(kLines, kLine)
}
return kLines
}

tests := []struct {
name string
kLines []types.KLine
window int
want float64
}{
{
name: "test_binance_btcusdt_1h",
kLines: buildKLines(bytes),
window: 7,
want: 34.83952,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stream := &types.StandardStream{}

kLines := KLines(stream, "", "")
adx := ADX(kLines, tt.window)

for _, k := range tt.kLines {
stream.EmitKLineClosed(k)
}

got := adx.Last(0)
diff := math.Trunc((got-tt.want)*100) / 100
if diff != 0 {
t.Errorf("ADX() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 6ac642b

Please sign in to comment.