{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Tutorial for basic usage \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import ndsplines\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n# generate grid of independent variables\nx = np.array([-1, -7/8, -3/4, -1/2, -1/4, -1/8, 0, 1/8, 1/4, 1/2, 3/4, 7/8, 1])*np.pi\ny = np.array([-1, -1/2, 0, 1/2, 1])\nmeshx, meshy = np.meshgrid(x, y, indexing='ij')\ngridxy = np.stack((meshx, meshy), axis=-1)\n\n\n# generate denser grid of independent variables to interpolate\nsparse_dense = 2**7\nxx = np.concatenate([np.linspace(x[i], x[i+1], sparse_dense) for i in range(x.size-1)]) # np.linspace(x[0], x[-1], x.size*sparse_dense)\nyy = np.concatenate([np.linspace(y[i], y[i+1], sparse_dense) for i in range(y.size-1)]) # np.linspace(y[0], y[-1], y.size*sparse_dense)\ngridxxyy = np.stack(np.meshgrid(xx, yy, indexing='ij'), axis=-1)\n\ndef plots(sparse_data, dense_data, ylabel='f(x,y)'):\n    fig, axes = plt.subplots(1, 2, constrained_layout=True)\n    for yidx in range(sparse_data.shape[1]):\n        axes[0].plot(x, sparse_data[:, yidx], 'o', color='C%d'%yidx, label='y=%.2f'%y[yidx])\n        axes[0].plot(xx, dense_data[:, np.clip(yidx*sparse_dense, 0, yy.size-1)], color='C%d'%yidx)# label='y=%.1f'%y[yidx])\n        \n    axes[0].legend()\n    axes[0].set_xlabel('x')\n    axes[0].set_ylabel(ylabel)\n    for xidx in range(sparse_data.shape[0]//2):\n        axes[1].plot(yy, dense_data[(xidx+3)*sparse_dense, :], '--', color='C%d'%xidx,)# label='x=%.1f'%x[xidx+3])\n        axes[1].plot(y, sparse_data[xidx+3, :], 'o', color='C%d'%xidx, label='x=%.1f'%x[xidx+3],)\n        \n    axes[1].legend()\n    axes[1].set_xlabel('y')\n    plt.show()\n\n# evaluate a function to interpolate over input grid\nmeshf = np.sin(meshx) * (meshy-3/8)**2 + 2\n\n# create the interpolating splane\ninterp = ndsplines.make_interp_spline(gridxy, meshf)\n\n# evaluate spline over denser grid\nmeshff = interp(gridxxyy)\n\n\nplots(meshf, meshff)\n\n\n##\n\n# as subplots\nfig, axes = plt.subplots(1,2, constrained_layout=True)\n\ngridxxy = np.stack(np.meshgrid(xx, y, indexing='ij'), axis=-1)\nmeshff = interp(gridxxy)\n\nfor yidx in range(meshf.shape[1]):\n    axes[0].plot(x, meshf[:, yidx], 'o', color='C%d'%yidx, label='y=%.1f'%y[yidx])\n    axes[0].plot(xx, meshff[:, yidx], color='C%d'%yidx)\naxes[0].legend()\naxes[0].set_xlabel('$x$')\naxes[0].set_ylabel('$f(x,y)$')\n\n# y-dir plot\ngridxyy = np.stack(np.meshgrid(x, yy, indexing='ij'), axis=-1)\n\nmeshff = interp(gridxyy)\nfor xidx in range(meshf.shape[0]//2):\n    axes[1].plot(yy, meshff[xidx*1+3, :], '--', color='C%d'%xidx, label='x=%.1f'%x[xidx*1+3])\n    axes[1].plot(y, meshf[xidx*1+3, :], 'o', color='C%d'%xidx)\n    \naxes[1].legend()\naxes[1].set_xlabel('$y$')\n# plt.ylabel(r'$\\frac{\\partial f(x,y)}{\\partial y}$')\nplt.show()\n\n##\n\n# we could also use tidy data format to make the grid\n\ntidy_data = np.dstack((gridxy, meshf)).reshape((-1,3))\nprint(tidy_data)\n\ntidy_interp = ndsplines.make_interp_spline_from_tidy(tidy_data, [0,1], [2])\n\nprint(\"\\nCoefficients all same?\", np.all(tidy_interp.coefficients == interp.coefficients))\nprint(\"Knots all same?\", np.all([np.all(knot0 == knot1) for knot0, knot1 in zip(tidy_interp.knots, interp.knots)]))\n\n# send to example of least squares\n##\n# two ways to evaluate derivative - y direction\n\nderiv_interp = interp.derivative(1)\nderiv1 = deriv_interp(gridxy)\nderiv2 = interp(gridxxyy, nus=np.array([0,1]))\n\nplots(deriv1, deriv2, r'$\\frac{\\partial f(x,y)}{\\partial y}$')\n\n##\n# two ways to evaluate derivatives x-direction: create a derivative spline or call with nus:\nderiv_interp = interp.derivative(0)\nderiv1 = deriv_interp(gridxy)\nderiv2 = interp(gridxxyy, nus=np.array([1,0]))\n\nplots(deriv1, deriv2, r'$\\frac{\\partial f(x,y)}{\\partial x}$')\n##\n\n# Calculus demonstration\ninterp1 = deriv_interp.antiderivative(0)\ncoeff_diff = interp1.coefficients - interp.coefficients\nprint(\"\\nAntiderivative of derivative:\\n\",\"Coefficients differ by constant?\", np.allclose(interp1.coefficients+2.0, interp.coefficients))\nprint(\"Knots all same?\", np.all([np.all(knot0 == knot1) for knot0, knot1 in zip(interp1.knots, interp.knots)]))\n\nantideriv_interp = interp.antiderivative(0)\n\ninterp2 = antideriv_interp.derivative(0)\nprint(\"\\nDerivative of antiderivative:\\n\",\"Coefficients the same?\", np.allclose(interp2.coefficients, interp.coefficients))\nprint(\"Knots all same?\", np.all([np.all(knot0 == knot1) for knot0, knot1 in zip(interp2.knots, interp.knots)]))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}