Skip to content

Commit

Permalink
use new DRP chart serializers
Browse files Browse the repository at this point in the history
  • Loading branch information
sheppard committed Sep 16, 2015
1 parent a8cba00 commit b4582d7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 271 deletions.
104 changes: 8 additions & 96 deletions contrib/chart/serializers.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,14 @@
from rest_framework import serializers
from rest_pandas import PandasSerializer
from wq.db.rest.serializers import LocalDateTimeField
from django.db.models.fields import DateTimeField, FieldDoesNotExist


class ChartModelSerializer(serializers.ModelSerializer):
key_fields = ["series", "date"]
parameter_fields = ["parameter", "units"]
value_field = "value"
class Meta:
pandas_index = ["date"]
pandas_unstacked_header = ["series", "units", "parameter"]

@property
def key_lookups(self):
return self.key_fields
pandas_scatter_coord = ["units", "parameter"]
pandas_scatter_header = ["series"]

@property
def parameter_lookups(self):
return self.parameter_fields

@property
def value_lookup(self):
return self.value_field

@property
def key_model(self):
return self.Meta.model

def get_fields(self):
fields = super(ChartModelSerializer, self).get_fields()
value_kwargs = {}
if self.value_field != self.value_lookup:
value_kwargs['source'] = self.value_lookup
fields[self.value_field] = serializers.ReadOnlyField(**value_kwargs)

for key, lookup in zip(self.parameter_fields, self.parameter_lookups):
param_kwargs = {}
if key != lookup:
param_kwargs['source'] = lookup
fields[key] = serializers.ReadOnlyField(**param_kwargs)

for key, lookup in zip(self.key_fields, self.key_lookups):
try:
field = self.key_model._meta.get_field_by_name(key)[0]
except FieldDoesNotExist:
field = None
key_kwargs = {}
if key != lookup:
key_kwargs['source'] = lookup
if isinstance(field, DateTimeField):
fields[key] = LocalDateTimeField(**key_kwargs)
else:
fields[key] = serializers.ReadOnlyField(**key_kwargs)

return fields


class ChartPandasSerializer(PandasSerializer):
index_none_value = "-"

@property
def model_serializer(self):
# Compatibility with ListSerializer and ModelSerializer
return getattr(self, 'child', self)

def get_key_fields(self):
return self.model_serializer.key_fields

def get_parameter_fields(self):
return self.model_serializer.parameter_fields

def get_index(self, dataframe):
"""
By default, all key fields need to be included or pivoting may break
due to non-unique values. Move first item in index (which would
usually be the most important key) to the end to facilitate unstacking.
"""
index_fields = []
meta = getattr(self, 'Meta', object())
for key in self.get_key_fields():
if key not in getattr(meta, 'exclude', []):
index_fields.append(key)

return (
index_fields[1:] + [index_fields[0]] + self.get_parameter_fields()
)

def get_dataframe(self, data):
"""
Unstack the dataframe so parameter fields and most important key field
are columns.
"""
dataframe = super(ChartPandasSerializer, self).get_dataframe(data)
dataframe.columns.name = ""

for i in range(len(self.get_parameter_fields()) + 1):
dataframe = dataframe.unstack()
dataframe = (
dataframe
.dropna(axis=0, how='all')
.dropna(axis=1, how='all')
)
return dataframe
pandas_boxplot_group = "series"
pandas_boxplot_date = "date"
pandas_boxplot_header = ["units", "parameter"]
173 changes: 8 additions & 165 deletions contrib/chart/views.py
Original file line number Diff line number Diff line change
@@ -1,189 +1,32 @@
from rest_pandas import PandasView
from rest_pandas import (
PandasView, PandasUnstackedSerializer, PandasScatterSerializer,
PandasBoxplotSerializer,
)
from wq.db.patterns.identify.filters import IdentifierFilterBackend
from .serializers import ChartModelSerializer, ChartPandasSerializer
from .serializers import ChartModelSerializer


class ChartView(PandasView):
serializer_class = ChartModelSerializer
pandas_serializer_class = ChartPandasSerializer
filter_backends = [IdentifierFilterBackend]


class TimeSeriesMixin(object):
"""
For use with chart.timeSeries() in wq/chart.js
"""

def transform_dataframe(self, df):
"""
The dataframe is already in a timeseries format.
"""
return df
pandas_serializer_class = PandasUnstackedSerializer


class ScatterMixin(object):
"""
For use with chart.scatter() in wq/chart.js
"""

def transform_dataframe(self, df):
"""
Transform timeseries dataframe into a format suitable for plotting two
values against each other.
"""

serializer = self.get_serializer()
value_column = serializer.value_field
series_column = serializer.key_fields[0]
parameter_column = serializer.parameter_fields[0]

# Only use primary 'value' column, ignoring any other result fields
# that may have been added to a serializer subclass.
for key in df.columns.levels[0]:
if key != value_column:
df = df.drop(key, axis=1)

# Remove all indexes/columns except for parameter (which will become
# the new 'value' field) and series (to allow distinguishing between
# scatterplot data for each series).
for name in df.columns.names:
if name not in (series_column, parameter_column):
df.columns = df.columns.droplevel(name)

# Rename columns ('value'/parameter column should be nameless)
df.columns.names = ["", series_column]

# Only include dates that have data for all parameters
df = df.dropna(axis=0, how='any')
return df
pandas_serializer_class = PandasScatterSerializer


class BoxPlotMixin(object):
"""
For use with chart.boxplot() in wq/chart.js
"""

