From 90d31043e7caec32bd86de430c26ccf2adc18696 Mon Sep 17 00:00:00 2001 From: rofinn Date: Thu, 26 May 2022 17:25:26 -0700 Subject: [PATCH] Add support for table types. --- Project.toml | 1 + src/MLUtils.jl | 1 + src/observation.jl | 8 ++++---- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 1d151e5..aafaea4 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] ChainRulesCore = "1.0" diff --git a/src/MLUtils.jl b/src/MLUtils.jl index ff6684c..9e4332b 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -13,6 +13,7 @@ import ChainRulesCore: rrule using ChainRulesCore: @non_differentiable, unthunk, AbstractZero, NoTangent, ZeroTangent, ProjectTo +using Tables: istable, rows include("observation.jl") export numobs, diff --git a/src/observation.jl b/src/observation.jl index f831a56..a2ad7e9 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -16,7 +16,7 @@ See also [`getobs`](@ref) function numobs end # Generic Fallbacks -numobs(data) = length(data) +numobs(data) = istable(data) ? length(rows(data)) : length(data) """ getobs(data, [idx]) @@ -40,13 +40,13 @@ Every author behind some custom data container can make this decision themselves. The output should be consistent when `idx` is a scalar vs vector. -See also [`getobs!`](@ref) and [`numobs`](@ref) +See also [`getobs!`](@ref) and [`numobs`](@ref) """ function getobs end # Generic Fallbacks getobs(data) = data -getobs(data, idx) = data[idx] +getobs(data, idx) = istable(data) ? collect(rows(data))[idx] : data[idx] """ getobs!(buffer, data, idx) @@ -82,7 +82,7 @@ Base.lastindex(x::AbstractDataContainer) = numobs(x) # -------------------------------------------------------------------- # Arrays # We are very opinionated with arrays: the observation dimension -# is th last dimension. For different behavior wrap the array in +# is th last dimension. For different behavior wrap the array in # a custom type, e.g. with Tables.table.