{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# 2-Dimensional Interpolation of Tidy  Data\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport matplotlib.pyplot as plt\nfrom scipy import interpolate\nfrom scipy.stats import norm\nfrom mpl_toolkits.mplot3d import Axes3D\n\nimport pandas as pd\n\nimport itertools\n\nimport ndsplines\n\ndef gaussian(x_in):\n    z = norm.ppf(.995)\n    x = z*(2*x_in-1)\n    return norm.pdf(x)\n\ndef sin(x_in):\n    x = np.pi*(x_in-0.5)\n    return np.sin(x)\n\ndef tanh(x_in):\n    x = 2*np.pi*(x_in-0.5)\n    return np.tanh(x)\n\ndef dist(x_in, y_in):\n    return np.sqrt((x_in-0.25)**2 + (y_in-0.25)**2)\n\nfuncs = [gaussian, sin, tanh]\n\ndef wrap2d(funcx, funcy):\n    def func2d(x_in, y_in):\n        return funcx(x_in)*funcy(y_in)\n    func2d.__name__ = '_'.join([funcx.__name__, funcy.__name__])\n    return func2d\n\nfuncs = [ wrap2d(*funcs_to_wrap) for funcs_to_wrap in itertools.combinations_with_replacement(funcs, r=2)]\nfuncs.append(dist)\n\nx = np.linspace(0, 1, 7)\ny = np.linspace(0, 1, 7)\n\nxx = np.linspace(0,1,64) \nyy = np.linspace(0,1,64)\n\nxx = np.linspace(-.25, 1.25, 64)\nyy = np.linspace(-.25, 1.25, 64)\nk = 3\n\n\nmeshx, meshy = np.meshgrid(x, y, indexing='ij')\ngridxy = np.r_['0,3', meshx, meshy]\ngridxy = np.stack((meshx, meshy), axis=-1)\ntidyxy = gridxy.reshape((-1,2))\n\n\nmeshxx, meshyy = np.meshgrid(xx, yy, indexing='ij')\ngridxxyy = np.stack((meshxx, meshyy), axis=-1)\n\nfor func in funcs:\n    fvals = func(meshx, meshy)\n    truef = func(meshxx, meshyy)\n    \n    tidy_array = np.concatenate((fvals.reshape((-1,1)), tidyxy,), axis=1)\n    \n    tidy_df = pd.DataFrame(tidy_array, columns=['z', 'x', 'y',])\n    test_NDBspline3 = ndsplines.make_interp_spline(gridxy, fvals[:, :, None])\n    test_NDBspline = ndsplines.make_interp_spline_from_tidy(tidy_df, ['x', 'y'], ['z'])\n    test_RectSpline = interpolate.RectBivariateSpline(x, y, fvals)\n    test_NDBspline2 = ndsplines.make_interp_spline_from_tidy(tidy_array, [1,2], [0])\n    \n    print(np.allclose(test_NDBspline2(gridxxyy),test_NDBspline(gridxxyy)))\n\n    fig = plt.figure()\n    ax = fig.add_subplot(111, projection='3d')\n    \n    ax.plot_wireframe(meshxx, meshyy, truef, alpha=0.25, color='C0')\n    ax.plot_wireframe(meshxx, meshyy, test_NDBspline(gridxxyy)[...,0], color='C1')\n    ax.plot_wireframe(meshxx, meshyy, test_RectSpline(meshxx, meshyy, grid=False), color='C2')\n    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}