NAME_MAP = {
'q1': 'p25',
'q3': 'p75',
'med': 'median',
'whishi': 'max',
'whislo': 'min',
}

def transform_dataframe(self, df):
"""
Use matplotlib to compute boxplot statistics on timeseries data.
"""
from pandas import DataFrame
group = self.get_grouping(len(df.columns))
serializer = self.get_serializer()
value_col = serializer.value_field
series_col = serializer.key_fields[0]
param_cols = serializer.parameter_fields
ncols = 1 + len(param_cols)

if "index" in group:
# Separate stats for each column in dataset
groups = {
col: df[col]
for col in df.columns
}
else:
# Stats for entire dataset
df = df.stack().stack().stack()
df.reset_index(inplace=True)
index = serializer.get_index(df)
df.set_index(index[0], inplace=True)
groups = {
(value_col,) + ('all',) * ncols: df.value
}

# Compute stats for each column, potentially grouped by year
all_stats = []
for g, series in groups.items():
if g[0] != serializer.value_field:
continue
series_info = g[-1]
param_info = list(reversed(g[1:-1]))
if "year" in group or "month" in group:
groupby = "year" if "year" in group else "month"
dstats = self.compute_boxplots(series, groupby)
for s in dstats:
s[series_col] = series_info
for pname, pval in zip(param_cols, param_info):
s[pname] = pval
else:
stats = self.compute_boxplot(series)
stats[series_col] = series_info
for pname, pval in zip(param_cols, param_info):
stats[pname] = pval
dstats = [stats]
all_stats += dstats

df = DataFrame(all_stats)
index = [series_col] + param_cols
if "year" in group:
index = ['year'] + index
elif "month" in group:
index = ['month'] + index
df.sort(index, inplace=True)
df.set_index(index, inplace=True)
df.columns.name = ""
df = df.unstack().unstack()
if "year" in group or "month" in group:
df = df.unstack()
return df

def get_grouping(self, sets):
group = self.request.GET.get('group', None)
if group:
return group
elif sets > 20:
return "year"
elif sets > 10:
return "index"
else:
return "year-index"

def compute_boxplots(self, series, groupby):
def groups(d):
if isinstance(d, tuple):
d = d[0]
return getattr(d, groupby)

dstats = []
for name, g in series.groupby(groups).groups.items():
stats = self.compute_boxplot(series[g])
stats[groupby] = name
dstats.append(stats)
return dstats

def compute_boxplot(self, series):
"""
Compute boxplot for given pandas Series.
"""
from matplotlib.cbook import boxplot_stats
series = series[series.notnull()]
if len(series.values) == 0:
return {}
stats = boxplot_stats(list(series.values))[0]
stats = {
self.NAME_MAP.get(key, key): value
for key, value in stats.items()
}
stats['count'] = len(series.values)
stats['fliers'] = "|".join(map(str, stats['fliers']))
return stats


class TimeSeriesView(ChartView, TimeSeriesMixin):
pass


class ScatterView(ChartView, ScatterMixin):
pass


class BoxPlotView(ChartView, BoxPlotMixin):
pass
pandas_serializer_class = PandasBoxplotSerializer
12 changes: 7 additions & 5 deletions tests/chart_app/views.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from rest_framework import serializers
from wq.db.contrib.chart import views as chart
from wq.db.contrib.chart.serializers import ChartModelSerializer
from .models import Value


class ValueSerializer(ChartModelSerializer):
key_lookups = ['series.primary_identifier.slug', 'date']
series = serializers.ReadOnlyField(source="series.primary_identifier.slug")

class Meta:
class Meta(ChartModelSerializer.Meta):
model = Value
fields = ['series', 'date', 'parameter', 'units', 'value']


class ChartView(chart.ChartView):
Expand All @@ -21,13 +23,13 @@ def filter_by_extra(self, qs, *extra):
return qs.filter(parameter__in=extra[0])


class TimeSeriesView(ChartView, chart.TimeSeriesMixin):
class TimeSeriesView(chart.TimeSeriesMixin, ChartView):
pass


class ScatterView(ChartView, chart.ScatterMixin):
class ScatterView(chart.ScatterMixin, ChartView):
pass


class BoxPlotView(ChartView, chart.BoxPlotMixin):
class BoxPlotView(chart.BoxPlotMixin, ChartView):
pass
10 changes: 5 additions & 5 deletions tests/test_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_chart_scatter(self):

d4 = dataset['data'][4]
self.assertEqual(d4['date'], '2014-01-05')
self.assertEqual(d4['temp'], 0.2)
self.assertEqual(d4['snow'], 0.0)
self.assertEqual(d4['temp-value'], 0.2)
self.assertEqual(d4['snow-value'], 0.0)

@unittest.skipUnless(boxplot_stats, "test requires matplotlib 1.4+")
def test_chart_boxplot(self):
Expand All @@ -90,9 +90,9 @@ def test_chart_boxplot(self):

stats = dataset['data'][0]
self.assertEqual(stats['year'], '2014')
self.assertEqual(stats['min'], 0.1)
self.assertEqual(stats['mean'], 0.36)
self.assertEqual(stats['max'], 0.6)
self.assertEqual(stats['value-whislo'], 0.1)
self.assertEqual(stats['value-mean'], 0.36)
self.assertEqual(stats['value-whishi'], 0.6)

def parse_csv(self, response):
return parse_csv(response.content.decode('utf-8'))

0 comments on commit b4582d7

Please sign in to comment.