diff --git a/.github/workflows/flake8.yml b/.github/workflows/flake8.yml index fdf2697..7cd301a 100644 --- a/.github/workflows/flake8.yml +++ b/.github/workflows/flake8.yml @@ -16,12 +16,12 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.6] + python-version: ['3.10'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index 57ccd14..831dfa7 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -20,7 +20,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: '3.x' - name: Install dependencies diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml deleted file mode 100644 index c2ffc18..0000000 --- a/.github/workflows/pytest.yml +++ /dev/null @@ -1,38 +0,0 @@ -# This workflow will install Python dependencies, run tests with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: pytest - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - build: - - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.6, 3.9] - torch: [1.7.1, 1.10.0] - pytorch-lightning: [1.1.8, 1.5.2] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip setuptools wheel - python -m pip install --upgrade pytest - python -m pip install torch==${{ matrix.torch }} - python -m pip install pytorch-lightning==${{ matrix.pytorch-lightning }} - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Test with pytest - run: | - python -m pytest diff --git a/.github/workflows/pytest_pip.yml b/.github/workflows/pytest_pip.yml index 3b225cc..852cabd 100644 --- a/.github/workflows/pytest_pip.yml +++ b/.github/workflows/pytest_pip.yml @@ -18,19 +18,19 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.6] + python-version: ['3.8'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install build==0.4.0 - python -m pip install --upgrade setuptools wheel + python -m pip install --upgrade setuptools==59.5.0 wheel python -m pip install --upgrade pytest - name: Install package and remove local dir run: | diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml new file mode 100644 index 0000000..79bd512 --- /dev/null +++ b/.github/workflows/python.yml @@ -0,0 +1,41 @@ +# This workflow will install Python dependencies, run tests with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: python + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.8', '3.9', '3.10'] + + steps: + - uses: actions/checkout@v3 + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + - name: Replace python + uses: jacobtomlinson/gha-find-replace@v2 + with: + find: "python>=3.8,<=3.10" + replace: "python==${{ matrix.python-version }}" + regex: false + include: "environment.yml" + - name: Install dependencies + run: | + conda env update --file environment.yml --name base + python -m pip install --upgrade pip setuptools==59.5.0 wheel + python -m pip install --upgrade pytest + - name: Test with pytest + run: | + python -m pytest diff --git a/.github/workflows/pytorch-lightning.yml b/.github/workflows/pytorch-lightning.yml new file mode 100644 index 0000000..df4fbef --- /dev/null +++ b/.github/workflows/pytorch-lightning.yml @@ -0,0 +1,41 @@ +# This workflow will install Python dependencies, run tests with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: pytorch-lightning + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + pytorch-lightning: [1.1.8, 1.2.10, 1.5.10, 1.6.5] + + steps: + - uses: actions/checkout@v3 + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + - name: Replace pytorch-lightning + uses: jacobtomlinson/gha-find-replace@v2 + with: + find: "pytorch-lightning>=1.1.0,<=1.6" + replace: "pytorch-lightning==${{ matrix.pytorch-lightning }}" + regex: false + include: "environment.yml" + - name: Install dependencies + run: | + conda env update --file environment.yml --name base + python -m pip install --upgrade pip setuptools==59.5.0 wheel + python -m pip install --upgrade pytest + - name: Test with pytest + run: | + python -m pytest diff --git a/.github/workflows/pytorch.yml b/.github/workflows/pytorch.yml new file mode 100644 index 0000000..97ee16f --- /dev/null +++ b/.github/workflows/pytorch.yml @@ -0,0 +1,41 @@ +# This workflow will install Python dependencies, run tests with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: pytorch + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + pytorch: [1.8.1, 1.9.1, 1.10.1, 1.11.0, 1.12.0] + + steps: + - uses: actions/checkout@v3 + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + - name: Replace pytorch + uses: jacobtomlinson/gha-find-replace@v2 + with: + find: "pytorch>=1.8.1,<=1.12" + replace: "pytorch==${{ matrix.pytorch }}" + regex: false + include: "environment.yml" + - name: Install dependencies + run: | + conda env update --file environment.yml --name base + python -m pip install --upgrade pip setuptools==59.5.0 wheel + python -m pip install --upgrade pytest + - name: Test with pytest + run: | + python -m pytest diff --git a/README.md b/README.md index 598fa96..c8b3b74 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ # PyTorch Lightning Optical Flow ![GitHub CI flake8 status](https://github.com/hmorimitsu/ptlflow/actions/workflows/flake8.yml/badge.svg) -![GitHub CI pytest status](https://github.com/hmorimitsu/ptlflow/actions/workflows/pytest.yml/badge.svg) +![GitHub CI python status](https://github.com/hmorimitsu/ptlflow/actions/workflows/python.yml/badge.svg) +![GitHub CI pytorch status](https://github.com/hmorimitsu/ptlflow/actions/workflows/pytorch.yml/badge.svg) +![GitHub CI pytorch-lightning status](https://github.com/hmorimitsu/ptlflow/actions/workflows/pytorch-lightning.yml/badge.svg) ![GitHub CI pytest pip status](https://github.com/hmorimitsu/ptlflow/actions/workflows/pytest_pip.yml/badge.svg) [![DOI](https://zenodo.org/badge/375416785.svg)](https://zenodo.org/badge/latestdoi/375416785) @@ -13,6 +15,7 @@ The work and code from many others are present here. I tried to make sure everyt This is still under development, so some things may not work as intended. I plan to add more models in the future, as well keep improving the platform. +- [What's new](#whats-new) - [Available models](#available-models) - [Results](#results) - [Getting started](#getting-started) @@ -21,13 +24,31 @@ This is still under development, so some things may not work as intended. I plan - [Citing](#citing) - [Acknowledgements](#acknowledgements) +## What's new + +### July 30, 2022 - v0.2.6 + +- Added new models: + - CRAFT [https://arxiv.org/abs/2203.16896](https://arxiv.org/abs/2203.16896) + - CSFlow [https://arxiv.org/abs/2202.00909](https://arxiv.org/abs/2202.00909) + - FlowFormer [https://arxiv.org/abs/2203.16194](https://arxiv.org/abs/2203.16194) + - GMFlow [https://arxiv.org/abs/2111.13680](https://arxiv.org/abs/2111.13680) + - GMFlowNet [https://arxiv.org/abs/2203.11335](https://arxiv.org/abs/2203.11335) +- Added support for AutoFlow dataset [https://arxiv.org/abs/2104.14544](https://arxiv.org/abs/2104.14544) +- Fix Compatibility with Pytorch Lightning 1.6 + ## Available models +- CRAFT [https://arxiv.org/abs/2203.16896](https://arxiv.org/abs/2203.16896) +- CSFlow [https://arxiv.org/abs/2202.00909](https://arxiv.org/abs/2202.00909) - DICL-Flow [https://arxiv.org/abs/2010.14851](https://arxiv.org/abs/2010.14851) - FastFlowNet [https://arxiv.org/abs/2103.04524](https://arxiv.org/abs/2103.04524) +- FlowFormer [https://arxiv.org/abs/2203.16194](https://arxiv.org/abs/2203.16194) - FlowNet - [https://arxiv.org/abs/1504.06852](https://arxiv.org/abs/1504.06852) - FlowNet2 - [https://arxiv.org/abs/1612.01925](https://arxiv.org/abs/1612.01925) - GMA - [https://arxiv.org/abs/2104.02409](https://arxiv.org/abs/2104.02409) +- GMFlow [https://arxiv.org/abs/2111.13680](https://arxiv.org/abs/2111.13680) +- GMFlowNet [https://arxiv.org/abs/2203.11335](https://arxiv.org/abs/2203.11335) - HD3 - [https://arxiv.org/abs/1812.06264](https://arxiv.org/abs/1812.06264) - IRR - [https://arxiv.org/abs/1904.05290](https://arxiv.org/abs/1904.05290) - LCV - [https://arxiv.org/abs/2007.11431](https://arxiv.org/abs/2007.11431) diff --git a/docs/source/_static/kitti_2012_epe_outlier-drop_kitti_sintel.html b/docs/source/_static/kitti_2012_epe_outlier-drop_kitti_sintel.html index ed3f8a7..984fa9d 100644 --- a/docs/source/_static/kitti_2012_epe_outlier-drop_kitti_sintel.html +++ b/docs/source/_static/kitti_2012_epe_outlier-drop_kitti_sintel.html @@ -3,30 +3,30 @@
+"use strict";var n,i="";e.exports=function(t,e){if("string"!=typeof t)throw new TypeError("expected a string");if(1===e)return t;if(2===e)return t+t;var r=t.length*e;if(n!==t||void 0===n)n=t,i="";else if(i.length>=r)return i.substr(0,r);for(;r>i.length&&e>1;)1&e&&(i+=t),e>>=1,t+=t;return i=(i+=t).substr(0,r)}},{}],278:[function(t,e,r){(function(t){(function(){e.exports=t.performance&&t.performance.now?function(){return performance.now()}:Date.now||function(){return+new Date}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}],279:[function(t,e,r){"use strict";e.exports=function(t){for(var e=t.length,r=t[t.length-1],n=e,i=e-2;i>=0;--i){var a=r,o=t[i];(l=o-((r=a+o)-a))&&(t[--n]=r,r=l)}var s=0;for(i=n;i0){if(a<=0)return o;n=i+a}else{if(!(i<0))return o;if(a>=0)return o;n=-(i+a)}var s=33306690738754716e-32*n;return o>=s||o<=-s?o:f(t,e,r)},function(t,e,r,n){var i=t[0]-n[0],a=e[0]-n[0],o=r[0]-n[0],s=t[1]-n[1],l=e[1]-n[1],c=r[1]-n[1],u=t[2]-n[2],f=e[2]-n[2],p=r[2]-n[2],d=a*c,m=o*l,g=o*s,v=i*c,y=i*l,x=a*s,b=u*(d-m)+f*(g-v)+p*(y-x),_=7771561172376103e-31*((Math.abs(d)+Math.abs(m))*Math.abs(u)+(Math.abs(g)+Math.abs(v))*Math.abs(f)+(Math.abs(y)+Math.abs(x))*Math.abs(p));return b>_||-b>_?b:h(t,e,r,n)}];function d(t){var e=p[t.length];return e||(e=p[t.length]=u(t.length)),e.apply(void 0,t)}function m(t,e,r,n,i,a,o){return function(e,r,s,l,c){switch(arguments.length){case 0:case 1:return 0;case 2:return n(e,r);case 3:return i(e,r,s);case 4:return a(e,r,s,l);case 5:return o(e,r,s,l,c)}for(var u=new Array(arguments.length),f=0;f0&&o>0||a<0&&o<0)return!1;var s=n(r,t,e),l=n(i,t,e);if(s>0&&l>0||s<0&&l<0)return!1;if(0===a&&0===o&&0===s&&0===l)return function(t,e,r,n){for(var i=0;i<2;++i){var a=t[i],o=e[i],s=Math.min(a,o),l=Math.max(a,o),c=r[i],u=n[i],f=Math.min(c,u);if(Math.max(c,u)=n?(i=f,(l+=1)=n?(i=f,(l+=1)>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,s=a(t[o],e);s<=0?(0===s&&(i=o),r=o+1):s>0&&(n=o-1)}return i}function u(t,e){for(var r=new Array(t.length),i=0,o=r.length;i=t.length||0!==a(t[g],s)););}return r}function f(t,e){if(e<0)return[];for(var r=[],i=(1<>>u&1&&c.push(i[u]);e.push(c)}return s(e)},r.skeleton=f,r.boundary=function(t){for(var e=[],r=0,n=t.length;r>1:(t>>1)-1}function x(t){for(var e=v(t);;){var r=e,n=2*t+1,i=2*(t+1),a=t;if(n0;){var r=y(t);if(r>=0)if(e0){var t=k[0];return g(0,M-1),M-=1,x(0),t}return-1}function w(t,e){var r=k[t];return c[r]===e?t:(c[r]=-1/0,b(t),_(),c[r]=e,b((M+=1)-1))}function T(t){if(!u[t]){u[t]=!0;var e=s[t],r=l[t];s[r]>=0&&(s[r]=e),l[e]>=0&&(l[e]=r),A[e]>=0&&w(A[e],m(e)),A[r]>=0&&w(A[r],m(r))}}var k=[],A=new Array(a);for(f=0;f>1;f>=0;--f)x(f);for(;;){var S=_();if(S<0||c[S]>r)break;T(S)}var E=[];for(f=0;f=0&&r>=0&&e!==r){var n=A[e],i=A[r];n!==i&&C.push([n,i])}})),i.unique(i.normalize(C)),{positions:E,edges:C}};var n=t("robust-orientation"),i=t("simplicial-complex")},{"robust-orientation":284,"simplicial-complex":295}],298:[function(t,e,r){"use strict";e.exports=function(t,e){var r,a,o,s;if(e[0][0]e[1][0]))return i(e,t);r=e[1],a=e[0]}if(t[0][0]t[1][0]))return-i(t,e);o=t[1],s=t[0]}var l=n(r,a,s),c=n(r,a,o);if(l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;if(l=n(s,o,a),c=n(s,o,r),l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;return a[0]-s[0]};var n=t("robust-orientation");function i(t,e){var r,i,a,o;if(e[0][0]e[1][0])){var s=Math.min(t[0][1],t[1][1]),l=Math.max(t[0][1],t[1][1]),c=Math.min(e[0][1],e[1][1]),u=Math.max(e[0][1],e[1][1]);return lu?s-u:l-u}r=e[1],i=e[0]}t[0][1]0)if(e[0]!==o[1][0])r=t,t=t.right;else{if(l=c(t.right,e))return l;t=t.left}else{if(e[0]!==o[1][0])return t;var l;if(l=c(t.right,e))return l;t=t.left}}return r}function u(t,e,r,n){this.y=t,this.index=e,this.start=r,this.closed=n}function f(t,e,r,n){this.x=t,this.segment=e,this.create=r,this.index=n}s.prototype.castUp=function(t){var e=n.le(this.coordinates,t[0]);if(e<0)return-1;this.slabs[e];var r=c(this.slabs[e],t),i=-1;if(r&&(i=r.value),this.coordinates[e]===t[0]){var s=null;if(r&&(s=r.key),e>0){var u=c(this.slabs[e-1],t);u&&(s?o(u.key,s)>0&&(s=u.key,i=u.value):(i=u.value,s=u.key))}var f=this.horizontal[e];if(f.length>0){var h=n.ge(f,t[1],l);if(h=f.length)return i;p=f[h]}}if(p.start)if(s){var d=a(s[0],s[1],[t[0],p.y]);s[0][0]>s[1][0]&&(d=-d),d>0&&(i=p.index)}else i=p.index;else p.y!==t[1]&&(i=p.index)}}}return i}},{"./lib/order-segments":298,"binary-search-bounds":31,"functional-red-black-tree":69,"robust-orientation":284}],300:[function(t,e,r){"use strict";var n=t("robust-dot-product"),i=t("robust-sum");function a(t,e){var r=i(n(t,e),[e[e.length-1]]);return r[r.length-1]}function o(t,e,r,n){var i=-e/(n-e);i<0?i=0:i>1&&(i=1);for(var a=1-i,o=t.length,s=new Array(o),l=0;l0||i>0&&u<0){var f=o(s,u,l,i);r.push(f),n.push(f.slice())}u<0?n.push(l.slice()):u>0?r.push(l.slice()):(r.push(l.slice()),n.push(l.slice())),i=u}return{positive:r,negative:n}},e.exports.positive=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c>=0&&r.push(s.slice()),n=c}return r},e.exports.negative=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c<=0&&r.push(s.slice()),n=c}return r}},{"robust-dot-product":281,"robust-sum":289}],301:[function(t,e,r){!function(){"use strict";var t={not_string:/[^s]/,not_bool:/[^t]/,not_type:/[^T]/,not_primitive:/[^v]/,number:/[diefg]/,numeric_arg:/[bcdiefguxX]/,json:/[j]/,not_json:/[^j]/,text:/^[^\x25]+/,modulo:/^\x25{2}/,placeholder:/^\x25(?:([1-9]\d*)\$|\(([^)]+)\))?(\+)?(0|'[^$])?(-)?(\d+)?(?:\.(\d+))?([b-gijostTuvxX])/,key:/^([a-z_][a-z_\d]*)/i,key_access:/^\.([a-z_][a-z_\d]*)/i,index_access:/^\[(\d+)\]/,sign:/^[+-]/};function e(t){return i(o(t),arguments)}function n(t,r){return e.apply(null,[t].concat(r||[]))}function i(r,n){var i,a,o,s,l,c,u,f,h,p=1,d=r.length,m="";for(a=0;a=0),s.type){case"b":i=parseInt(i,10).toString(2);break;case"c":i=String.fromCharCode(parseInt(i,10));break;case"d":case"i":i=parseInt(i,10);break;case"j":i=JSON.stringify(i,null,s.width?parseInt(s.width):0);break;case"e":i=s.precision?parseFloat(i).toExponential(s.precision):parseFloat(i).toExponential();break;case"f":i=s.precision?parseFloat(i).toFixed(s.precision):parseFloat(i);break;case"g":i=s.precision?String(Number(i.toPrecision(s.precision))):parseFloat(i);break;case"o":i=(parseInt(i,10)>>>0).toString(8);break;case"s":i=String(i),i=s.precision?i.substring(0,s.precision):i;break;case"t":i=String(!!i),i=s.precision?i.substring(0,s.precision):i;break;case"T":i=Object.prototype.toString.call(i).slice(8,-1).toLowerCase(),i=s.precision?i.substring(0,s.precision):i;break;case"u":i=parseInt(i,10)>>>0;break;case"v":i=i.valueOf(),i=s.precision?i.substring(0,s.precision):i;break;case"x":i=(parseInt(i,10)>>>0).toString(16);break;case"X":i=(parseInt(i,10)>>>0).toString(16).toUpperCase()}t.json.test(s.type)?m+=i:(!t.number.test(s.type)||f&&!s.sign?h="":(h=f?"+":"-",i=i.toString().replace(t.sign,"")),c=s.pad_char?"0"===s.pad_char?"0":s.pad_char.charAt(1):" ",u=s.width-(h+i).length,l=s.width&&u>0?c.repeat(u):"",m+=s.align?h+i+l:"0"===c?h+l+i:l+h+i)}return m}var a=Object.create(null);function o(e){if(a[e])return a[e];for(var r,n=e,i=[],o=0;n;){if(null!==(r=t.text.exec(n)))i.push(r[0]);else if(null!==(r=t.modulo.exec(n)))i.push("%");else{if(null===(r=t.placeholder.exec(n)))throw new SyntaxError("[sprintf] unexpected placeholder");if(r[2]){o|=1;var s=[],l=r[2],c=[];if(null===(c=t.key.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");for(s.push(c[1]);""!==(l=l.substring(c[0].length));)if(null!==(c=t.key_access.exec(l)))s.push(c[1]);else{if(null===(c=t.index_access.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");s.push(c[1])}r[2]=s}else o|=2;if(3===o)throw new Error("[sprintf] mixing positional and named placeholders is not (yet) supported");i.push({placeholder:r[0],param_no:r[1],keys:r[2],sign:r[3],pad_char:r[4],align:r[5],width:r[6],precision:r[7],type:r[8]})}n=n.substring(r[0].length)}return a[e]=i}void 0!==r&&(r.sprintf=e,r.vsprintf=n),"undefined"!=typeof window&&(window.sprintf=e,window.vsprintf=n)}()},{}],302:[function(t,e,r){"use strict";e.exports=function(t,e){if(t.dimension<=0)return{positions:[],cells:[]};if(1===t.dimension)return function(t,e){for(var r=i(t,e),n=r.length,a=new Array(n),o=new Array(n),s=0;sn|0},vertex:function(t,e,r,n,i,a,o,s,l,c,u,f,h){var p=(o<<0)+(s<<1)+(l<<2)+(c<<3)|0;if(0!==p&&15!==p)switch(p){case 0:u.push([t-.5,e-.5]);break;case 1:u.push([t-.25-.25*(n+r-2*h)/(r-n),e-.25-.25*(i+r-2*h)/(r-i)]);break;case 2:u.push([t-.75-.25*(-n-r+2*h)/(n-r),e-.25-.25*(a+n-2*h)/(n-a)]);break;case 3:u.push([t-.5,e-.5-.5*(i+r+a+n-4*h)/(r-i+n-a)]);break;case 4:u.push([t-.25-.25*(a+i-2*h)/(i-a),e-.75-.25*(-i-r+2*h)/(i-r)]);break;case 5:u.push([t-.5-.5*(n+r+a+i-4*h)/(r-n+i-a),e-.5]);break;case 6:u.push([t-.5-.25*(-n-r+a+i)/(n-r+i-a),e-.5-.25*(-i-r+a+n)/(i-r+n-a)]);break;case 7:u.push([t-.75-.25*(a+i-2*h)/(i-a),e-.75-.25*(a+n-2*h)/(n-a)]);break;case 8:u.push([t-.75-.25*(-a-i+2*h)/(a-i),e-.75-.25*(-a-n+2*h)/(a-n)]);break;case 9:u.push([t-.5-.25*(n+r+-a-i)/(r-n+a-i),e-.5-.25*(i+r+-a-n)/(r-i+a-n)]);break;case 10:u.push([t-.5-.5*(-n-r-a-i+4*h)/(n-r+a-i),e-.5]);break;case 11:u.push([t-.25-.25*(-a-i+2*h)/(a-i),e-.75-.25*(i+r-2*h)/(r-i)]);break;case 12:u.push([t-.5,e-.5-.5*(-i-r-a-n+4*h)/(i-r+a-n)]);break;case 13:u.push([t-.75-.25*(n+r-2*h)/(r-n),e-.25-.25*(-a-n+2*h)/(a-n)]);break;case 14:u.push([t-.25-.25*(-n-r+2*h)/(n-r),e-.25-.25*(-i-r+2*h)/(i-r)]);break;case 15:u.push([t-.5,e-.5])}},cell:function(t,e,r,n,i,a,o,s,l){i?s.push([t,e]):s.push([e,t])}});return function(t,e){var r=[],i=[];return n(t,r,i,e),{positions:r,cells:i}}}};var o={}},{"ndarray-extract-contour":251,"zero-crossings":318}],303:[function(t,e,r){(function(r){(function(){"use strict";e.exports=function t(e,r,i){i=i||{};var o=a[e];o||(o=a[e]={" ":{data:new Float32Array(0),shape:.2}});var s=o[r];if(!s)if(r.length<=1||!/\d/.test(r))s=o[r]=function(t){for(var e=t.cells,r=t.positions,n=new Float32Array(6*e.length),i=0,a=0,o=0;o0&&(f+=.02);var p=new Float32Array(u),d=0,m=-.5*f;for(h=0;hMath.max(r,n)?i[2]=1:r>Math.max(e,n)?i[0]=1:i[1]=1;for(var a=0,o=0,l=0;l<3;++l)a+=t[l]*t[l],o+=i[l]*t[l];for(l=0;l<3;++l)i[l]-=o/a*t[l];return s(i,i),i}function h(t,e,r,i,a,o,s,l){this.center=n(r),this.up=n(i),this.right=n(a),this.radius=n([o]),this.angle=n([s,l]),this.angle.bounds=[[-1/0,-Math.PI/2],[1/0,Math.PI/2]],this.setDistanceLimits(t,e),this.computedCenter=this.center.curve(0),this.computedUp=this.up.curve(0),this.computedRight=this.right.curve(0),this.computedRadius=this.radius.curve(0),this.computedAngle=this.angle.curve(0),this.computedToward=[0,0,0],this.computedEye=[0,0,0],this.computedMatrix=new Array(16);for(var c=0;c<16;++c)this.computedMatrix[c]=.5;this.recalcMatrix(0)}var p=h.prototype;p.setDistanceLimits=function(t,e){t=t>0?Math.log(t):-1/0,e=e>0?Math.log(e):1/0,e=Math.max(e,t),this.radius.bounds[0][0]=t,this.radius.bounds[1][0]=e},p.getDistanceLimits=function(t){var e=this.radius.bounds[0];return t?(t[0]=Math.exp(e[0][0]),t[1]=Math.exp(e[1][0]),t):[Math.exp(e[0][0]),Math.exp(e[1][0])]},p.recalcMatrix=function(t){this.center.curve(t),this.up.curve(t),this.right.curve(t),this.radius.curve(t),this.angle.curve(t);for(var e=this.computedUp,r=this.computedRight,n=0,i=0,a=0;a<3;++a)i+=e[a]*r[a],n+=e[a]*e[a];var l=Math.sqrt(n),u=0;for(a=0;a<3;++a)r[a]-=e[a]*i/n,u+=r[a]*r[a],e[a]/=l;var f=Math.sqrt(u);for(a=0;a<3;++a)r[a]/=f;var h=this.computedToward;o(h,e,r),s(h,h);var p=Math.exp(this.computedRadius[0]),d=this.computedAngle[0],m=this.computedAngle[1],g=Math.cos(d),v=Math.sin(d),y=Math.cos(m),x=Math.sin(m),b=this.computedCenter,_=g*y,w=v*y,T=x,k=-g*x,A=-v*x,M=y,S=this.computedEye,E=this.computedMatrix;for(a=0;a<3;++a){var L=_*r[a]+w*h[a]+T*e[a];E[4*a+1]=k*r[a]+A*h[a]+M*e[a],E[4*a+2]=L,E[4*a+3]=0}var C=E[1],P=E[5],I=E[9],O=E[2],z=E[6],D=E[10],R=P*D-I*z,F=I*O-C*D,B=C*z-P*O,N=c(R,F,B);R/=N,F/=N,B/=N,E[0]=R,E[4]=F,E[8]=B;for(a=0;a<3;++a)S[a]=b[a]+E[2+4*a]*p;for(a=0;a<3;++a){u=0;for(var j=0;j<3;++j)u+=E[a+4*j]*S[j];E[12+a]=-u}E[15]=1},p.getMatrix=function(t,e){this.recalcMatrix(t);var r=this.computedMatrix;if(e){for(var n=0;n<16;++n)e[n]=r[n];return e}return r};var d=[0,0,0];p.rotate=function(t,e,r,n){if(this.angle.move(t,e,r),n){this.recalcMatrix(t);var i=this.computedMatrix;d[0]=i[2],d[1]=i[6],d[2]=i[10];for(var o=this.computedUp,s=this.computedRight,l=this.computedToward,c=0;c<3;++c)i[4*c]=o[c],i[4*c+1]=s[c],i[4*c+2]=l[c];a(i,i,n,d);for(c=0;c<3;++c)o[c]=i[4*c],s[c]=i[4*c+1];this.up.set(t,o[0],o[1],o[2]),this.right.set(t,s[0],s[1],s[2])}},p.pan=function(t,e,r,n){e=e||0,r=r||0,n=n||0,this.recalcMatrix(t);var i=this.computedMatrix,a=(Math.exp(this.computedRadius[0]),i[1]),o=i[5],s=i[9],l=c(a,o,s);a/=l,o/=l,s/=l;var u=i[0],f=i[4],h=i[8],p=u*a+f*o+h*s,d=c(u-=a*p,f-=o*p,h-=s*p),m=(u/=d)*e+a*r,g=(f/=d)*e+o*r,v=(h/=d)*e+s*r;this.center.move(t,m,g,v);var y=Math.exp(this.computedRadius[0]);y=Math.max(1e-4,y+n),this.radius.set(t,Math.log(y))},p.translate=function(t,e,r,n){this.center.move(t,e||0,r||0,n||0)},p.setMatrix=function(t,e,r,n){var a=1;"number"==typeof r&&(a=0|r),(a<0||a>3)&&(a=1);var o=(a+2)%3;e||(this.recalcMatrix(t),e=this.computedMatrix);var s=e[a],l=e[a+4],f=e[a+8];if(n){var h=Math.abs(s),p=Math.abs(l),d=Math.abs(f),m=Math.max(h,p,d);h===m?(s=s<0?-1:1,l=f=0):d===m?(f=f<0?-1:1,s=l=0):(l=l<0?-1:1,s=f=0)}else{var g=c(s,l,f);s/=g,l/=g,f/=g}var v,y,x=e[o],b=e[o+4],_=e[o+8],w=x*s+b*l+_*f,T=c(x-=s*w,b-=l*w,_-=f*w),k=l*(_/=T)-f*(b/=T),A=f*(x/=T)-s*_,M=s*b-l*x,S=c(k,A,M);if(k/=S,A/=S,M/=S,this.center.jump(t,q,G,Y),this.radius.idle(t),this.up.jump(t,s,l,f),this.right.jump(t,x,b,_),2===a){var E=e[1],L=e[5],C=e[9],P=E*x+L*b+C*_,I=E*k+L*A+C*M;v=R<0?-Math.PI/2:Math.PI/2,y=Math.atan2(I,P)}else{var O=e[2],z=e[6],D=e[10],R=O*s+z*l+D*f,F=O*x+z*b+D*_,B=O*k+z*A+D*M;v=Math.asin(u(R)),y=Math.atan2(B,F)}this.angle.jump(t,y,v),this.recalcMatrix(t);var N=e[2],j=e[6],U=e[10],V=this.computedMatrix;i(V,e);var H=V[15],q=V[12]/H,G=V[13]/H,Y=V[14]/H,W=Math.exp(this.computedRadius[0]);this.center.jump(t,q-N*W,G-j*W,Y-U*W)},p.lastT=function(){return Math.max(this.center.lastT(),this.up.lastT(),this.right.lastT(),this.radius.lastT(),this.angle.lastT())},p.idle=function(t){this.center.idle(t),this.up.idle(t),this.right.idle(t),this.radius.idle(t),this.angle.idle(t)},p.flush=function(t){this.center.flush(t),this.up.flush(t),this.right.flush(t),this.radius.flush(t),this.angle.flush(t)},p.setDistance=function(t,e){e>0&&this.radius.set(t,Math.log(e))},p.lookAt=function(t,e,r,n){this.recalcMatrix(t),e=e||this.computedEye,r=r||this.computedCenter;var i=(n=n||this.computedUp)[0],a=n[1],o=n[2],s=c(i,a,o);if(!(s<1e-6)){i/=s,a/=s,o/=s;var l=e[0]-r[0],f=e[1]-r[1],h=e[2]-r[2],p=c(l,f,h);if(!(p<1e-6)){l/=p,f/=p,h/=p;var d=this.computedRight,m=d[0],g=d[1],v=d[2],y=i*m+a*g+o*v,x=c(m-=y*i,g-=y*a,v-=y*o);if(!(x<.01&&(x=c(m=a*h-o*f,g=o*l-i*h,v=i*f-a*l))<1e-6)){m/=x,g/=x,v/=x,this.up.set(t,i,a,o),this.right.set(t,m,g,v),this.center.set(t,r[0],r[1],r[2]),this.radius.set(t,Math.log(p));var b=a*v-o*g,_=o*m-i*v,w=i*g-a*m,T=c(b,_,w),k=i*l+a*f+o*h,A=m*l+g*f+v*h,M=(b/=T)*l+(_/=T)*f+(w/=T)*h,S=Math.asin(u(k)),E=Math.atan2(M,A),L=this.angle._state,C=L[L.length-1],P=L[L.length-2];C%=2*Math.PI;var I=Math.abs(C+2*Math.PI-E),O=Math.abs(C-E),z=Math.abs(C-2*Math.PI-E);I0?r.pop():new ArrayBuffer(t)}function d(t){return new Uint8Array(p(t),0,t)}function m(t){return new Uint16Array(p(2*t),0,t)}function g(t){return new Uint32Array(p(4*t),0,t)}function v(t){return new Int8Array(p(t),0,t)}function y(t){return new Int16Array(p(2*t),0,t)}function x(t){return new Int32Array(p(4*t),0,t)}function b(t){return new Float32Array(p(4*t),0,t)}function _(t){return new Float64Array(p(8*t),0,t)}function w(t){return o?new Uint8ClampedArray(p(t),0,t):d(t)}function T(t){return s?new BigUint64Array(p(8*t),0,t):null}function k(t){return l?new BigInt64Array(p(8*t),0,t):null}function A(t){return new DataView(p(t),0,t)}function M(t){t=n.nextPow2(t);var e=n.log2(t),r=f[e];return r.length>0?r.pop():new a(t)}r.free=function(t){if(a.isBuffer(t))f[n.log2(t.length)].push(t);else{if("[object ArrayBuffer]"!==Object.prototype.toString.call(t)&&(t=t.buffer),!t)return;var e=t.length||t.byteLength,r=0|n.log2(e);u[r].push(t)}},r.freeUint8=r.freeUint16=r.freeUint32=r.freeBigUint64=r.freeInt8=r.freeInt16=r.freeInt32=r.freeBigInt64=r.freeFloat32=r.freeFloat=r.freeFloat64=r.freeDouble=r.freeUint8Clamped=r.freeDataView=function(t){h(t.buffer)},r.freeArrayBuffer=h,r.freeBuffer=function(t){f[n.log2(t.length)].push(t)},r.malloc=function(t,e){if(void 0===e||"arraybuffer"===e)return p(t);switch(e){case"uint8":return d(t);case"uint16":return m(t);case"uint32":return g(t);case"int8":return v(t);case"int16":return y(t);case"int32":return x(t);case"float":case"float32":return b(t);case"double":case"float64":return _(t);case"uint8_clamped":return w(t);case"bigint64":return k(t);case"biguint64":return T(t);case"buffer":return M(t);case"data":case"dataview":return A(t);default:return null}return null},r.mallocArrayBuffer=p,r.mallocUint8=d,r.mallocUint16=m,r.mallocUint32=g,r.mallocInt8=v,r.mallocInt16=y,r.mallocInt32=x,r.mallocFloat32=r.mallocFloat=b,r.mallocFloat64=r.mallocDouble=_,r.mallocUint8Clamped=w,r.mallocBigUint64=T,r.mallocBigInt64=k,r.mallocDataView=A,r.mallocBuffer=M,r.clearCache=function(){for(var t=0;t<32;++t)c.UINT8[t].length=0,c.UINT16[t].length=0,c.UINT32[t].length=0,c.INT8[t].length=0,c.INT16[t].length=0,c.INT32[t].length=0,c.FLOAT[t].length=0,c.DOUBLE[t].length=0,c.BIGUINT64[t].length=0,c.BIGINT64[t].length=0,c.UINT8C[t].length=0,u[t].length=0,f[t].length=0}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{"bit-twiddle":32,buffer:3,dup:65}],309:[function(t,e,r){"use strict";function n(t){this.roots=new Array(t),this.ranks=new Array(t);for(var e=0;e0&&(a=n.size),n.lineSpacing&&n.lineSpacing>0&&(o=n.lineSpacing),n.styletags&&n.styletags.breaklines&&(s.breaklines=!!n.styletags.breaklines),n.styletags&&n.styletags.bolds&&(s.bolds=!!n.styletags.bolds),n.styletags&&n.styletags.italics&&(s.italics=!!n.styletags.italics),n.styletags&&n.styletags.subscripts&&(s.subscripts=!!n.styletags.subscripts),n.styletags&&n.styletags.superscripts&&(s.superscripts=!!n.styletags.superscripts));return r.font=[n.fontStyle,n.fontVariant,n.fontWeight,a+"px",n.font].filter((function(t){return t})).join(" "),r.textAlign="start",r.textBaseline="alphabetic",r.direction="ltr",h(function(t,e,r,n,a,o){r=r.replace(/\n/g,""),r=!0===o.breaklines?r.replace(/\/g,"\n"):r.replace(/\/g," ");var s="",l=[];for(p=0;p-1?parseInt(t[1+i]):0,l=a>-1?parseInt(r[1+a]):0;s!==l&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,l-s),n=n.replace("?px ",S())),m+=.25*x*(l-s)}if(!0===o.superscripts){var c=t.indexOf("+"),u=r.indexOf("+"),f=c>-1?parseInt(t[1+c]):0,h=u>-1?parseInt(r[1+u]):0;f!==h&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,h-f),n=n.replace("?px ",S())),m-=.25*x*(h-f)}if(!0===o.bolds){var p=t.indexOf("b|")>-1,d=r.indexOf("b|")>-1;!p&&d&&(n=v?n.replace("italic ","italic bold "):"bold "+n),p&&!d&&(n=n.replace("bold ",""))}if(!0===o.italics){var v=t.indexOf("i|")>-1,y=r.indexOf("i|")>-1;!v&&y&&(n="italic "+n),v&&!y&&(n=n.replace("italic ",""))}e.font=n}for(h=0;h",a="",o=i.length,s=a.length,l="+"===e[0]||"-"===e[0],c=0,u=-s;c>-1&&-1!==(c=r.indexOf(i,c))&&-1!==(u=r.indexOf(a,c+o))&&!(u<=c);){for(var f=c;f=u)n[f]=null,r=r.substr(0,f)+" "+r.substr(f+1);else if(null!==n[f]){var h=n[f].indexOf(e[0]);-1===h?n[f]+=e:l&&(n[f]=n[f].substr(0,h+1)+(1+parseInt(n[f][h+1]))+n[f].substr(h+2))}var p=c+o,d=r.substr(p,u-p).indexOf(i);c=-1!==d?d:u+s}return n}function u(t,e){var r=n(t,128);return e?a(r.cells,r.positions,.25):{edges:r.cells,positions:r.positions}}function f(t,e,r,n){var i=u(t,n),a=function(t,e,r){for(var n=e.textAlign||"start",i=e.textBaseline||"alphabetic",a=[1<<30,1<<30],o=[0,0],s=t.length,l=0;l=0?e[a]:i}))},has___:{value:y((function(e){var n=v(e);return n?r in n:t.indexOf(e)>=0}))},set___:{value:y((function(n,i){var a,o=v(n);return o?o[r]=i:(a=t.indexOf(n))>=0?e[a]=i:(a=t.length,e[a]=i,t[a]=n),this}))},delete___:{value:y((function(n){var i,a,o=v(n);return o?r in o&&delete o[r]:!((i=t.indexOf(n))<0)&&(a=t.length-1,t[i]=void 0,e[i]=e[a],t[i]=t[a],t.length=a,e.length=a,!0)}))}})};d.prototype=Object.create(Object.prototype,{get:{value:function(t,e){return this.get___(t,e)},writable:!0,configurable:!0},has:{value:function(t){return this.has___(t)},writable:!0,configurable:!0},set:{value:function(t,e){return this.set___(t,e)},writable:!0,configurable:!0},delete:{value:function(t){return this.delete___(t)},writable:!0,configurable:!0}}),"function"==typeof r?function(){function n(){this instanceof d||x();var e,n=new r,i=void 0,a=!1;return e=t?function(t,e){return n.set(t,e),n.has(t)||(i||(i=new d),i.set(t,e)),this}:function(t,e){if(a)try{n.set(t,e)}catch(r){i||(i=new d),i.set___(t,e)}else n.set(t,e);return this},Object.create(d.prototype,{get___:{value:y((function(t,e){return i?n.has(t)?n.get(t):i.get___(t,e):n.get(t,e)}))},has___:{value:y((function(t){return n.has(t)||!!i&&i.has___(t)}))},set___:{value:y(e)},delete___:{value:y((function(t){var e=!!n.delete(t);return i&&i.delete___(t)||e}))},permitHostObjects___:{value:y((function(t){if(t!==m)throw new Error("bogus call to permitHostObjects___");a=!0}))}})}t&&"undefined"!=typeof Proxy&&(Proxy=void 0),n.prototype=d.prototype,e.exports=n,Object.defineProperty(WeakMap.prototype,"constructor",{value:WeakMap,enumerable:!1,configurable:!0,writable:!0})}():("undefined"!=typeof Proxy&&(Proxy=void 0),e.exports=d)}function m(t){t.permitHostObjects___&&t.permitHostObjects___(m)}function g(t){return!("weakmap:"==t.substr(0,"weakmap:".length)&&"___"===t.substr(t.length-3))}function v(t){if(t!==Object(t))throw new TypeError("Not an object: "+t);var e=t[l];if(e&&e.key===t)return e;if(s(t)){e={key:t};try{return o(t,l,{value:e,writable:!1,enumerable:!1,configurable:!1}),e}catch(t){return}}}function y(t){return t.prototype=null,Object.freeze(t)}function x(){h||"undefined"==typeof console||(h=!0,console.warn("WeakMap should be invoked as new WeakMap(), not WeakMap(). This will be an error in the future."))}}()},{}],314:[function(t,e,r){var n=t("./hidden-store.js");e.exports=function(){var t={};return function(e){if(("object"!=typeof e||null===e)&&"function"!=typeof e)throw new Error("Weakmap-shim: Key must be object");var r=e.valueOf(t);return r&&r.identity===t?r:n(e,t)}}},{"./hidden-store.js":315}],315:[function(t,e,r){e.exports=function(t,e){var r={identity:e},n=t.valueOf;return Object.defineProperty(t,"valueOf",{value:function(t){return t!==e?n.apply(this,arguments):r},writable:!0}),r}},{}],316:[function(t,e,r){var n=t("./create-store.js");e.exports=function(){var t=n();return{get:function(e,r){var n=t(e);return n.hasOwnProperty("value")?n.value:r},set:function(e,r){return t(e).value=r,this},has:function(e){return"value"in t(e)},delete:function(e){return delete t(e).value}}}},{"./create-store.js":314}],317:[function(t,e,r){"use strict";var n,i=function(){return function(t,e,r,n,i,a){var o=t[0],s=r[0],l=[0],c=s;n|=0;var u=0,f=s;for(u=0;u=0!=p>=0&&i.push(l[0]+.5+.5*(h+p)/(h-p)),n+=f,++l[0]}}};e.exports=(n={funcName:{funcName:"zeroCrossings"}.funcName},function(t){var e={};return function(r,n,i){var a=r.dtype,o=r.order,s=[a,o.join()].join(),l=e[s];return l||(e[s]=l=t([a,o])),l(r.shape.slice(0),r.data,r.stride,0|r.offset,n,i)}}(i.bind(void 0,n)))},{}],318:[function(t,e,r){"use strict";e.exports=function(t,e){var r=[];return e=+e||0,n(t.hi(t.shape[0]-1),r,e),r};var n=t("./lib/zc-core")},{"./lib/zc-core":317}]},{},[6])(6)}))}).call(this)}).call(this,"undefined"!=typeof global?global:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}]},{},[27])(27)}));
\ No newline at end of file diff --git a/docs/source/_static/kitti_2015_epe_outlier-drop_kitti_sintel.html b/docs/source/_static/kitti_2015_epe_outlier-drop_kitti_sintel.html index e51b98a..5985153 100644 --- a/docs/source/_static/kitti_2015_epe_outlier-drop_kitti_sintel.html +++ b/docs/source/_static/kitti_2015_epe_outlier-drop_kitti_sintel.html @@ -3,30 +3,30 @@
+"use strict";var n,i="";e.exports=function(t,e){if("string"!=typeof t)throw new TypeError("expected a string");if(1===e)return t;if(2===e)return t+t;var r=t.length*e;if(n!==t||void 0===n)n=t,i="";else if(i.length>=r)return i.substr(0,r);for(;r>i.length&&e>1;)1&e&&(i+=t),e>>=1,t+=t;return i=(i+=t).substr(0,r)}},{}],278:[function(t,e,r){(function(t){(function(){e.exports=t.performance&&t.performance.now?function(){return performance.now()}:Date.now||function(){return+new Date}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}],279:[function(t,e,r){"use strict";e.exports=function(t){for(var e=t.length,r=t[t.length-1],n=e,i=e-2;i>=0;--i){var a=r,o=t[i];(l=o-((r=a+o)-a))&&(t[--n]=r,r=l)}var s=0;for(i=n;i0){if(a<=0)return o;n=i+a}else{if(!(i<0))return o;if(a>=0)return o;n=-(i+a)}var s=33306690738754716e-32*n;return o>=s||o<=-s?o:f(t,e,r)},function(t,e,r,n){var i=t[0]-n[0],a=e[0]-n[0],o=r[0]-n[0],s=t[1]-n[1],l=e[1]-n[1],c=r[1]-n[1],u=t[2]-n[2],f=e[2]-n[2],p=r[2]-n[2],d=a*c,m=o*l,g=o*s,v=i*c,y=i*l,x=a*s,b=u*(d-m)+f*(g-v)+p*(y-x),_=7771561172376103e-31*((Math.abs(d)+Math.abs(m))*Math.abs(u)+(Math.abs(g)+Math.abs(v))*Math.abs(f)+(Math.abs(y)+Math.abs(x))*Math.abs(p));return b>_||-b>_?b:h(t,e,r,n)}];function d(t){var e=p[t.length];return e||(e=p[t.length]=u(t.length)),e.apply(void 0,t)}function m(t,e,r,n,i,a,o){return function(e,r,s,l,c){switch(arguments.length){case 0:case 1:return 0;case 2:return n(e,r);case 3:return i(e,r,s);case 4:return a(e,r,s,l);case 5:return o(e,r,s,l,c)}for(var u=new Array(arguments.length),f=0;f0&&o>0||a<0&&o<0)return!1;var s=n(r,t,e),l=n(i,t,e);if(s>0&&l>0||s<0&&l<0)return!1;if(0===a&&0===o&&0===s&&0===l)return function(t,e,r,n){for(var i=0;i<2;++i){var a=t[i],o=e[i],s=Math.min(a,o),l=Math.max(a,o),c=r[i],u=n[i],f=Math.min(c,u);if(Math.max(c,u)=n?(i=f,(l+=1)=n?(i=f,(l+=1)>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,s=a(t[o],e);s<=0?(0===s&&(i=o),r=o+1):s>0&&(n=o-1)}return i}function u(t,e){for(var r=new Array(t.length),i=0,o=r.length;i=t.length||0!==a(t[g],s)););}return r}function f(t,e){if(e<0)return[];for(var r=[],i=(1<>>u&1&&c.push(i[u]);e.push(c)}return s(e)},r.skeleton=f,r.boundary=function(t){for(var e=[],r=0,n=t.length;r>1:(t>>1)-1}function x(t){for(var e=v(t);;){var r=e,n=2*t+1,i=2*(t+1),a=t;if(n0;){var r=y(t);if(r>=0)if(e0){var t=k[0];return g(0,M-1),M-=1,x(0),t}return-1}function w(t,e){var r=k[t];return c[r]===e?t:(c[r]=-1/0,b(t),_(),c[r]=e,b((M+=1)-1))}function T(t){if(!u[t]){u[t]=!0;var e=s[t],r=l[t];s[r]>=0&&(s[r]=e),l[e]>=0&&(l[e]=r),A[e]>=0&&w(A[e],m(e)),A[r]>=0&&w(A[r],m(r))}}var k=[],A=new Array(a);for(f=0;f>1;f>=0;--f)x(f);for(;;){var S=_();if(S<0||c[S]>r)break;T(S)}var E=[];for(f=0;f=0&&r>=0&&e!==r){var n=A[e],i=A[r];n!==i&&C.push([n,i])}})),i.unique(i.normalize(C)),{positions:E,edges:C}};var n=t("robust-orientation"),i=t("simplicial-complex")},{"robust-orientation":284,"simplicial-complex":295}],298:[function(t,e,r){"use strict";e.exports=function(t,e){var r,a,o,s;if(e[0][0]e[1][0]))return i(e,t);r=e[1],a=e[0]}if(t[0][0]t[1][0]))return-i(t,e);o=t[1],s=t[0]}var l=n(r,a,s),c=n(r,a,o);if(l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;if(l=n(s,o,a),c=n(s,o,r),l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;return a[0]-s[0]};var n=t("robust-orientation");function i(t,e){var r,i,a,o;if(e[0][0]e[1][0])){var s=Math.min(t[0][1],t[1][1]),l=Math.max(t[0][1],t[1][1]),c=Math.min(e[0][1],e[1][1]),u=Math.max(e[0][1],e[1][1]);return lu?s-u:l-u}r=e[1],i=e[0]}t[0][1]0)if(e[0]!==o[1][0])r=t,t=t.right;else{if(l=c(t.right,e))return l;t=t.left}else{if(e[0]!==o[1][0])return t;var l;if(l=c(t.right,e))return l;t=t.left}}return r}function u(t,e,r,n){this.y=t,this.index=e,this.start=r,this.closed=n}function f(t,e,r,n){this.x=t,this.segment=e,this.create=r,this.index=n}s.prototype.castUp=function(t){var e=n.le(this.coordinates,t[0]);if(e<0)return-1;this.slabs[e];var r=c(this.slabs[e],t),i=-1;if(r&&(i=r.value),this.coordinates[e]===t[0]){var s=null;if(r&&(s=r.key),e>0){var u=c(this.slabs[e-1],t);u&&(s?o(u.key,s)>0&&(s=u.key,i=u.value):(i=u.value,s=u.key))}var f=this.horizontal[e];if(f.length>0){var h=n.ge(f,t[1],l);if(h=f.length)return i;p=f[h]}}if(p.start)if(s){var d=a(s[0],s[1],[t[0],p.y]);s[0][0]>s[1][0]&&(d=-d),d>0&&(i=p.index)}else i=p.index;else p.y!==t[1]&&(i=p.index)}}}return i}},{"./lib/order-segments":298,"binary-search-bounds":31,"functional-red-black-tree":69,"robust-orientation":284}],300:[function(t,e,r){"use strict";var n=t("robust-dot-product"),i=t("robust-sum");function a(t,e){var r=i(n(t,e),[e[e.length-1]]);return r[r.length-1]}function o(t,e,r,n){var i=-e/(n-e);i<0?i=0:i>1&&(i=1);for(var a=1-i,o=t.length,s=new Array(o),l=0;l0||i>0&&u<0){var f=o(s,u,l,i);r.push(f),n.push(f.slice())}u<0?n.push(l.slice()):u>0?r.push(l.slice()):(r.push(l.slice()),n.push(l.slice())),i=u}return{positive:r,negative:n}},e.exports.positive=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c>=0&&r.push(s.slice()),n=c}return r},e.exports.negative=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c<=0&&r.push(s.slice()),n=c}return r}},{"robust-dot-product":281,"robust-sum":289}],301:[function(t,e,r){!function(){"use strict";var t={not_string:/[^s]/,not_bool:/[^t]/,not_type:/[^T]/,not_primitive:/[^v]/,number:/[diefg]/,numeric_arg:/[bcdiefguxX]/,json:/[j]/,not_json:/[^j]/,text:/^[^\x25]+/,modulo:/^\x25{2}/,placeholder:/^\x25(?:([1-9]\d*)\$|\(([^)]+)\))?(\+)?(0|'[^$])?(-)?(\d+)?(?:\.(\d+))?([b-gijostTuvxX])/,key:/^([a-z_][a-z_\d]*)/i,key_access:/^\.([a-z_][a-z_\d]*)/i,index_access:/^\[(\d+)\]/,sign:/^[+-]/};function e(t){return i(o(t),arguments)}function n(t,r){return e.apply(null,[t].concat(r||[]))}function i(r,n){var i,a,o,s,l,c,u,f,h,p=1,d=r.length,m="";for(a=0;a=0),s.type){case"b":i=parseInt(i,10).toString(2);break;case"c":i=String.fromCharCode(parseInt(i,10));break;case"d":case"i":i=parseInt(i,10);break;case"j":i=JSON.stringify(i,null,s.width?parseInt(s.width):0);break;case"e":i=s.precision?parseFloat(i).toExponential(s.precision):parseFloat(i).toExponential();break;case"f":i=s.precision?parseFloat(i).toFixed(s.precision):parseFloat(i);break;case"g":i=s.precision?String(Number(i.toPrecision(s.precision))):parseFloat(i);break;case"o":i=(parseInt(i,10)>>>0).toString(8);break;case"s":i=String(i),i=s.precision?i.substring(0,s.precision):i;break;case"t":i=String(!!i),i=s.precision?i.substring(0,s.precision):i;break;case"T":i=Object.prototype.toString.call(i).slice(8,-1).toLowerCase(),i=s.precision?i.substring(0,s.precision):i;break;case"u":i=parseInt(i,10)>>>0;break;case"v":i=i.valueOf(),i=s.precision?i.substring(0,s.precision):i;break;case"x":i=(parseInt(i,10)>>>0).toString(16);break;case"X":i=(parseInt(i,10)>>>0).toString(16).toUpperCase()}t.json.test(s.type)?m+=i:(!t.number.test(s.type)||f&&!s.sign?h="":(h=f?"+":"-",i=i.toString().replace(t.sign,"")),c=s.pad_char?"0"===s.pad_char?"0":s.pad_char.charAt(1):" ",u=s.width-(h+i).length,l=s.width&&u>0?c.repeat(u):"",m+=s.align?h+i+l:"0"===c?h+l+i:l+h+i)}return m}var a=Object.create(null);function o(e){if(a[e])return a[e];for(var r,n=e,i=[],o=0;n;){if(null!==(r=t.text.exec(n)))i.push(r[0]);else if(null!==(r=t.modulo.exec(n)))i.push("%");else{if(null===(r=t.placeholder.exec(n)))throw new SyntaxError("[sprintf] unexpected placeholder");if(r[2]){o|=1;var s=[],l=r[2],c=[];if(null===(c=t.key.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");for(s.push(c[1]);""!==(l=l.substring(c[0].length));)if(null!==(c=t.key_access.exec(l)))s.push(c[1]);else{if(null===(c=t.index_access.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");s.push(c[1])}r[2]=s}else o|=2;if(3===o)throw new Error("[sprintf] mixing positional and named placeholders is not (yet) supported");i.push({placeholder:r[0],param_no:r[1],keys:r[2],sign:r[3],pad_char:r[4],align:r[5],width:r[6],precision:r[7],type:r[8]})}n=n.substring(r[0].length)}return a[e]=i}void 0!==r&&(r.sprintf=e,r.vsprintf=n),"undefined"!=typeof window&&(window.sprintf=e,window.vsprintf=n)}()},{}],302:[function(t,e,r){"use strict";e.exports=function(t,e){if(t.dimension<=0)return{positions:[],cells:[]};if(1===t.dimension)return function(t,e){for(var r=i(t,e),n=r.length,a=new Array(n),o=new Array(n),s=0;sn|0},vertex:function(t,e,r,n,i,a,o,s,l,c,u,f,h){var p=(o<<0)+(s<<1)+(l<<2)+(c<<3)|0;if(0!==p&&15!==p)switch(p){case 0:u.push([t-.5,e-.5]);break;case 1:u.push([t-.25-.25*(n+r-2*h)/(r-n),e-.25-.25*(i+r-2*h)/(r-i)]);break;case 2:u.push([t-.75-.25*(-n-r+2*h)/(n-r),e-.25-.25*(a+n-2*h)/(n-a)]);break;case 3:u.push([t-.5,e-.5-.5*(i+r+a+n-4*h)/(r-i+n-a)]);break;case 4:u.push([t-.25-.25*(a+i-2*h)/(i-a),e-.75-.25*(-i-r+2*h)/(i-r)]);break;case 5:u.push([t-.5-.5*(n+r+a+i-4*h)/(r-n+i-a),e-.5]);break;case 6:u.push([t-.5-.25*(-n-r+a+i)/(n-r+i-a),e-.5-.25*(-i-r+a+n)/(i-r+n-a)]);break;case 7:u.push([t-.75-.25*(a+i-2*h)/(i-a),e-.75-.25*(a+n-2*h)/(n-a)]);break;case 8:u.push([t-.75-.25*(-a-i+2*h)/(a-i),e-.75-.25*(-a-n+2*h)/(a-n)]);break;case 9:u.push([t-.5-.25*(n+r+-a-i)/(r-n+a-i),e-.5-.25*(i+r+-a-n)/(r-i+a-n)]);break;case 10:u.push([t-.5-.5*(-n-r-a-i+4*h)/(n-r+a-i),e-.5]);break;case 11:u.push([t-.25-.25*(-a-i+2*h)/(a-i),e-.75-.25*(i+r-2*h)/(r-i)]);break;case 12:u.push([t-.5,e-.5-.5*(-i-r-a-n+4*h)/(i-r+a-n)]);break;case 13:u.push([t-.75-.25*(n+r-2*h)/(r-n),e-.25-.25*(-a-n+2*h)/(a-n)]);break;case 14:u.push([t-.25-.25*(-n-r+2*h)/(n-r),e-.25-.25*(-i-r+2*h)/(i-r)]);break;case 15:u.push([t-.5,e-.5])}},cell:function(t,e,r,n,i,a,o,s,l){i?s.push([t,e]):s.push([e,t])}});return function(t,e){var r=[],i=[];return n(t,r,i,e),{positions:r,cells:i}}}};var o={}},{"ndarray-extract-contour":251,"zero-crossings":318}],303:[function(t,e,r){(function(r){(function(){"use strict";e.exports=function t(e,r,i){i=i||{};var o=a[e];o||(o=a[e]={" ":{data:new Float32Array(0),shape:.2}});var s=o[r];if(!s)if(r.length<=1||!/\d/.test(r))s=o[r]=function(t){for(var e=t.cells,r=t.positions,n=new Float32Array(6*e.length),i=0,a=0,o=0;o0&&(f+=.02);var p=new Float32Array(u),d=0,m=-.5*f;for(h=0;hMath.max(r,n)?i[2]=1:r>Math.max(e,n)?i[0]=1:i[1]=1;for(var a=0,o=0,l=0;l<3;++l)a+=t[l]*t[l],o+=i[l]*t[l];for(l=0;l<3;++l)i[l]-=o/a*t[l];return s(i,i),i}function h(t,e,r,i,a,o,s,l){this.center=n(r),this.up=n(i),this.right=n(a),this.radius=n([o]),this.angle=n([s,l]),this.angle.bounds=[[-1/0,-Math.PI/2],[1/0,Math.PI/2]],this.setDistanceLimits(t,e),this.computedCenter=this.center.curve(0),this.computedUp=this.up.curve(0),this.computedRight=this.right.curve(0),this.computedRadius=this.radius.curve(0),this.computedAngle=this.angle.curve(0),this.computedToward=[0,0,0],this.computedEye=[0,0,0],this.computedMatrix=new Array(16);for(var c=0;c<16;++c)this.computedMatrix[c]=.5;this.recalcMatrix(0)}var p=h.prototype;p.setDistanceLimits=function(t,e){t=t>0?Math.log(t):-1/0,e=e>0?Math.log(e):1/0,e=Math.max(e,t),this.radius.bounds[0][0]=t,this.radius.bounds[1][0]=e},p.getDistanceLimits=function(t){var e=this.radius.bounds[0];return t?(t[0]=Math.exp(e[0][0]),t[1]=Math.exp(e[1][0]),t):[Math.exp(e[0][0]),Math.exp(e[1][0])]},p.recalcMatrix=function(t){this.center.curve(t),this.up.curve(t),this.right.curve(t),this.radius.curve(t),this.angle.curve(t);for(var e=this.computedUp,r=this.computedRight,n=0,i=0,a=0;a<3;++a)i+=e[a]*r[a],n+=e[a]*e[a];var l=Math.sqrt(n),u=0;for(a=0;a<3;++a)r[a]-=e[a]*i/n,u+=r[a]*r[a],e[a]/=l;var f=Math.sqrt(u);for(a=0;a<3;++a)r[a]/=f;var h=this.computedToward;o(h,e,r),s(h,h);var p=Math.exp(this.computedRadius[0]),d=this.computedAngle[0],m=this.computedAngle[1],g=Math.cos(d),v=Math.sin(d),y=Math.cos(m),x=Math.sin(m),b=this.computedCenter,_=g*y,w=v*y,T=x,k=-g*x,A=-v*x,M=y,S=this.computedEye,E=this.computedMatrix;for(a=0;a<3;++a){var L=_*r[a]+w*h[a]+T*e[a];E[4*a+1]=k*r[a]+A*h[a]+M*e[a],E[4*a+2]=L,E[4*a+3]=0}var C=E[1],P=E[5],I=E[9],O=E[2],z=E[6],D=E[10],R=P*D-I*z,F=I*O-C*D,B=C*z-P*O,N=c(R,F,B);R/=N,F/=N,B/=N,E[0]=R,E[4]=F,E[8]=B;for(a=0;a<3;++a)S[a]=b[a]+E[2+4*a]*p;for(a=0;a<3;++a){u=0;for(var j=0;j<3;++j)u+=E[a+4*j]*S[j];E[12+a]=-u}E[15]=1},p.getMatrix=function(t,e){this.recalcMatrix(t);var r=this.computedMatrix;if(e){for(var n=0;n<16;++n)e[n]=r[n];return e}return r};var d=[0,0,0];p.rotate=function(t,e,r,n){if(this.angle.move(t,e,r),n){this.recalcMatrix(t);var i=this.computedMatrix;d[0]=i[2],d[1]=i[6],d[2]=i[10];for(var o=this.computedUp,s=this.computedRight,l=this.computedToward,c=0;c<3;++c)i[4*c]=o[c],i[4*c+1]=s[c],i[4*c+2]=l[c];a(i,i,n,d);for(c=0;c<3;++c)o[c]=i[4*c],s[c]=i[4*c+1];this.up.set(t,o[0],o[1],o[2]),this.right.set(t,s[0],s[1],s[2])}},p.pan=function(t,e,r,n){e=e||0,r=r||0,n=n||0,this.recalcMatrix(t);var i=this.computedMatrix,a=(Math.exp(this.computedRadius[0]),i[1]),o=i[5],s=i[9],l=c(a,o,s);a/=l,o/=l,s/=l;var u=i[0],f=i[4],h=i[8],p=u*a+f*o+h*s,d=c(u-=a*p,f-=o*p,h-=s*p),m=(u/=d)*e+a*r,g=(f/=d)*e+o*r,v=(h/=d)*e+s*r;this.center.move(t,m,g,v);var y=Math.exp(this.computedRadius[0]);y=Math.max(1e-4,y+n),this.radius.set(t,Math.log(y))},p.translate=function(t,e,r,n){this.center.move(t,e||0,r||0,n||0)},p.setMatrix=function(t,e,r,n){var a=1;"number"==typeof r&&(a=0|r),(a<0||a>3)&&(a=1);var o=(a+2)%3;e||(this.recalcMatrix(t),e=this.computedMatrix);var s=e[a],l=e[a+4],f=e[a+8];if(n){var h=Math.abs(s),p=Math.abs(l),d=Math.abs(f),m=Math.max(h,p,d);h===m?(s=s<0?-1:1,l=f=0):d===m?(f=f<0?-1:1,s=l=0):(l=l<0?-1:1,s=f=0)}else{var g=c(s,l,f);s/=g,l/=g,f/=g}var v,y,x=e[o],b=e[o+4],_=e[o+8],w=x*s+b*l+_*f,T=c(x-=s*w,b-=l*w,_-=f*w),k=l*(_/=T)-f*(b/=T),A=f*(x/=T)-s*_,M=s*b-l*x,S=c(k,A,M);if(k/=S,A/=S,M/=S,this.center.jump(t,q,G,Y),this.radius.idle(t),this.up.jump(t,s,l,f),this.right.jump(t,x,b,_),2===a){var E=e[1],L=e[5],C=e[9],P=E*x+L*b+C*_,I=E*k+L*A+C*M;v=R<0?-Math.PI/2:Math.PI/2,y=Math.atan2(I,P)}else{var O=e[2],z=e[6],D=e[10],R=O*s+z*l+D*f,F=O*x+z*b+D*_,B=O*k+z*A+D*M;v=Math.asin(u(R)),y=Math.atan2(B,F)}this.angle.jump(t,y,v),this.recalcMatrix(t);var N=e[2],j=e[6],U=e[10],V=this.computedMatrix;i(V,e);var H=V[15],q=V[12]/H,G=V[13]/H,Y=V[14]/H,W=Math.exp(this.computedRadius[0]);this.center.jump(t,q-N*W,G-j*W,Y-U*W)},p.lastT=function(){return Math.max(this.center.lastT(),this.up.lastT(),this.right.lastT(),this.radius.lastT(),this.angle.lastT())},p.idle=function(t){this.center.idle(t),this.up.idle(t),this.right.idle(t),this.radius.idle(t),this.angle.idle(t)},p.flush=function(t){this.center.flush(t),this.up.flush(t),this.right.flush(t),this.radius.flush(t),this.angle.flush(t)},p.setDistance=function(t,e){e>0&&this.radius.set(t,Math.log(e))},p.lookAt=function(t,e,r,n){this.recalcMatrix(t),e=e||this.computedEye,r=r||this.computedCenter;var i=(n=n||this.computedUp)[0],a=n[1],o=n[2],s=c(i,a,o);if(!(s<1e-6)){i/=s,a/=s,o/=s;var l=e[0]-r[0],f=e[1]-r[1],h=e[2]-r[2],p=c(l,f,h);if(!(p<1e-6)){l/=p,f/=p,h/=p;var d=this.computedRight,m=d[0],g=d[1],v=d[2],y=i*m+a*g+o*v,x=c(m-=y*i,g-=y*a,v-=y*o);if(!(x<.01&&(x=c(m=a*h-o*f,g=o*l-i*h,v=i*f-a*l))<1e-6)){m/=x,g/=x,v/=x,this.up.set(t,i,a,o),this.right.set(t,m,g,v),this.center.set(t,r[0],r[1],r[2]),this.radius.set(t,Math.log(p));var b=a*v-o*g,_=o*m-i*v,w=i*g-a*m,T=c(b,_,w),k=i*l+a*f+o*h,A=m*l+g*f+v*h,M=(b/=T)*l+(_/=T)*f+(w/=T)*h,S=Math.asin(u(k)),E=Math.atan2(M,A),L=this.angle._state,C=L[L.length-1],P=L[L.length-2];C%=2*Math.PI;var I=Math.abs(C+2*Math.PI-E),O=Math.abs(C-E),z=Math.abs(C-2*Math.PI-E);I0?r.pop():new ArrayBuffer(t)}function d(t){return new Uint8Array(p(t),0,t)}function m(t){return new Uint16Array(p(2*t),0,t)}function g(t){return new Uint32Array(p(4*t),0,t)}function v(t){return new Int8Array(p(t),0,t)}function y(t){return new Int16Array(p(2*t),0,t)}function x(t){return new Int32Array(p(4*t),0,t)}function b(t){return new Float32Array(p(4*t),0,t)}function _(t){return new Float64Array(p(8*t),0,t)}function w(t){return o?new Uint8ClampedArray(p(t),0,t):d(t)}function T(t){return s?new BigUint64Array(p(8*t),0,t):null}function k(t){return l?new BigInt64Array(p(8*t),0,t):null}function A(t){return new DataView(p(t),0,t)}function M(t){t=n.nextPow2(t);var e=n.log2(t),r=f[e];return r.length>0?r.pop():new a(t)}r.free=function(t){if(a.isBuffer(t))f[n.log2(t.length)].push(t);else{if("[object ArrayBuffer]"!==Object.prototype.toString.call(t)&&(t=t.buffer),!t)return;var e=t.length||t.byteLength,r=0|n.log2(e);u[r].push(t)}},r.freeUint8=r.freeUint16=r.freeUint32=r.freeBigUint64=r.freeInt8=r.freeInt16=r.freeInt32=r.freeBigInt64=r.freeFloat32=r.freeFloat=r.freeFloat64=r.freeDouble=r.freeUint8Clamped=r.freeDataView=function(t){h(t.buffer)},r.freeArrayBuffer=h,r.freeBuffer=function(t){f[n.log2(t.length)].push(t)},r.malloc=function(t,e){if(void 0===e||"arraybuffer"===e)return p(t);switch(e){case"uint8":return d(t);case"uint16":return m(t);case"uint32":return g(t);case"int8":return v(t);case"int16":return y(t);case"int32":return x(t);case"float":case"float32":return b(t);case"double":case"float64":return _(t);case"uint8_clamped":return w(t);case"bigint64":return k(t);case"biguint64":return T(t);case"buffer":return M(t);case"data":case"dataview":return A(t);default:return null}return null},r.mallocArrayBuffer=p,r.mallocUint8=d,r.mallocUint16=m,r.mallocUint32=g,r.mallocInt8=v,r.mallocInt16=y,r.mallocInt32=x,r.mallocFloat32=r.mallocFloat=b,r.mallocFloat64=r.mallocDouble=_,r.mallocUint8Clamped=w,r.mallocBigUint64=T,r.mallocBigInt64=k,r.mallocDataView=A,r.mallocBuffer=M,r.clearCache=function(){for(var t=0;t<32;++t)c.UINT8[t].length=0,c.UINT16[t].length=0,c.UINT32[t].length=0,c.INT8[t].length=0,c.INT16[t].length=0,c.INT32[t].length=0,c.FLOAT[t].length=0,c.DOUBLE[t].length=0,c.BIGUINT64[t].length=0,c.BIGINT64[t].length=0,c.UINT8C[t].length=0,u[t].length=0,f[t].length=0}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{"bit-twiddle":32,buffer:3,dup:65}],309:[function(t,e,r){"use strict";function n(t){this.roots=new Array(t),this.ranks=new Array(t);for(var e=0;e0&&(a=n.size),n.lineSpacing&&n.lineSpacing>0&&(o=n.lineSpacing),n.styletags&&n.styletags.breaklines&&(s.breaklines=!!n.styletags.breaklines),n.styletags&&n.styletags.bolds&&(s.bolds=!!n.styletags.bolds),n.styletags&&n.styletags.italics&&(s.italics=!!n.styletags.italics),n.styletags&&n.styletags.subscripts&&(s.subscripts=!!n.styletags.subscripts),n.styletags&&n.styletags.superscripts&&(s.superscripts=!!n.styletags.superscripts));return r.font=[n.fontStyle,n.fontVariant,n.fontWeight,a+"px",n.font].filter((function(t){return t})).join(" "),r.textAlign="start",r.textBaseline="alphabetic",r.direction="ltr",h(function(t,e,r,n,a,o){r=r.replace(/\n/g,""),r=!0===o.breaklines?r.replace(/\/g,"\n"):r.replace(/\/g," ");var s="",l=[];for(p=0;p-1?parseInt(t[1+i]):0,l=a>-1?parseInt(r[1+a]):0;s!==l&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,l-s),n=n.replace("?px ",S())),m+=.25*x*(l-s)}if(!0===o.superscripts){var c=t.indexOf("+"),u=r.indexOf("+"),f=c>-1?parseInt(t[1+c]):0,h=u>-1?parseInt(r[1+u]):0;f!==h&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,h-f),n=n.replace("?px ",S())),m-=.25*x*(h-f)}if(!0===o.bolds){var p=t.indexOf("b|")>-1,d=r.indexOf("b|")>-1;!p&&d&&(n=v?n.replace("italic ","italic bold "):"bold "+n),p&&!d&&(n=n.replace("bold ",""))}if(!0===o.italics){var v=t.indexOf("i|")>-1,y=r.indexOf("i|")>-1;!v&&y&&(n="italic "+n),v&&!y&&(n=n.replace("italic ",""))}e.font=n}for(h=0;h",a="",o=i.length,s=a.length,l="+"===e[0]||"-"===e[0],c=0,u=-s;c>-1&&-1!==(c=r.indexOf(i,c))&&-1!==(u=r.indexOf(a,c+o))&&!(u<=c);){for(var f=c;f=u)n[f]=null,r=r.substr(0,f)+" "+r.substr(f+1);else if(null!==n[f]){var h=n[f].indexOf(e[0]);-1===h?n[f]+=e:l&&(n[f]=n[f].substr(0,h+1)+(1+parseInt(n[f][h+1]))+n[f].substr(h+2))}var p=c+o,d=r.substr(p,u-p).indexOf(i);c=-1!==d?d:u+s}return n}function u(t,e){var r=n(t,128);return e?a(r.cells,r.positions,.25):{edges:r.cells,positions:r.positions}}function f(t,e,r,n){var i=u(t,n),a=function(t,e,r){for(var n=e.textAlign||"start",i=e.textBaseline||"alphabetic",a=[1<<30,1<<30],o=[0,0],s=t.length,l=0;l=0?e[a]:i}))},has___:{value:y((function(e){var n=v(e);return n?r in n:t.indexOf(e)>=0}))},set___:{value:y((function(n,i){var a,o=v(n);return o?o[r]=i:(a=t.indexOf(n))>=0?e[a]=i:(a=t.length,e[a]=i,t[a]=n),this}))},delete___:{value:y((function(n){var i,a,o=v(n);return o?r in o&&delete o[r]:!((i=t.indexOf(n))<0)&&(a=t.length-1,t[i]=void 0,e[i]=e[a],t[i]=t[a],t.length=a,e.length=a,!0)}))}})};d.prototype=Object.create(Object.prototype,{get:{value:function(t,e){return this.get___(t,e)},writable:!0,configurable:!0},has:{value:function(t){return this.has___(t)},writable:!0,configurable:!0},set:{value:function(t,e){return this.set___(t,e)},writable:!0,configurable:!0},delete:{value:function(t){return this.delete___(t)},writable:!0,configurable:!0}}),"function"==typeof r?function(){function n(){this instanceof d||x();var e,n=new r,i=void 0,a=!1;return e=t?function(t,e){return n.set(t,e),n.has(t)||(i||(i=new d),i.set(t,e)),this}:function(t,e){if(a)try{n.set(t,e)}catch(r){i||(i=new d),i.set___(t,e)}else n.set(t,e);return this},Object.create(d.prototype,{get___:{value:y((function(t,e){return i?n.has(t)?n.get(t):i.get___(t,e):n.get(t,e)}))},has___:{value:y((function(t){return n.has(t)||!!i&&i.has___(t)}))},set___:{value:y(e)},delete___:{value:y((function(t){var e=!!n.delete(t);return i&&i.delete___(t)||e}))},permitHostObjects___:{value:y((function(t){if(t!==m)throw new Error("bogus call to permitHostObjects___");a=!0}))}})}t&&"undefined"!=typeof Proxy&&(Proxy=void 0),n.prototype=d.prototype,e.exports=n,Object.defineProperty(WeakMap.prototype,"constructor",{value:WeakMap,enumerable:!1,configurable:!0,writable:!0})}():("undefined"!=typeof Proxy&&(Proxy=void 0),e.exports=d)}function m(t){t.permitHostObjects___&&t.permitHostObjects___(m)}function g(t){return!("weakmap:"==t.substr(0,"weakmap:".length)&&"___"===t.substr(t.length-3))}function v(t){if(t!==Object(t))throw new TypeError("Not an object: "+t);var e=t[l];if(e&&e.key===t)return e;if(s(t)){e={key:t};try{return o(t,l,{value:e,writable:!1,enumerable:!1,configurable:!1}),e}catch(t){return}}}function y(t){return t.prototype=null,Object.freeze(t)}function x(){h||"undefined"==typeof console||(h=!0,console.warn("WeakMap should be invoked as new WeakMap(), not WeakMap(). This will be an error in the future."))}}()},{}],314:[function(t,e,r){var n=t("./hidden-store.js");e.exports=function(){var t={};return function(e){if(("object"!=typeof e||null===e)&&"function"!=typeof e)throw new Error("Weakmap-shim: Key must be object");var r=e.valueOf(t);return r&&r.identity===t?r:n(e,t)}}},{"./hidden-store.js":315}],315:[function(t,e,r){e.exports=function(t,e){var r={identity:e},n=t.valueOf;return Object.defineProperty(t,"valueOf",{value:function(t){return t!==e?n.apply(this,arguments):r},writable:!0}),r}},{}],316:[function(t,e,r){var n=t("./create-store.js");e.exports=function(){var t=n();return{get:function(e,r){var n=t(e);return n.hasOwnProperty("value")?n.value:r},set:function(e,r){return t(e).value=r,this},has:function(e){return"value"in t(e)},delete:function(e){return delete t(e).value}}}},{"./create-store.js":314}],317:[function(t,e,r){"use strict";var n,i=function(){return function(t,e,r,n,i,a){var o=t[0],s=r[0],l=[0],c=s;n|=0;var u=0,f=s;for(u=0;u=0!=p>=0&&i.push(l[0]+.5+.5*(h+p)/(h-p)),n+=f,++l[0]}}};e.exports=(n={funcName:{funcName:"zeroCrossings"}.funcName},function(t){var e={};return function(r,n,i){var a=r.dtype,o=r.order,s=[a,o.join()].join(),l=e[s];return l||(e[s]=l=t([a,o])),l(r.shape.slice(0),r.data,r.stride,0|r.offset,n,i)}}(i.bind(void 0,n)))},{}],318:[function(t,e,r){"use strict";e.exports=function(t,e){var r=[];return e=+e||0,n(t.hi(t.shape[0]-1),r,e),r};var n=t("./lib/zc-core")},{"./lib/zc-core":317}]},{},[6])(6)}))}).call(this)}).call(this,"undefined"!=typeof global?global:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}]},{},[27])(27)}));
\ No newline at end of file diff --git a/docs/source/_static/sintel_clean_epe_outlier-drop_kitti_sintel.html b/docs/source/_static/sintel_clean_epe_outlier-drop_kitti_sintel.html index 469d6fc..5847c5a 100644 --- a/docs/source/_static/sintel_clean_epe_outlier-drop_kitti_sintel.html +++ b/docs/source/_static/sintel_clean_epe_outlier-drop_kitti_sintel.html @@ -3,30 +3,30 @@
+"use strict";var n,i="";e.exports=function(t,e){if("string"!=typeof t)throw new TypeError("expected a string");if(1===e)return t;if(2===e)return t+t;var r=t.length*e;if(n!==t||void 0===n)n=t,i="";else if(i.length>=r)return i.substr(0,r);for(;r>i.length&&e>1;)1&e&&(i+=t),e>>=1,t+=t;return i=(i+=t).substr(0,r)}},{}],278:[function(t,e,r){(function(t){(function(){e.exports=t.performance&&t.performance.now?function(){return performance.now()}:Date.now||function(){return+new Date}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}],279:[function(t,e,r){"use strict";e.exports=function(t){for(var e=t.length,r=t[t.length-1],n=e,i=e-2;i>=0;--i){var a=r,o=t[i];(l=o-((r=a+o)-a))&&(t[--n]=r,r=l)}var s=0;for(i=n;i0){if(a<=0)return o;n=i+a}else{if(!(i<0))return o;if(a>=0)return o;n=-(i+a)}var s=33306690738754716e-32*n;return o>=s||o<=-s?o:f(t,e,r)},function(t,e,r,n){var i=t[0]-n[0],a=e[0]-n[0],o=r[0]-n[0],s=t[1]-n[1],l=e[1]-n[1],c=r[1]-n[1],u=t[2]-n[2],f=e[2]-n[2],p=r[2]-n[2],d=a*c,m=o*l,g=o*s,v=i*c,y=i*l,x=a*s,b=u*(d-m)+f*(g-v)+p*(y-x),_=7771561172376103e-31*((Math.abs(d)+Math.abs(m))*Math.abs(u)+(Math.abs(g)+Math.abs(v))*Math.abs(f)+(Math.abs(y)+Math.abs(x))*Math.abs(p));return b>_||-b>_?b:h(t,e,r,n)}];function d(t){var e=p[t.length];return e||(e=p[t.length]=u(t.length)),e.apply(void 0,t)}function m(t,e,r,n,i,a,o){return function(e,r,s,l,c){switch(arguments.length){case 0:case 1:return 0;case 2:return n(e,r);case 3:return i(e,r,s);case 4:return a(e,r,s,l);case 5:return o(e,r,s,l,c)}for(var u=new Array(arguments.length),f=0;f0&&o>0||a<0&&o<0)return!1;var s=n(r,t,e),l=n(i,t,e);if(s>0&&l>0||s<0&&l<0)return!1;if(0===a&&0===o&&0===s&&0===l)return function(t,e,r,n){for(var i=0;i<2;++i){var a=t[i],o=e[i],s=Math.min(a,o),l=Math.max(a,o),c=r[i],u=n[i],f=Math.min(c,u);if(Math.max(c,u)=n?(i=f,(l+=1)=n?(i=f,(l+=1)>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,s=a(t[o],e);s<=0?(0===s&&(i=o),r=o+1):s>0&&(n=o-1)}return i}function u(t,e){for(var r=new Array(t.length),i=0,o=r.length;i=t.length||0!==a(t[g],s)););}return r}function f(t,e){if(e<0)return[];for(var r=[],i=(1<>>u&1&&c.push(i[u]);e.push(c)}return s(e)},r.skeleton=f,r.boundary=function(t){for(var e=[],r=0,n=t.length;r>1:(t>>1)-1}function x(t){for(var e=v(t);;){var r=e,n=2*t+1,i=2*(t+1),a=t;if(n0;){var r=y(t);if(r>=0)if(e0){var t=k[0];return g(0,M-1),M-=1,x(0),t}return-1}function w(t,e){var r=k[t];return c[r]===e?t:(c[r]=-1/0,b(t),_(),c[r]=e,b((M+=1)-1))}function T(t){if(!u[t]){u[t]=!0;var e=s[t],r=l[t];s[r]>=0&&(s[r]=e),l[e]>=0&&(l[e]=r),A[e]>=0&&w(A[e],m(e)),A[r]>=0&&w(A[r],m(r))}}var k=[],A=new Array(a);for(f=0;f>1;f>=0;--f)x(f);for(;;){var S=_();if(S<0||c[S]>r)break;T(S)}var E=[];for(f=0;f=0&&r>=0&&e!==r){var n=A[e],i=A[r];n!==i&&C.push([n,i])}})),i.unique(i.normalize(C)),{positions:E,edges:C}};var n=t("robust-orientation"),i=t("simplicial-complex")},{"robust-orientation":284,"simplicial-complex":295}],298:[function(t,e,r){"use strict";e.exports=function(t,e){var r,a,o,s;if(e[0][0]e[1][0]))return i(e,t);r=e[1],a=e[0]}if(t[0][0]t[1][0]))return-i(t,e);o=t[1],s=t[0]}var l=n(r,a,s),c=n(r,a,o);if(l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;if(l=n(s,o,a),c=n(s,o,r),l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;return a[0]-s[0]};var n=t("robust-orientation");function i(t,e){var r,i,a,o;if(e[0][0]e[1][0])){var s=Math.min(t[0][1],t[1][1]),l=Math.max(t[0][1],t[1][1]),c=Math.min(e[0][1],e[1][1]),u=Math.max(e[0][1],e[1][1]);return lu?s-u:l-u}r=e[1],i=e[0]}t[0][1]0)if(e[0]!==o[1][0])r=t,t=t.right;else{if(l=c(t.right,e))return l;t=t.left}else{if(e[0]!==o[1][0])return t;var l;if(l=c(t.right,e))return l;t=t.left}}return r}function u(t,e,r,n){this.y=t,this.index=e,this.start=r,this.closed=n}function f(t,e,r,n){this.x=t,this.segment=e,this.create=r,this.index=n}s.prototype.castUp=function(t){var e=n.le(this.coordinates,t[0]);if(e<0)return-1;this.slabs[e];var r=c(this.slabs[e],t),i=-1;if(r&&(i=r.value),this.coordinates[e]===t[0]){var s=null;if(r&&(s=r.key),e>0){var u=c(this.slabs[e-1],t);u&&(s?o(u.key,s)>0&&(s=u.key,i=u.value):(i=u.value,s=u.key))}var f=this.horizontal[e];if(f.length>0){var h=n.ge(f,t[1],l);if(h=f.length)return i;p=f[h]}}if(p.start)if(s){var d=a(s[0],s[1],[t[0],p.y]);s[0][0]>s[1][0]&&(d=-d),d>0&&(i=p.index)}else i=p.index;else p.y!==t[1]&&(i=p.index)}}}return i}},{"./lib/order-segments":298,"binary-search-bounds":31,"functional-red-black-tree":69,"robust-orientation":284}],300:[function(t,e,r){"use strict";var n=t("robust-dot-product"),i=t("robust-sum");function a(t,e){var r=i(n(t,e),[e[e.length-1]]);return r[r.length-1]}function o(t,e,r,n){var i=-e/(n-e);i<0?i=0:i>1&&(i=1);for(var a=1-i,o=t.length,s=new Array(o),l=0;l0||i>0&&u<0){var f=o(s,u,l,i);r.push(f),n.push(f.slice())}u<0?n.push(l.slice()):u>0?r.push(l.slice()):(r.push(l.slice()),n.push(l.slice())),i=u}return{positive:r,negative:n}},e.exports.positive=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c>=0&&r.push(s.slice()),n=c}return r},e.exports.negative=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c<=0&&r.push(s.slice()),n=c}return r}},{"robust-dot-product":281,"robust-sum":289}],301:[function(t,e,r){!function(){"use strict";var t={not_string:/[^s]/,not_bool:/[^t]/,not_type:/[^T]/,not_primitive:/[^v]/,number:/[diefg]/,numeric_arg:/[bcdiefguxX]/,json:/[j]/,not_json:/[^j]/,text:/^[^\x25]+/,modulo:/^\x25{2}/,placeholder:/^\x25(?:([1-9]\d*)\$|\(([^)]+)\))?(\+)?(0|'[^$])?(-)?(\d+)?(?:\.(\d+))?([b-gijostTuvxX])/,key:/^([a-z_][a-z_\d]*)/i,key_access:/^\.([a-z_][a-z_\d]*)/i,index_access:/^\[(\d+)\]/,sign:/^[+-]/};function e(t){return i(o(t),arguments)}function n(t,r){return e.apply(null,[t].concat(r||[]))}function i(r,n){var i,a,o,s,l,c,u,f,h,p=1,d=r.length,m="";for(a=0;a=0),s.type){case"b":i=parseInt(i,10).toString(2);break;case"c":i=String.fromCharCode(parseInt(i,10));break;case"d":case"i":i=parseInt(i,10);break;case"j":i=JSON.stringify(i,null,s.width?parseInt(s.width):0);break;case"e":i=s.precision?parseFloat(i).toExponential(s.precision):parseFloat(i).toExponential();break;case"f":i=s.precision?parseFloat(i).toFixed(s.precision):parseFloat(i);break;case"g":i=s.precision?String(Number(i.toPrecision(s.precision))):parseFloat(i);break;case"o":i=(parseInt(i,10)>>>0).toString(8);break;case"s":i=String(i),i=s.precision?i.substring(0,s.precision):i;break;case"t":i=String(!!i),i=s.precision?i.substring(0,s.precision):i;break;case"T":i=Object.prototype.toString.call(i).slice(8,-1).toLowerCase(),i=s.precision?i.substring(0,s.precision):i;break;case"u":i=parseInt(i,10)>>>0;break;case"v":i=i.valueOf(),i=s.precision?i.substring(0,s.precision):i;break;case"x":i=(parseInt(i,10)>>>0).toString(16);break;case"X":i=(parseInt(i,10)>>>0).toString(16).toUpperCase()}t.json.test(s.type)?m+=i:(!t.number.test(s.type)||f&&!s.sign?h="":(h=f?"+":"-",i=i.toString().replace(t.sign,"")),c=s.pad_char?"0"===s.pad_char?"0":s.pad_char.charAt(1):" ",u=s.width-(h+i).length,l=s.width&&u>0?c.repeat(u):"",m+=s.align?h+i+l:"0"===c?h+l+i:l+h+i)}return m}var a=Object.create(null);function o(e){if(a[e])return a[e];for(var r,n=e,i=[],o=0;n;){if(null!==(r=t.text.exec(n)))i.push(r[0]);else if(null!==(r=t.modulo.exec(n)))i.push("%");else{if(null===(r=t.placeholder.exec(n)))throw new SyntaxError("[sprintf] unexpected placeholder");if(r[2]){o|=1;var s=[],l=r[2],c=[];if(null===(c=t.key.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");for(s.push(c[1]);""!==(l=l.substring(c[0].length));)if(null!==(c=t.key_access.exec(l)))s.push(c[1]);else{if(null===(c=t.index_access.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");s.push(c[1])}r[2]=s}else o|=2;if(3===o)throw new Error("[sprintf] mixing positional and named placeholders is not (yet) supported");i.push({placeholder:r[0],param_no:r[1],keys:r[2],sign:r[3],pad_char:r[4],align:r[5],width:r[6],precision:r[7],type:r[8]})}n=n.substring(r[0].length)}return a[e]=i}void 0!==r&&(r.sprintf=e,r.vsprintf=n),"undefined"!=typeof window&&(window.sprintf=e,window.vsprintf=n)}()},{}],302:[function(t,e,r){"use strict";e.exports=function(t,e){if(t.dimension<=0)return{positions:[],cells:[]};if(1===t.dimension)return function(t,e){for(var r=i(t,e),n=r.length,a=new Array(n),o=new Array(n),s=0;sn|0},vertex:function(t,e,r,n,i,a,o,s,l,c,u,f,h){var p=(o<<0)+(s<<1)+(l<<2)+(c<<3)|0;if(0!==p&&15!==p)switch(p){case 0:u.push([t-.5,e-.5]);break;case 1:u.push([t-.25-.25*(n+r-2*h)/(r-n),e-.25-.25*(i+r-2*h)/(r-i)]);break;case 2:u.push([t-.75-.25*(-n-r+2*h)/(n-r),e-.25-.25*(a+n-2*h)/(n-a)]);break;case 3:u.push([t-.5,e-.5-.5*(i+r+a+n-4*h)/(r-i+n-a)]);break;case 4:u.push([t-.25-.25*(a+i-2*h)/(i-a),e-.75-.25*(-i-r+2*h)/(i-r)]);break;case 5:u.push([t-.5-.5*(n+r+a+i-4*h)/(r-n+i-a),e-.5]);break;case 6:u.push([t-.5-.25*(-n-r+a+i)/(n-r+i-a),e-.5-.25*(-i-r+a+n)/(i-r+n-a)]);break;case 7:u.push([t-.75-.25*(a+i-2*h)/(i-a),e-.75-.25*(a+n-2*h)/(n-a)]);break;case 8:u.push([t-.75-.25*(-a-i+2*h)/(a-i),e-.75-.25*(-a-n+2*h)/(a-n)]);break;case 9:u.push([t-.5-.25*(n+r+-a-i)/(r-n+a-i),e-.5-.25*(i+r+-a-n)/(r-i+a-n)]);break;case 10:u.push([t-.5-.5*(-n-r-a-i+4*h)/(n-r+a-i),e-.5]);break;case 11:u.push([t-.25-.25*(-a-i+2*h)/(a-i),e-.75-.25*(i+r-2*h)/(r-i)]);break;case 12:u.push([t-.5,e-.5-.5*(-i-r-a-n+4*h)/(i-r+a-n)]);break;case 13:u.push([t-.75-.25*(n+r-2*h)/(r-n),e-.25-.25*(-a-n+2*h)/(a-n)]);break;case 14:u.push([t-.25-.25*(-n-r+2*h)/(n-r),e-.25-.25*(-i-r+2*h)/(i-r)]);break;case 15:u.push([t-.5,e-.5])}},cell:function(t,e,r,n,i,a,o,s,l){i?s.push([t,e]):s.push([e,t])}});return function(t,e){var r=[],i=[];return n(t,r,i,e),{positions:r,cells:i}}}};var o={}},{"ndarray-extract-contour":251,"zero-crossings":318}],303:[function(t,e,r){(function(r){(function(){"use strict";e.exports=function t(e,r,i){i=i||{};var o=a[e];o||(o=a[e]={" ":{data:new Float32Array(0),shape:.2}});var s=o[r];if(!s)if(r.length<=1||!/\d/.test(r))s=o[r]=function(t){for(var e=t.cells,r=t.positions,n=new Float32Array(6*e.length),i=0,a=0,o=0;o0&&(f+=.02);var p=new Float32Array(u),d=0,m=-.5*f;for(h=0;hMath.max(r,n)?i[2]=1:r>Math.max(e,n)?i[0]=1:i[1]=1;for(var a=0,o=0,l=0;l<3;++l)a+=t[l]*t[l],o+=i[l]*t[l];for(l=0;l<3;++l)i[l]-=o/a*t[l];return s(i,i),i}function h(t,e,r,i,a,o,s,l){this.center=n(r),this.up=n(i),this.right=n(a),this.radius=n([o]),this.angle=n([s,l]),this.angle.bounds=[[-1/0,-Math.PI/2],[1/0,Math.PI/2]],this.setDistanceLimits(t,e),this.computedCenter=this.center.curve(0),this.computedUp=this.up.curve(0),this.computedRight=this.right.curve(0),this.computedRadius=this.radius.curve(0),this.computedAngle=this.angle.curve(0),this.computedToward=[0,0,0],this.computedEye=[0,0,0],this.computedMatrix=new Array(16);for(var c=0;c<16;++c)this.computedMatrix[c]=.5;this.recalcMatrix(0)}var p=h.prototype;p.setDistanceLimits=function(t,e){t=t>0?Math.log(t):-1/0,e=e>0?Math.log(e):1/0,e=Math.max(e,t),this.radius.bounds[0][0]=t,this.radius.bounds[1][0]=e},p.getDistanceLimits=function(t){var e=this.radius.bounds[0];return t?(t[0]=Math.exp(e[0][0]),t[1]=Math.exp(e[1][0]),t):[Math.exp(e[0][0]),Math.exp(e[1][0])]},p.recalcMatrix=function(t){this.center.curve(t),this.up.curve(t),this.right.curve(t),this.radius.curve(t),this.angle.curve(t);for(var e=this.computedUp,r=this.computedRight,n=0,i=0,a=0;a<3;++a)i+=e[a]*r[a],n+=e[a]*e[a];var l=Math.sqrt(n),u=0;for(a=0;a<3;++a)r[a]-=e[a]*i/n,u+=r[a]*r[a],e[a]/=l;var f=Math.sqrt(u);for(a=0;a<3;++a)r[a]/=f;var h=this.computedToward;o(h,e,r),s(h,h);var p=Math.exp(this.computedRadius[0]),d=this.computedAngle[0],m=this.computedAngle[1],g=Math.cos(d),v=Math.sin(d),y=Math.cos(m),x=Math.sin(m),b=this.computedCenter,_=g*y,w=v*y,T=x,k=-g*x,A=-v*x,M=y,S=this.computedEye,E=this.computedMatrix;for(a=0;a<3;++a){var L=_*r[a]+w*h[a]+T*e[a];E[4*a+1]=k*r[a]+A*h[a]+M*e[a],E[4*a+2]=L,E[4*a+3]=0}var C=E[1],P=E[5],I=E[9],O=E[2],z=E[6],D=E[10],R=P*D-I*z,F=I*O-C*D,B=C*z-P*O,N=c(R,F,B);R/=N,F/=N,B/=N,E[0]=R,E[4]=F,E[8]=B;for(a=0;a<3;++a)S[a]=b[a]+E[2+4*a]*p;for(a=0;a<3;++a){u=0;for(var j=0;j<3;++j)u+=E[a+4*j]*S[j];E[12+a]=-u}E[15]=1},p.getMatrix=function(t,e){this.recalcMatrix(t);var r=this.computedMatrix;if(e){for(var n=0;n<16;++n)e[n]=r[n];return e}return r};var d=[0,0,0];p.rotate=function(t,e,r,n){if(this.angle.move(t,e,r),n){this.recalcMatrix(t);var i=this.computedMatrix;d[0]=i[2],d[1]=i[6],d[2]=i[10];for(var o=this.computedUp,s=this.computedRight,l=this.computedToward,c=0;c<3;++c)i[4*c]=o[c],i[4*c+1]=s[c],i[4*c+2]=l[c];a(i,i,n,d);for(c=0;c<3;++c)o[c]=i[4*c],s[c]=i[4*c+1];this.up.set(t,o[0],o[1],o[2]),this.right.set(t,s[0],s[1],s[2])}},p.pan=function(t,e,r,n){e=e||0,r=r||0,n=n||0,this.recalcMatrix(t);var i=this.computedMatrix,a=(Math.exp(this.computedRadius[0]),i[1]),o=i[5],s=i[9],l=c(a,o,s);a/=l,o/=l,s/=l;var u=i[0],f=i[4],h=i[8],p=u*a+f*o+h*s,d=c(u-=a*p,f-=o*p,h-=s*p),m=(u/=d)*e+a*r,g=(f/=d)*e+o*r,v=(h/=d)*e+s*r;this.center.move(t,m,g,v);var y=Math.exp(this.computedRadius[0]);y=Math.max(1e-4,y+n),this.radius.set(t,Math.log(y))},p.translate=function(t,e,r,n){this.center.move(t,e||0,r||0,n||0)},p.setMatrix=function(t,e,r,n){var a=1;"number"==typeof r&&(a=0|r),(a<0||a>3)&&(a=1);var o=(a+2)%3;e||(this.recalcMatrix(t),e=this.computedMatrix);var s=e[a],l=e[a+4],f=e[a+8];if(n){var h=Math.abs(s),p=Math.abs(l),d=Math.abs(f),m=Math.max(h,p,d);h===m?(s=s<0?-1:1,l=f=0):d===m?(f=f<0?-1:1,s=l=0):(l=l<0?-1:1,s=f=0)}else{var g=c(s,l,f);s/=g,l/=g,f/=g}var v,y,x=e[o],b=e[o+4],_=e[o+8],w=x*s+b*l+_*f,T=c(x-=s*w,b-=l*w,_-=f*w),k=l*(_/=T)-f*(b/=T),A=f*(x/=T)-s*_,M=s*b-l*x,S=c(k,A,M);if(k/=S,A/=S,M/=S,this.center.jump(t,q,G,Y),this.radius.idle(t),this.up.jump(t,s,l,f),this.right.jump(t,x,b,_),2===a){var E=e[1],L=e[5],C=e[9],P=E*x+L*b+C*_,I=E*k+L*A+C*M;v=R<0?-Math.PI/2:Math.PI/2,y=Math.atan2(I,P)}else{var O=e[2],z=e[6],D=e[10],R=O*s+z*l+D*f,F=O*x+z*b+D*_,B=O*k+z*A+D*M;v=Math.asin(u(R)),y=Math.atan2(B,F)}this.angle.jump(t,y,v),this.recalcMatrix(t);var N=e[2],j=e[6],U=e[10],V=this.computedMatrix;i(V,e);var H=V[15],q=V[12]/H,G=V[13]/H,Y=V[14]/H,W=Math.exp(this.computedRadius[0]);this.center.jump(t,q-N*W,G-j*W,Y-U*W)},p.lastT=function(){return Math.max(this.center.lastT(),this.up.lastT(),this.right.lastT(),this.radius.lastT(),this.angle.lastT())},p.idle=function(t){this.center.idle(t),this.up.idle(t),this.right.idle(t),this.radius.idle(t),this.angle.idle(t)},p.flush=function(t){this.center.flush(t),this.up.flush(t),this.right.flush(t),this.radius.flush(t),this.angle.flush(t)},p.setDistance=function(t,e){e>0&&this.radius.set(t,Math.log(e))},p.lookAt=function(t,e,r,n){this.recalcMatrix(t),e=e||this.computedEye,r=r||this.computedCenter;var i=(n=n||this.computedUp)[0],a=n[1],o=n[2],s=c(i,a,o);if(!(s<1e-6)){i/=s,a/=s,o/=s;var l=e[0]-r[0],f=e[1]-r[1],h=e[2]-r[2],p=c(l,f,h);if(!(p<1e-6)){l/=p,f/=p,h/=p;var d=this.computedRight,m=d[0],g=d[1],v=d[2],y=i*m+a*g+o*v,x=c(m-=y*i,g-=y*a,v-=y*o);if(!(x<.01&&(x=c(m=a*h-o*f,g=o*l-i*h,v=i*f-a*l))<1e-6)){m/=x,g/=x,v/=x,this.up.set(t,i,a,o),this.right.set(t,m,g,v),this.center.set(t,r[0],r[1],r[2]),this.radius.set(t,Math.log(p));var b=a*v-o*g,_=o*m-i*v,w=i*g-a*m,T=c(b,_,w),k=i*l+a*f+o*h,A=m*l+g*f+v*h,M=(b/=T)*l+(_/=T)*f+(w/=T)*h,S=Math.asin(u(k)),E=Math.atan2(M,A),L=this.angle._state,C=L[L.length-1],P=L[L.length-2];C%=2*Math.PI;var I=Math.abs(C+2*Math.PI-E),O=Math.abs(C-E),z=Math.abs(C-2*Math.PI-E);I0?r.pop():new ArrayBuffer(t)}function d(t){return new Uint8Array(p(t),0,t)}function m(t){return new Uint16Array(p(2*t),0,t)}function g(t){return new Uint32Array(p(4*t),0,t)}function v(t){return new Int8Array(p(t),0,t)}function y(t){return new Int16Array(p(2*t),0,t)}function x(t){return new Int32Array(p(4*t),0,t)}function b(t){return new Float32Array(p(4*t),0,t)}function _(t){return new Float64Array(p(8*t),0,t)}function w(t){return o?new Uint8ClampedArray(p(t),0,t):d(t)}function T(t){return s?new BigUint64Array(p(8*t),0,t):null}function k(t){return l?new BigInt64Array(p(8*t),0,t):null}function A(t){return new DataView(p(t),0,t)}function M(t){t=n.nextPow2(t);var e=n.log2(t),r=f[e];return r.length>0?r.pop():new a(t)}r.free=function(t){if(a.isBuffer(t))f[n.log2(t.length)].push(t);else{if("[object ArrayBuffer]"!==Object.prototype.toString.call(t)&&(t=t.buffer),!t)return;var e=t.length||t.byteLength,r=0|n.log2(e);u[r].push(t)}},r.freeUint8=r.freeUint16=r.freeUint32=r.freeBigUint64=r.freeInt8=r.freeInt16=r.freeInt32=r.freeBigInt64=r.freeFloat32=r.freeFloat=r.freeFloat64=r.freeDouble=r.freeUint8Clamped=r.freeDataView=function(t){h(t.buffer)},r.freeArrayBuffer=h,r.freeBuffer=function(t){f[n.log2(t.length)].push(t)},r.malloc=function(t,e){if(void 0===e||"arraybuffer"===e)return p(t);switch(e){case"uint8":return d(t);case"uint16":return m(t);case"uint32":return g(t);case"int8":return v(t);case"int16":return y(t);case"int32":return x(t);case"float":case"float32":return b(t);case"double":case"float64":return _(t);case"uint8_clamped":return w(t);case"bigint64":return k(t);case"biguint64":return T(t);case"buffer":return M(t);case"data":case"dataview":return A(t);default:return null}return null},r.mallocArrayBuffer=p,r.mallocUint8=d,r.mallocUint16=m,r.mallocUint32=g,r.mallocInt8=v,r.mallocInt16=y,r.mallocInt32=x,r.mallocFloat32=r.mallocFloat=b,r.mallocFloat64=r.mallocDouble=_,r.mallocUint8Clamped=w,r.mallocBigUint64=T,r.mallocBigInt64=k,r.mallocDataView=A,r.mallocBuffer=M,r.clearCache=function(){for(var t=0;t<32;++t)c.UINT8[t].length=0,c.UINT16[t].length=0,c.UINT32[t].length=0,c.INT8[t].length=0,c.INT16[t].length=0,c.INT32[t].length=0,c.FLOAT[t].length=0,c.DOUBLE[t].length=0,c.BIGUINT64[t].length=0,c.BIGINT64[t].length=0,c.UINT8C[t].length=0,u[t].length=0,f[t].length=0}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{"bit-twiddle":32,buffer:3,dup:65}],309:[function(t,e,r){"use strict";function n(t){this.roots=new Array(t),this.ranks=new Array(t);for(var e=0;e0&&(a=n.size),n.lineSpacing&&n.lineSpacing>0&&(o=n.lineSpacing),n.styletags&&n.styletags.breaklines&&(s.breaklines=!!n.styletags.breaklines),n.styletags&&n.styletags.bolds&&(s.bolds=!!n.styletags.bolds),n.styletags&&n.styletags.italics&&(s.italics=!!n.styletags.italics),n.styletags&&n.styletags.subscripts&&(s.subscripts=!!n.styletags.subscripts),n.styletags&&n.styletags.superscripts&&(s.superscripts=!!n.styletags.superscripts));return r.font=[n.fontStyle,n.fontVariant,n.fontWeight,a+"px",n.font].filter((function(t){return t})).join(" "),r.textAlign="start",r.textBaseline="alphabetic",r.direction="ltr",h(function(t,e,r,n,a,o){r=r.replace(/\n/g,""),r=!0===o.breaklines?r.replace(/\/g,"\n"):r.replace(/\/g," ");var s="",l=[];for(p=0;p-1?parseInt(t[1+i]):0,l=a>-1?parseInt(r[1+a]):0;s!==l&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,l-s),n=n.replace("?px ",S())),m+=.25*x*(l-s)}if(!0===o.superscripts){var c=t.indexOf("+"),u=r.indexOf("+"),f=c>-1?parseInt(t[1+c]):0,h=u>-1?parseInt(r[1+u]):0;f!==h&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,h-f),n=n.replace("?px ",S())),m-=.25*x*(h-f)}if(!0===o.bolds){var p=t.indexOf("b|")>-1,d=r.indexOf("b|")>-1;!p&&d&&(n=v?n.replace("italic ","italic bold "):"bold "+n),p&&!d&&(n=n.replace("bold ",""))}if(!0===o.italics){var v=t.indexOf("i|")>-1,y=r.indexOf("i|")>-1;!v&&y&&(n="italic "+n),v&&!y&&(n=n.replace("italic ",""))}e.font=n}for(h=0;h",a="",o=i.length,s=a.length,l="+"===e[0]||"-"===e[0],c=0,u=-s;c>-1&&-1!==(c=r.indexOf(i,c))&&-1!==(u=r.indexOf(a,c+o))&&!(u<=c);){for(var f=c;f=u)n[f]=null,r=r.substr(0,f)+" "+r.substr(f+1);else if(null!==n[f]){var h=n[f].indexOf(e[0]);-1===h?n[f]+=e:l&&(n[f]=n[f].substr(0,h+1)+(1+parseInt(n[f][h+1]))+n[f].substr(h+2))}var p=c+o,d=r.substr(p,u-p).indexOf(i);c=-1!==d?d:u+s}return n}function u(t,e){var r=n(t,128);return e?a(r.cells,r.positions,.25):{edges:r.cells,positions:r.positions}}function f(t,e,r,n){var i=u(t,n),a=function(t,e,r){for(var n=e.textAlign||"start",i=e.textBaseline||"alphabetic",a=[1<<30,1<<30],o=[0,0],s=t.length,l=0;l=0?e[a]:i}))},has___:{value:y((function(e){var n=v(e);return n?r in n:t.indexOf(e)>=0}))},set___:{value:y((function(n,i){var a,o=v(n);return o?o[r]=i:(a=t.indexOf(n))>=0?e[a]=i:(a=t.length,e[a]=i,t[a]=n),this}))},delete___:{value:y((function(n){var i,a,o=v(n);return o?r in o&&delete o[r]:!((i=t.indexOf(n))<0)&&(a=t.length-1,t[i]=void 0,e[i]=e[a],t[i]=t[a],t.length=a,e.length=a,!0)}))}})};d.prototype=Object.create(Object.prototype,{get:{value:function(t,e){return this.get___(t,e)},writable:!0,configurable:!0},has:{value:function(t){return this.has___(t)},writable:!0,configurable:!0},set:{value:function(t,e){return this.set___(t,e)},writable:!0,configurable:!0},delete:{value:function(t){return this.delete___(t)},writable:!0,configurable:!0}}),"function"==typeof r?function(){function n(){this instanceof d||x();var e,n=new r,i=void 0,a=!1;return e=t?function(t,e){return n.set(t,e),n.has(t)||(i||(i=new d),i.set(t,e)),this}:function(t,e){if(a)try{n.set(t,e)}catch(r){i||(i=new d),i.set___(t,e)}else n.set(t,e);return this},Object.create(d.prototype,{get___:{value:y((function(t,e){return i?n.has(t)?n.get(t):i.get___(t,e):n.get(t,e)}))},has___:{value:y((function(t){return n.has(t)||!!i&&i.has___(t)}))},set___:{value:y(e)},delete___:{value:y((function(t){var e=!!n.delete(t);return i&&i.delete___(t)||e}))},permitHostObjects___:{value:y((function(t){if(t!==m)throw new Error("bogus call to permitHostObjects___");a=!0}))}})}t&&"undefined"!=typeof Proxy&&(Proxy=void 0),n.prototype=d.prototype,e.exports=n,Object.defineProperty(WeakMap.prototype,"constructor",{value:WeakMap,enumerable:!1,configurable:!0,writable:!0})}():("undefined"!=typeof Proxy&&(Proxy=void 0),e.exports=d)}function m(t){t.permitHostObjects___&&t.permitHostObjects___(m)}function g(t){return!("weakmap:"==t.substr(0,"weakmap:".length)&&"___"===t.substr(t.length-3))}function v(t){if(t!==Object(t))throw new TypeError("Not an object: "+t);var e=t[l];if(e&&e.key===t)return e;if(s(t)){e={key:t};try{return o(t,l,{value:e,writable:!1,enumerable:!1,configurable:!1}),e}catch(t){return}}}function y(t){return t.prototype=null,Object.freeze(t)}function x(){h||"undefined"==typeof console||(h=!0,console.warn("WeakMap should be invoked as new WeakMap(), not WeakMap(). This will be an error in the future."))}}()},{}],314:[function(t,e,r){var n=t("./hidden-store.js");e.exports=function(){var t={};return function(e){if(("object"!=typeof e||null===e)&&"function"!=typeof e)throw new Error("Weakmap-shim: Key must be object");var r=e.valueOf(t);return r&&r.identity===t?r:n(e,t)}}},{"./hidden-store.js":315}],315:[function(t,e,r){e.exports=function(t,e){var r={identity:e},n=t.valueOf;return Object.defineProperty(t,"valueOf",{value:function(t){return t!==e?n.apply(this,arguments):r},writable:!0}),r}},{}],316:[function(t,e,r){var n=t("./create-store.js");e.exports=function(){var t=n();return{get:function(e,r){var n=t(e);return n.hasOwnProperty("value")?n.value:r},set:function(e,r){return t(e).value=r,this},has:function(e){return"value"in t(e)},delete:function(e){return delete t(e).value}}}},{"./create-store.js":314}],317:[function(t,e,r){"use strict";var n,i=function(){return function(t,e,r,n,i,a){var o=t[0],s=r[0],l=[0],c=s;n|=0;var u=0,f=s;for(u=0;u=0!=p>=0&&i.push(l[0]+.5+.5*(h+p)/(h-p)),n+=f,++l[0]}}};e.exports=(n={funcName:{funcName:"zeroCrossings"}.funcName},function(t){var e={};return function(r,n,i){var a=r.dtype,o=r.order,s=[a,o.join()].join(),l=e[s];return l||(e[s]=l=t([a,o])),l(r.shape.slice(0),r.data,r.stride,0|r.offset,n,i)}}(i.bind(void 0,n)))},{}],318:[function(t,e,r){"use strict";e.exports=function(t,e){var r=[];return e=+e||0,n(t.hi(t.shape[0]-1),r,e),r};var n=t("./lib/zc-core")},{"./lib/zc-core":317}]},{},[6])(6)}))}).call(this)}).call(this,"undefined"!=typeof global?global:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}]},{},[27])(27)}));
\ No newline at end of file diff --git a/docs/source/_static/sintel_final_epe_outlier-drop_kitti_sintel.html b/docs/source/_static/sintel_final_epe_outlier-drop_kitti_sintel.html index c409abe..fd9394b 100644 --- a/docs/source/_static/sintel_final_epe_outlier-drop_kitti_sintel.html +++ b/docs/source/_static/sintel_final_epe_outlier-drop_kitti_sintel.html @@ -3,30 +3,30 @@
+"use strict";var n,i="";e.exports=function(t,e){if("string"!=typeof t)throw new TypeError("expected a string");if(1===e)return t;if(2===e)return t+t;var r=t.length*e;if(n!==t||void 0===n)n=t,i="";else if(i.length>=r)return i.substr(0,r);for(;r>i.length&&e>1;)1&e&&(i+=t),e>>=1,t+=t;return i=(i+=t).substr(0,r)}},{}],278:[function(t,e,r){(function(t){(function(){e.exports=t.performance&&t.performance.now?function(){return performance.now()}:Date.now||function(){return+new Date}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}],279:[function(t,e,r){"use strict";e.exports=function(t){for(var e=t.length,r=t[t.length-1],n=e,i=e-2;i>=0;--i){var a=r,o=t[i];(l=o-((r=a+o)-a))&&(t[--n]=r,r=l)}var s=0;for(i=n;i0){if(a<=0)return o;n=i+a}else{if(!(i<0))return o;if(a>=0)return o;n=-(i+a)}var s=33306690738754716e-32*n;return o>=s||o<=-s?o:f(t,e,r)},function(t,e,r,n){var i=t[0]-n[0],a=e[0]-n[0],o=r[0]-n[0],s=t[1]-n[1],l=e[1]-n[1],c=r[1]-n[1],u=t[2]-n[2],f=e[2]-n[2],p=r[2]-n[2],d=a*c,m=o*l,g=o*s,v=i*c,y=i*l,x=a*s,b=u*(d-m)+f*(g-v)+p*(y-x),_=7771561172376103e-31*((Math.abs(d)+Math.abs(m))*Math.abs(u)+(Math.abs(g)+Math.abs(v))*Math.abs(f)+(Math.abs(y)+Math.abs(x))*Math.abs(p));return b>_||-b>_?b:h(t,e,r,n)}];function d(t){var e=p[t.length];return e||(e=p[t.length]=u(t.length)),e.apply(void 0,t)}function m(t,e,r,n,i,a,o){return function(e,r,s,l,c){switch(arguments.length){case 0:case 1:return 0;case 2:return n(e,r);case 3:return i(e,r,s);case 4:return a(e,r,s,l);case 5:return o(e,r,s,l,c)}for(var u=new Array(arguments.length),f=0;f0&&o>0||a<0&&o<0)return!1;var s=n(r,t,e),l=n(i,t,e);if(s>0&&l>0||s<0&&l<0)return!1;if(0===a&&0===o&&0===s&&0===l)return function(t,e,r,n){for(var i=0;i<2;++i){var a=t[i],o=e[i],s=Math.min(a,o),l=Math.max(a,o),c=r[i],u=n[i],f=Math.min(c,u);if(Math.max(c,u)=n?(i=f,(l+=1)=n?(i=f,(l+=1)>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,s=a(t[o],e);s<=0?(0===s&&(i=o),r=o+1):s>0&&(n=o-1)}return i}function u(t,e){for(var r=new Array(t.length),i=0,o=r.length;i=t.length||0!==a(t[g],s)););}return r}function f(t,e){if(e<0)return[];for(var r=[],i=(1<>>u&1&&c.push(i[u]);e.push(c)}return s(e)},r.skeleton=f,r.boundary=function(t){for(var e=[],r=0,n=t.length;r>1:(t>>1)-1}function x(t){for(var e=v(t);;){var r=e,n=2*t+1,i=2*(t+1),a=t;if(n0;){var r=y(t);if(r>=0)if(e0){var t=k[0];return g(0,M-1),M-=1,x(0),t}return-1}function w(t,e){var r=k[t];return c[r]===e?t:(c[r]=-1/0,b(t),_(),c[r]=e,b((M+=1)-1))}function T(t){if(!u[t]){u[t]=!0;var e=s[t],r=l[t];s[r]>=0&&(s[r]=e),l[e]>=0&&(l[e]=r),A[e]>=0&&w(A[e],m(e)),A[r]>=0&&w(A[r],m(r))}}var k=[],A=new Array(a);for(f=0;f>1;f>=0;--f)x(f);for(;;){var S=_();if(S<0||c[S]>r)break;T(S)}var E=[];for(f=0;f=0&&r>=0&&e!==r){var n=A[e],i=A[r];n!==i&&C.push([n,i])}})),i.unique(i.normalize(C)),{positions:E,edges:C}};var n=t("robust-orientation"),i=t("simplicial-complex")},{"robust-orientation":284,"simplicial-complex":295}],298:[function(t,e,r){"use strict";e.exports=function(t,e){var r,a,o,s;if(e[0][0]e[1][0]))return i(e,t);r=e[1],a=e[0]}if(t[0][0]t[1][0]))return-i(t,e);o=t[1],s=t[0]}var l=n(r,a,s),c=n(r,a,o);if(l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;if(l=n(s,o,a),c=n(s,o,r),l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;return a[0]-s[0]};var n=t("robust-orientation");function i(t,e){var r,i,a,o;if(e[0][0]e[1][0])){var s=Math.min(t[0][1],t[1][1]),l=Math.max(t[0][1],t[1][1]),c=Math.min(e[0][1],e[1][1]),u=Math.max(e[0][1],e[1][1]);return lu?s-u:l-u}r=e[1],i=e[0]}t[0][1]0)if(e[0]!==o[1][0])r=t,t=t.right;else{if(l=c(t.right,e))return l;t=t.left}else{if(e[0]!==o[1][0])return t;var l;if(l=c(t.right,e))return l;t=t.left}}return r}function u(t,e,r,n){this.y=t,this.index=e,this.start=r,this.closed=n}function f(t,e,r,n){this.x=t,this.segment=e,this.create=r,this.index=n}s.prototype.castUp=function(t){var e=n.le(this.coordinates,t[0]);if(e<0)return-1;this.slabs[e];var r=c(this.slabs[e],t),i=-1;if(r&&(i=r.value),this.coordinates[e]===t[0]){var s=null;if(r&&(s=r.key),e>0){var u=c(this.slabs[e-1],t);u&&(s?o(u.key,s)>0&&(s=u.key,i=u.value):(i=u.value,s=u.key))}var f=this.horizontal[e];if(f.length>0){var h=n.ge(f,t[1],l);if(h=f.length)return i;p=f[h]}}if(p.start)if(s){var d=a(s[0],s[1],[t[0],p.y]);s[0][0]>s[1][0]&&(d=-d),d>0&&(i=p.index)}else i=p.index;else p.y!==t[1]&&(i=p.index)}}}return i}},{"./lib/order-segments":298,"binary-search-bounds":31,"functional-red-black-tree":69,"robust-orientation":284}],300:[function(t,e,r){"use strict";var n=t("robust-dot-product"),i=t("robust-sum");function a(t,e){var r=i(n(t,e),[e[e.length-1]]);return r[r.length-1]}function o(t,e,r,n){var i=-e/(n-e);i<0?i=0:i>1&&(i=1);for(var a=1-i,o=t.length,s=new Array(o),l=0;l0||i>0&&u<0){var f=o(s,u,l,i);r.push(f),n.push(f.slice())}u<0?n.push(l.slice()):u>0?r.push(l.slice()):(r.push(l.slice()),n.push(l.slice())),i=u}return{positive:r,negative:n}},e.exports.positive=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c>=0&&r.push(s.slice()),n=c}return r},e.exports.negative=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c<=0&&r.push(s.slice()),n=c}return r}},{"robust-dot-product":281,"robust-sum":289}],301:[function(t,e,r){!function(){"use strict";var t={not_string:/[^s]/,not_bool:/[^t]/,not_type:/[^T]/,not_primitive:/[^v]/,number:/[diefg]/,numeric_arg:/[bcdiefguxX]/,json:/[j]/,not_json:/[^j]/,text:/^[^\x25]+/,modulo:/^\x25{2}/,placeholder:/^\x25(?:([1-9]\d*)\$|\(([^)]+)\))?(\+)?(0|'[^$])?(-)?(\d+)?(?:\.(\d+))?([b-gijostTuvxX])/,key:/^([a-z_][a-z_\d]*)/i,key_access:/^\.([a-z_][a-z_\d]*)/i,index_access:/^\[(\d+)\]/,sign:/^[+-]/};function e(t){return i(o(t),arguments)}function n(t,r){return e.apply(null,[t].concat(r||[]))}function i(r,n){var i,a,o,s,l,c,u,f,h,p=1,d=r.length,m="";for(a=0;a=0),s.type){case"b":i=parseInt(i,10).toString(2);break;case"c":i=String.fromCharCode(parseInt(i,10));break;case"d":case"i":i=parseInt(i,10);break;case"j":i=JSON.stringify(i,null,s.width?parseInt(s.width):0);break;case"e":i=s.precision?parseFloat(i).toExponential(s.precision):parseFloat(i).toExponential();break;case"f":i=s.precision?parseFloat(i).toFixed(s.precision):parseFloat(i);break;case"g":i=s.precision?String(Number(i.toPrecision(s.precision))):parseFloat(i);break;case"o":i=(parseInt(i,10)>>>0).toString(8);break;case"s":i=String(i),i=s.precision?i.substring(0,s.precision):i;break;case"t":i=String(!!i),i=s.precision?i.substring(0,s.precision):i;break;case"T":i=Object.prototype.toString.call(i).slice(8,-1).toLowerCase(),i=s.precision?i.substring(0,s.precision):i;break;case"u":i=parseInt(i,10)>>>0;break;case"v":i=i.valueOf(),i=s.precision?i.substring(0,s.precision):i;break;case"x":i=(parseInt(i,10)>>>0).toString(16);break;case"X":i=(parseInt(i,10)>>>0).toString(16).toUpperCase()}t.json.test(s.type)?m+=i:(!t.number.test(s.type)||f&&!s.sign?h="":(h=f?"+":"-",i=i.toString().replace(t.sign,"")),c=s.pad_char?"0"===s.pad_char?"0":s.pad_char.charAt(1):" ",u=s.width-(h+i).length,l=s.width&&u>0?c.repeat(u):"",m+=s.align?h+i+l:"0"===c?h+l+i:l+h+i)}return m}var a=Object.create(null);function o(e){if(a[e])return a[e];for(var r,n=e,i=[],o=0;n;){if(null!==(r=t.text.exec(n)))i.push(r[0]);else if(null!==(r=t.modulo.exec(n)))i.push("%");else{if(null===(r=t.placeholder.exec(n)))throw new SyntaxError("[sprintf] unexpected placeholder");if(r[2]){o|=1;var s=[],l=r[2],c=[];if(null===(c=t.key.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");for(s.push(c[1]);""!==(l=l.substring(c[0].length));)if(null!==(c=t.key_access.exec(l)))s.push(c[1]);else{if(null===(c=t.index_access.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");s.push(c[1])}r[2]=s}else o|=2;if(3===o)throw new Error("[sprintf] mixing positional and named placeholders is not (yet) supported");i.push({placeholder:r[0],param_no:r[1],keys:r[2],sign:r[3],pad_char:r[4],align:r[5],width:r[6],precision:r[7],type:r[8]})}n=n.substring(r[0].length)}return a[e]=i}void 0!==r&&(r.sprintf=e,r.vsprintf=n),"undefined"!=typeof window&&(window.sprintf=e,window.vsprintf=n)}()},{}],302:[function(t,e,r){"use strict";e.exports=function(t,e){if(t.dimension<=0)return{positions:[],cells:[]};if(1===t.dimension)return function(t,e){for(var r=i(t,e),n=r.length,a=new Array(n),o=new Array(n),s=0;sn|0},vertex:function(t,e,r,n,i,a,o,s,l,c,u,f,h){var p=(o<<0)+(s<<1)+(l<<2)+(c<<3)|0;if(0!==p&&15!==p)switch(p){case 0:u.push([t-.5,e-.5]);break;case 1:u.push([t-.25-.25*(n+r-2*h)/(r-n),e-.25-.25*(i+r-2*h)/(r-i)]);break;case 2:u.push([t-.75-.25*(-n-r+2*h)/(n-r),e-.25-.25*(a+n-2*h)/(n-a)]);break;case 3:u.push([t-.5,e-.5-.5*(i+r+a+n-4*h)/(r-i+n-a)]);break;case 4:u.push([t-.25-.25*(a+i-2*h)/(i-a),e-.75-.25*(-i-r+2*h)/(i-r)]);break;case 5:u.push([t-.5-.5*(n+r+a+i-4*h)/(r-n+i-a),e-.5]);break;case 6:u.push([t-.5-.25*(-n-r+a+i)/(n-r+i-a),e-.5-.25*(-i-r+a+n)/(i-r+n-a)]);break;case 7:u.push([t-.75-.25*(a+i-2*h)/(i-a),e-.75-.25*(a+n-2*h)/(n-a)]);break;case 8:u.push([t-.75-.25*(-a-i+2*h)/(a-i),e-.75-.25*(-a-n+2*h)/(a-n)]);break;case 9:u.push([t-.5-.25*(n+r+-a-i)/(r-n+a-i),e-.5-.25*(i+r+-a-n)/(r-i+a-n)]);break;case 10:u.push([t-.5-.5*(-n-r-a-i+4*h)/(n-r+a-i),e-.5]);break;case 11:u.push([t-.25-.25*(-a-i+2*h)/(a-i),e-.75-.25*(i+r-2*h)/(r-i)]);break;case 12:u.push([t-.5,e-.5-.5*(-i-r-a-n+4*h)/(i-r+a-n)]);break;case 13:u.push([t-.75-.25*(n+r-2*h)/(r-n),e-.25-.25*(-a-n+2*h)/(a-n)]);break;case 14:u.push([t-.25-.25*(-n-r+2*h)/(n-r),e-.25-.25*(-i-r+2*h)/(i-r)]);break;case 15:u.push([t-.5,e-.5])}},cell:function(t,e,r,n,i,a,o,s,l){i?s.push([t,e]):s.push([e,t])}});return function(t,e){var r=[],i=[];return n(t,r,i,e),{positions:r,cells:i}}}};var o={}},{"ndarray-extract-contour":251,"zero-crossings":318}],303:[function(t,e,r){(function(r){(function(){"use strict";e.exports=function t(e,r,i){i=i||{};var o=a[e];o||(o=a[e]={" ":{data:new Float32Array(0),shape:.2}});var s=o[r];if(!s)if(r.length<=1||!/\d/.test(r))s=o[r]=function(t){for(var e=t.cells,r=t.positions,n=new Float32Array(6*e.length),i=0,a=0,o=0;o0&&(f+=.02);var p=new Float32Array(u),d=0,m=-.5*f;for(h=0;hMath.max(r,n)?i[2]=1:r>Math.max(e,n)?i[0]=1:i[1]=1;for(var a=0,o=0,l=0;l<3;++l)a+=t[l]*t[l],o+=i[l]*t[l];for(l=0;l<3;++l)i[l]-=o/a*t[l];return s(i,i),i}function h(t,e,r,i,a,o,s,l){this.center=n(r),this.up=n(i),this.right=n(a),this.radius=n([o]),this.angle=n([s,l]),this.angle.bounds=[[-1/0,-Math.PI/2],[1/0,Math.PI/2]],this.setDistanceLimits(t,e),this.computedCenter=this.center.curve(0),this.computedUp=this.up.curve(0),this.computedRight=this.right.curve(0),this.computedRadius=this.radius.curve(0),this.computedAngle=this.angle.curve(0),this.computedToward=[0,0,0],this.computedEye=[0,0,0],this.computedMatrix=new Array(16);for(var c=0;c<16;++c)this.computedMatrix[c]=.5;this.recalcMatrix(0)}var p=h.prototype;p.setDistanceLimits=function(t,e){t=t>0?Math.log(t):-1/0,e=e>0?Math.log(e):1/0,e=Math.max(e,t),this.radius.bounds[0][0]=t,this.radius.bounds[1][0]=e},p.getDistanceLimits=function(t){var e=this.radius.bounds[0];return t?(t[0]=Math.exp(e[0][0]),t[1]=Math.exp(e[1][0]),t):[Math.exp(e[0][0]),Math.exp(e[1][0])]},p.recalcMatrix=function(t){this.center.curve(t),this.up.curve(t),this.right.curve(t),this.radius.curve(t),this.angle.curve(t);for(var e=this.computedUp,r=this.computedRight,n=0,i=0,a=0;a<3;++a)i+=e[a]*r[a],n+=e[a]*e[a];var l=Math.sqrt(n),u=0;for(a=0;a<3;++a)r[a]-=e[a]*i/n,u+=r[a]*r[a],e[a]/=l;var f=Math.sqrt(u);for(a=0;a<3;++a)r[a]/=f;var h=this.computedToward;o(h,e,r),s(h,h);var p=Math.exp(this.computedRadius[0]),d=this.computedAngle[0],m=this.computedAngle[1],g=Math.cos(d),v=Math.sin(d),y=Math.cos(m),x=Math.sin(m),b=this.computedCenter,_=g*y,w=v*y,T=x,k=-g*x,A=-v*x,M=y,S=this.computedEye,E=this.computedMatrix;for(a=0;a<3;++a){var L=_*r[a]+w*h[a]+T*e[a];E[4*a+1]=k*r[a]+A*h[a]+M*e[a],E[4*a+2]=L,E[4*a+3]=0}var C=E[1],P=E[5],I=E[9],O=E[2],z=E[6],D=E[10],R=P*D-I*z,F=I*O-C*D,B=C*z-P*O,N=c(R,F,B);R/=N,F/=N,B/=N,E[0]=R,E[4]=F,E[8]=B;for(a=0;a<3;++a)S[a]=b[a]+E[2+4*a]*p;for(a=0;a<3;++a){u=0;for(var j=0;j<3;++j)u+=E[a+4*j]*S[j];E[12+a]=-u}E[15]=1},p.getMatrix=function(t,e){this.recalcMatrix(t);var r=this.computedMatrix;if(e){for(var n=0;n<16;++n)e[n]=r[n];return e}return r};var d=[0,0,0];p.rotate=function(t,e,r,n){if(this.angle.move(t,e,r),n){this.recalcMatrix(t);var i=this.computedMatrix;d[0]=i[2],d[1]=i[6],d[2]=i[10];for(var o=this.computedUp,s=this.computedRight,l=this.computedToward,c=0;c<3;++c)i[4*c]=o[c],i[4*c+1]=s[c],i[4*c+2]=l[c];a(i,i,n,d);for(c=0;c<3;++c)o[c]=i[4*c],s[c]=i[4*c+1];this.up.set(t,o[0],o[1],o[2]),this.right.set(t,s[0],s[1],s[2])}},p.pan=function(t,e,r,n){e=e||0,r=r||0,n=n||0,this.recalcMatrix(t);var i=this.computedMatrix,a=(Math.exp(this.computedRadius[0]),i[1]),o=i[5],s=i[9],l=c(a,o,s);a/=l,o/=l,s/=l;var u=i[0],f=i[4],h=i[8],p=u*a+f*o+h*s,d=c(u-=a*p,f-=o*p,h-=s*p),m=(u/=d)*e+a*r,g=(f/=d)*e+o*r,v=(h/=d)*e+s*r;this.center.move(t,m,g,v);var y=Math.exp(this.computedRadius[0]);y=Math.max(1e-4,y+n),this.radius.set(t,Math.log(y))},p.translate=function(t,e,r,n){this.center.move(t,e||0,r||0,n||0)},p.setMatrix=function(t,e,r,n){var a=1;"number"==typeof r&&(a=0|r),(a<0||a>3)&&(a=1);var o=(a+2)%3;e||(this.recalcMatrix(t),e=this.computedMatrix);var s=e[a],l=e[a+4],f=e[a+8];if(n){var h=Math.abs(s),p=Math.abs(l),d=Math.abs(f),m=Math.max(h,p,d);h===m?(s=s<0?-1:1,l=f=0):d===m?(f=f<0?-1:1,s=l=0):(l=l<0?-1:1,s=f=0)}else{var g=c(s,l,f);s/=g,l/=g,f/=g}var v,y,x=e[o],b=e[o+4],_=e[o+8],w=x*s+b*l+_*f,T=c(x-=s*w,b-=l*w,_-=f*w),k=l*(_/=T)-f*(b/=T),A=f*(x/=T)-s*_,M=s*b-l*x,S=c(k,A,M);if(k/=S,A/=S,M/=S,this.center.jump(t,q,G,Y),this.radius.idle(t),this.up.jump(t,s,l,f),this.right.jump(t,x,b,_),2===a){var E=e[1],L=e[5],C=e[9],P=E*x+L*b+C*_,I=E*k+L*A+C*M;v=R<0?-Math.PI/2:Math.PI/2,y=Math.atan2(I,P)}else{var O=e[2],z=e[6],D=e[10],R=O*s+z*l+D*f,F=O*x+z*b+D*_,B=O*k+z*A+D*M;v=Math.asin(u(R)),y=Math.atan2(B,F)}this.angle.jump(t,y,v),this.recalcMatrix(t);var N=e[2],j=e[6],U=e[10],V=this.computedMatrix;i(V,e);var H=V[15],q=V[12]/H,G=V[13]/H,Y=V[14]/H,W=Math.exp(this.computedRadius[0]);this.center.jump(t,q-N*W,G-j*W,Y-U*W)},p.lastT=function(){return Math.max(this.center.lastT(),this.up.lastT(),this.right.lastT(),this.radius.lastT(),this.angle.lastT())},p.idle=function(t){this.center.idle(t),this.up.idle(t),this.right.idle(t),this.radius.idle(t),this.angle.idle(t)},p.flush=function(t){this.center.flush(t),this.up.flush(t),this.right.flush(t),this.radius.flush(t),this.angle.flush(t)},p.setDistance=function(t,e){e>0&&this.radius.set(t,Math.log(e))},p.lookAt=function(t,e,r,n){this.recalcMatrix(t),e=e||this.computedEye,r=r||this.computedCenter;var i=(n=n||this.computedUp)[0],a=n[1],o=n[2],s=c(i,a,o);if(!(s<1e-6)){i/=s,a/=s,o/=s;var l=e[0]-r[0],f=e[1]-r[1],h=e[2]-r[2],p=c(l,f,h);if(!(p<1e-6)){l/=p,f/=p,h/=p;var d=this.computedRight,m=d[0],g=d[1],v=d[2],y=i*m+a*g+o*v,x=c(m-=y*i,g-=y*a,v-=y*o);if(!(x<.01&&(x=c(m=a*h-o*f,g=o*l-i*h,v=i*f-a*l))<1e-6)){m/=x,g/=x,v/=x,this.up.set(t,i,a,o),this.right.set(t,m,g,v),this.center.set(t,r[0],r[1],r[2]),this.radius.set(t,Math.log(p));var b=a*v-o*g,_=o*m-i*v,w=i*g-a*m,T=c(b,_,w),k=i*l+a*f+o*h,A=m*l+g*f+v*h,M=(b/=T)*l+(_/=T)*f+(w/=T)*h,S=Math.asin(u(k)),E=Math.atan2(M,A),L=this.angle._state,C=L[L.length-1],P=L[L.length-2];C%=2*Math.PI;var I=Math.abs(C+2*Math.PI-E),O=Math.abs(C-E),z=Math.abs(C-2*Math.PI-E);I0?r.pop():new ArrayBuffer(t)}function d(t){return new Uint8Array(p(t),0,t)}function m(t){return new Uint16Array(p(2*t),0,t)}function g(t){return new Uint32Array(p(4*t),0,t)}function v(t){return new Int8Array(p(t),0,t)}function y(t){return new Int16Array(p(2*t),0,t)}function x(t){return new Int32Array(p(4*t),0,t)}function b(t){return new Float32Array(p(4*t),0,t)}function _(t){return new Float64Array(p(8*t),0,t)}function w(t){return o?new Uint8ClampedArray(p(t),0,t):d(t)}function T(t){return s?new BigUint64Array(p(8*t),0,t):null}function k(t){return l?new BigInt64Array(p(8*t),0,t):null}function A(t){return new DataView(p(t),0,t)}function M(t){t=n.nextPow2(t);var e=n.log2(t),r=f[e];return r.length>0?r.pop():new a(t)}r.free=function(t){if(a.isBuffer(t))f[n.log2(t.length)].push(t);else{if("[object ArrayBuffer]"!==Object.prototype.toString.call(t)&&(t=t.buffer),!t)return;var e=t.length||t.byteLength,r=0|n.log2(e);u[r].push(t)}},r.freeUint8=r.freeUint16=r.freeUint32=r.freeBigUint64=r.freeInt8=r.freeInt16=r.freeInt32=r.freeBigInt64=r.freeFloat32=r.freeFloat=r.freeFloat64=r.freeDouble=r.freeUint8Clamped=r.freeDataView=function(t){h(t.buffer)},r.freeArrayBuffer=h,r.freeBuffer=function(t){f[n.log2(t.length)].push(t)},r.malloc=function(t,e){if(void 0===e||"arraybuffer"===e)return p(t);switch(e){case"uint8":return d(t);case"uint16":return m(t);case"uint32":return g(t);case"int8":return v(t);case"int16":return y(t);case"int32":return x(t);case"float":case"float32":return b(t);case"double":case"float64":return _(t);case"uint8_clamped":return w(t);case"bigint64":return k(t);case"biguint64":return T(t);case"buffer":return M(t);case"data":case"dataview":return A(t);default:return null}return null},r.mallocArrayBuffer=p,r.mallocUint8=d,r.mallocUint16=m,r.mallocUint32=g,r.mallocInt8=v,r.mallocInt16=y,r.mallocInt32=x,r.mallocFloat32=r.mallocFloat=b,r.mallocFloat64=r.mallocDouble=_,r.mallocUint8Clamped=w,r.mallocBigUint64=T,r.mallocBigInt64=k,r.mallocDataView=A,r.mallocBuffer=M,r.clearCache=function(){for(var t=0;t<32;++t)c.UINT8[t].length=0,c.UINT16[t].length=0,c.UINT32[t].length=0,c.INT8[t].length=0,c.INT16[t].length=0,c.INT32[t].length=0,c.FLOAT[t].length=0,c.DOUBLE[t].length=0,c.BIGUINT64[t].length=0,c.BIGINT64[t].length=0,c.UINT8C[t].length=0,u[t].length=0,f[t].length=0}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{"bit-twiddle":32,buffer:3,dup:65}],309:[function(t,e,r){"use strict";function n(t){this.roots=new Array(t),this.ranks=new Array(t);for(var e=0;e0&&(a=n.size),n.lineSpacing&&n.lineSpacing>0&&(o=n.lineSpacing),n.styletags&&n.styletags.breaklines&&(s.breaklines=!!n.styletags.breaklines),n.styletags&&n.styletags.bolds&&(s.bolds=!!n.styletags.bolds),n.styletags&&n.styletags.italics&&(s.italics=!!n.styletags.italics),n.styletags&&n.styletags.subscripts&&(s.subscripts=!!n.styletags.subscripts),n.styletags&&n.styletags.superscripts&&(s.superscripts=!!n.styletags.superscripts));return r.font=[n.fontStyle,n.fontVariant,n.fontWeight,a+"px",n.font].filter((function(t){return t})).join(" "),r.textAlign="start",r.textBaseline="alphabetic",r.direction="ltr",h(function(t,e,r,n,a,o){r=r.replace(/\n/g,""),r=!0===o.breaklines?r.replace(/\/g,"\n"):r.replace(/\/g," ");var s="",l=[];for(p=0;p-1?parseInt(t[1+i]):0,l=a>-1?parseInt(r[1+a]):0;s!==l&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,l-s),n=n.replace("?px ",S())),m+=.25*x*(l-s)}if(!0===o.superscripts){var c=t.indexOf("+"),u=r.indexOf("+"),f=c>-1?parseInt(t[1+c]):0,h=u>-1?parseInt(r[1+u]):0;f!==h&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,h-f),n=n.replace("?px ",S())),m-=.25*x*(h-f)}if(!0===o.bolds){var p=t.indexOf("b|")>-1,d=r.indexOf("b|")>-1;!p&&d&&(n=v?n.replace("italic ","italic bold "):"bold "+n),p&&!d&&(n=n.replace("bold ",""))}if(!0===o.italics){var v=t.indexOf("i|")>-1,y=r.indexOf("i|")>-1;!v&&y&&(n="italic "+n),v&&!y&&(n=n.replace("italic ",""))}e.font=n}for(h=0;h",a="",o=i.length,s=a.length,l="+"===e[0]||"-"===e[0],c=0,u=-s;c>-1&&-1!==(c=r.indexOf(i,c))&&-1!==(u=r.indexOf(a,c+o))&&!(u<=c);){for(var f=c;f=u)n[f]=null,r=r.substr(0,f)+" "+r.substr(f+1);else if(null!==n[f]){var h=n[f].indexOf(e[0]);-1===h?n[f]+=e:l&&(n[f]=n[f].substr(0,h+1)+(1+parseInt(n[f][h+1]))+n[f].substr(h+2))}var p=c+o,d=r.substr(p,u-p).indexOf(i);c=-1!==d?d:u+s}return n}function u(t,e){var r=n(t,128);return e?a(r.cells,r.positions,.25):{edges:r.cells,positions:r.positions}}function f(t,e,r,n){var i=u(t,n),a=function(t,e,r){for(var n=e.textAlign||"start",i=e.textBaseline||"alphabetic",a=[1<<30,1<<30],o=[0,0],s=t.length,l=0;l=0?e[a]:i}))},has___:{value:y((function(e){var n=v(e);return n?r in n:t.indexOf(e)>=0}))},set___:{value:y((function(n,i){var a,o=v(n);return o?o[r]=i:(a=t.indexOf(n))>=0?e[a]=i:(a=t.length,e[a]=i,t[a]=n),this}))},delete___:{value:y((function(n){var i,a,o=v(n);return o?r in o&&delete o[r]:!((i=t.indexOf(n))<0)&&(a=t.length-1,t[i]=void 0,e[i]=e[a],t[i]=t[a],t.length=a,e.length=a,!0)}))}})};d.prototype=Object.create(Object.prototype,{get:{value:function(t,e){return this.get___(t,e)},writable:!0,configurable:!0},has:{value:function(t){return this.has___(t)},writable:!0,configurable:!0},set:{value:function(t,e){return this.set___(t,e)},writable:!0,configurable:!0},delete:{value:function(t){return this.delete___(t)},writable:!0,configurable:!0}}),"function"==typeof r?function(){function n(){this instanceof d||x();var e,n=new r,i=void 0,a=!1;return e=t?function(t,e){return n.set(t,e),n.has(t)||(i||(i=new d),i.set(t,e)),this}:function(t,e){if(a)try{n.set(t,e)}catch(r){i||(i=new d),i.set___(t,e)}else n.set(t,e);return this},Object.create(d.prototype,{get___:{value:y((function(t,e){return i?n.has(t)?n.get(t):i.get___(t,e):n.get(t,e)}))},has___:{value:y((function(t){return n.has(t)||!!i&&i.has___(t)}))},set___:{value:y(e)},delete___:{value:y((function(t){var e=!!n.delete(t);return i&&i.delete___(t)||e}))},permitHostObjects___:{value:y((function(t){if(t!==m)throw new Error("bogus call to permitHostObjects___");a=!0}))}})}t&&"undefined"!=typeof Proxy&&(Proxy=void 0),n.prototype=d.prototype,e.exports=n,Object.defineProperty(WeakMap.prototype,"constructor",{value:WeakMap,enumerable:!1,configurable:!0,writable:!0})}():("undefined"!=typeof Proxy&&(Proxy=void 0),e.exports=d)}function m(t){t.permitHostObjects___&&t.permitHostObjects___(m)}function g(t){return!("weakmap:"==t.substr(0,"weakmap:".length)&&"___"===t.substr(t.length-3))}function v(t){if(t!==Object(t))throw new TypeError("Not an object: "+t);var e=t[l];if(e&&e.key===t)return e;if(s(t)){e={key:t};try{return o(t,l,{value:e,writable:!1,enumerable:!1,configurable:!1}),e}catch(t){return}}}function y(t){return t.prototype=null,Object.freeze(t)}function x(){h||"undefined"==typeof console||(h=!0,console.warn("WeakMap should be invoked as new WeakMap(), not WeakMap(). This will be an error in the future."))}}()},{}],314:[function(t,e,r){var n=t("./hidden-store.js");e.exports=function(){var t={};return function(e){if(("object"!=typeof e||null===e)&&"function"!=typeof e)throw new Error("Weakmap-shim: Key must be object");var r=e.valueOf(t);return r&&r.identity===t?r:n(e,t)}}},{"./hidden-store.js":315}],315:[function(t,e,r){e.exports=function(t,e){var r={identity:e},n=t.valueOf;return Object.defineProperty(t,"valueOf",{value:function(t){return t!==e?n.apply(this,arguments):r},writable:!0}),r}},{}],316:[function(t,e,r){var n=t("./create-store.js");e.exports=function(){var t=n();return{get:function(e,r){var n=t(e);return n.hasOwnProperty("value")?n.value:r},set:function(e,r){return t(e).value=r,this},has:function(e){return"value"in t(e)},delete:function(e){return delete t(e).value}}}},{"./create-store.js":314}],317:[function(t,e,r){"use strict";var n,i=function(){return function(t,e,r,n,i,a){var o=t[0],s=r[0],l=[0],c=s;n|=0;var u=0,f=s;for(u=0;u=0!=p>=0&&i.push(l[0]+.5+.5*(h+p)/(h-p)),n+=f,++l[0]}}};e.exports=(n={funcName:{funcName:"zeroCrossings"}.funcName},function(t){var e={};return function(r,n,i){var a=r.dtype,o=r.order,s=[a,o.join()].join(),l=e[s];return l||(e[s]=l=t([a,o])),l(r.shape.slice(0),r.data,r.stride,0|r.offset,n,i)}}(i.bind(void 0,n)))},{}],318:[function(t,e,r){"use strict";e.exports=function(t,e){var r=[];return e=+e||0,n(t.hi(t.shape[0]-1),r,e),r};var n=t("./lib/zc-core")},{"./lib/zc-core":317}]},{},[6])(6)}))}).call(this)}).call(this,"undefined"!=typeof global?global:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}]},{},[27])(27)}));
\ No newline at end of file diff --git a/docs/source/_static/speed_plot-all.html b/docs/source/_static/speed_plot-all.html index 92a62e3..900e24f 100644 --- a/docs/source/_static/speed_plot-all.html +++ b/docs/source/_static/speed_plot-all.html @@ -3,30 +3,30 @@
+"use strict";var n,i="";e.exports=function(t,e){if("string"!=typeof t)throw new TypeError("expected a string");if(1===e)return t;if(2===e)return t+t;var r=t.length*e;if(n!==t||void 0===n)n=t,i="";else if(i.length>=r)return i.substr(0,r);for(;r>i.length&&e>1;)1&e&&(i+=t),e>>=1,t+=t;return i=(i+=t).substr(0,r)}},{}],278:[function(t,e,r){(function(t){(function(){e.exports=t.performance&&t.performance.now?function(){return performance.now()}:Date.now||function(){return+new Date}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}],279:[function(t,e,r){"use strict";e.exports=function(t){for(var e=t.length,r=t[t.length-1],n=e,i=e-2;i>=0;--i){var a=r,o=t[i];(l=o-((r=a+o)-a))&&(t[--n]=r,r=l)}var s=0;for(i=n;i0){if(a<=0)return o;n=i+a}else{if(!(i<0))return o;if(a>=0)return o;n=-(i+a)}var s=33306690738754716e-32*n;return o>=s||o<=-s?o:f(t,e,r)},function(t,e,r,n){var i=t[0]-n[0],a=e[0]-n[0],o=r[0]-n[0],s=t[1]-n[1],l=e[1]-n[1],c=r[1]-n[1],u=t[2]-n[2],f=e[2]-n[2],p=r[2]-n[2],d=a*c,m=o*l,g=o*s,v=i*c,y=i*l,x=a*s,b=u*(d-m)+f*(g-v)+p*(y-x),_=7771561172376103e-31*((Math.abs(d)+Math.abs(m))*Math.abs(u)+(Math.abs(g)+Math.abs(v))*Math.abs(f)+(Math.abs(y)+Math.abs(x))*Math.abs(p));return b>_||-b>_?b:h(t,e,r,n)}];function d(t){var e=p[t.length];return e||(e=p[t.length]=u(t.length)),e.apply(void 0,t)}function m(t,e,r,n,i,a,o){return function(e,r,s,l,c){switch(arguments.length){case 0:case 1:return 0;case 2:return n(e,r);case 3:return i(e,r,s);case 4:return a(e,r,s,l);case 5:return o(e,r,s,l,c)}for(var u=new Array(arguments.length),f=0;f0&&o>0||a<0&&o<0)return!1;var s=n(r,t,e),l=n(i,t,e);if(s>0&&l>0||s<0&&l<0)return!1;if(0===a&&0===o&&0===s&&0===l)return function(t,e,r,n){for(var i=0;i<2;++i){var a=t[i],o=e[i],s=Math.min(a,o),l=Math.max(a,o),c=r[i],u=n[i],f=Math.min(c,u);if(Math.max(c,u)=n?(i=f,(l+=1)=n?(i=f,(l+=1)>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,c=e[2*l+1];if(c===a)return l;a>1,s=a(t[o],e);s<=0?(0===s&&(i=o),r=o+1):s>0&&(n=o-1)}return i}function u(t,e){for(var r=new Array(t.length),i=0,o=r.length;i=t.length||0!==a(t[g],s)););}return r}function f(t,e){if(e<0)return[];for(var r=[],i=(1<>>u&1&&c.push(i[u]);e.push(c)}return s(e)},r.skeleton=f,r.boundary=function(t){for(var e=[],r=0,n=t.length;r>1:(t>>1)-1}function x(t){for(var e=v(t);;){var r=e,n=2*t+1,i=2*(t+1),a=t;if(n0;){var r=y(t);if(r>=0)if(e0){var t=k[0];return g(0,M-1),M-=1,x(0),t}return-1}function w(t,e){var r=k[t];return c[r]===e?t:(c[r]=-1/0,b(t),_(),c[r]=e,b((M+=1)-1))}function T(t){if(!u[t]){u[t]=!0;var e=s[t],r=l[t];s[r]>=0&&(s[r]=e),l[e]>=0&&(l[e]=r),A[e]>=0&&w(A[e],m(e)),A[r]>=0&&w(A[r],m(r))}}var k=[],A=new Array(a);for(f=0;f>1;f>=0;--f)x(f);for(;;){var S=_();if(S<0||c[S]>r)break;T(S)}var E=[];for(f=0;f=0&&r>=0&&e!==r){var n=A[e],i=A[r];n!==i&&C.push([n,i])}})),i.unique(i.normalize(C)),{positions:E,edges:C}};var n=t("robust-orientation"),i=t("simplicial-complex")},{"robust-orientation":284,"simplicial-complex":295}],298:[function(t,e,r){"use strict";e.exports=function(t,e){var r,a,o,s;if(e[0][0]e[1][0]))return i(e,t);r=e[1],a=e[0]}if(t[0][0]t[1][0]))return-i(t,e);o=t[1],s=t[0]}var l=n(r,a,s),c=n(r,a,o);if(l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;if(l=n(s,o,a),c=n(s,o,r),l<0){if(c<=0)return l}else if(l>0){if(c>=0)return l}else if(c)return c;return a[0]-s[0]};var n=t("robust-orientation");function i(t,e){var r,i,a,o;if(e[0][0]e[1][0])){var s=Math.min(t[0][1],t[1][1]),l=Math.max(t[0][1],t[1][1]),c=Math.min(e[0][1],e[1][1]),u=Math.max(e[0][1],e[1][1]);return lu?s-u:l-u}r=e[1],i=e[0]}t[0][1]0)if(e[0]!==o[1][0])r=t,t=t.right;else{if(l=c(t.right,e))return l;t=t.left}else{if(e[0]!==o[1][0])return t;var l;if(l=c(t.right,e))return l;t=t.left}}return r}function u(t,e,r,n){this.y=t,this.index=e,this.start=r,this.closed=n}function f(t,e,r,n){this.x=t,this.segment=e,this.create=r,this.index=n}s.prototype.castUp=function(t){var e=n.le(this.coordinates,t[0]);if(e<0)return-1;this.slabs[e];var r=c(this.slabs[e],t),i=-1;if(r&&(i=r.value),this.coordinates[e]===t[0]){var s=null;if(r&&(s=r.key),e>0){var u=c(this.slabs[e-1],t);u&&(s?o(u.key,s)>0&&(s=u.key,i=u.value):(i=u.value,s=u.key))}var f=this.horizontal[e];if(f.length>0){var h=n.ge(f,t[1],l);if(h=f.length)return i;p=f[h]}}if(p.start)if(s){var d=a(s[0],s[1],[t[0],p.y]);s[0][0]>s[1][0]&&(d=-d),d>0&&(i=p.index)}else i=p.index;else p.y!==t[1]&&(i=p.index)}}}return i}},{"./lib/order-segments":298,"binary-search-bounds":31,"functional-red-black-tree":69,"robust-orientation":284}],300:[function(t,e,r){"use strict";var n=t("robust-dot-product"),i=t("robust-sum");function a(t,e){var r=i(n(t,e),[e[e.length-1]]);return r[r.length-1]}function o(t,e,r,n){var i=-e/(n-e);i<0?i=0:i>1&&(i=1);for(var a=1-i,o=t.length,s=new Array(o),l=0;l0||i>0&&u<0){var f=o(s,u,l,i);r.push(f),n.push(f.slice())}u<0?n.push(l.slice()):u>0?r.push(l.slice()):(r.push(l.slice()),n.push(l.slice())),i=u}return{positive:r,negative:n}},e.exports.positive=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c>=0&&r.push(s.slice()),n=c}return r},e.exports.negative=function(t,e){for(var r=[],n=a(t[t.length-1],e),i=t[t.length-1],s=t[0],l=0;l0||n>0&&c<0)&&r.push(o(i,c,s,n)),c<=0&&r.push(s.slice()),n=c}return r}},{"robust-dot-product":281,"robust-sum":289}],301:[function(t,e,r){!function(){"use strict";var t={not_string:/[^s]/,not_bool:/[^t]/,not_type:/[^T]/,not_primitive:/[^v]/,number:/[diefg]/,numeric_arg:/[bcdiefguxX]/,json:/[j]/,not_json:/[^j]/,text:/^[^\x25]+/,modulo:/^\x25{2}/,placeholder:/^\x25(?:([1-9]\d*)\$|\(([^)]+)\))?(\+)?(0|'[^$])?(-)?(\d+)?(?:\.(\d+))?([b-gijostTuvxX])/,key:/^([a-z_][a-z_\d]*)/i,key_access:/^\.([a-z_][a-z_\d]*)/i,index_access:/^\[(\d+)\]/,sign:/^[+-]/};function e(t){return i(o(t),arguments)}function n(t,r){return e.apply(null,[t].concat(r||[]))}function i(r,n){var i,a,o,s,l,c,u,f,h,p=1,d=r.length,m="";for(a=0;a=0),s.type){case"b":i=parseInt(i,10).toString(2);break;case"c":i=String.fromCharCode(parseInt(i,10));break;case"d":case"i":i=parseInt(i,10);break;case"j":i=JSON.stringify(i,null,s.width?parseInt(s.width):0);break;case"e":i=s.precision?parseFloat(i).toExponential(s.precision):parseFloat(i).toExponential();break;case"f":i=s.precision?parseFloat(i).toFixed(s.precision):parseFloat(i);break;case"g":i=s.precision?String(Number(i.toPrecision(s.precision))):parseFloat(i);break;case"o":i=(parseInt(i,10)>>>0).toString(8);break;case"s":i=String(i),i=s.precision?i.substring(0,s.precision):i;break;case"t":i=String(!!i),i=s.precision?i.substring(0,s.precision):i;break;case"T":i=Object.prototype.toString.call(i).slice(8,-1).toLowerCase(),i=s.precision?i.substring(0,s.precision):i;break;case"u":i=parseInt(i,10)>>>0;break;case"v":i=i.valueOf(),i=s.precision?i.substring(0,s.precision):i;break;case"x":i=(parseInt(i,10)>>>0).toString(16);break;case"X":i=(parseInt(i,10)>>>0).toString(16).toUpperCase()}t.json.test(s.type)?m+=i:(!t.number.test(s.type)||f&&!s.sign?h="":(h=f?"+":"-",i=i.toString().replace(t.sign,"")),c=s.pad_char?"0"===s.pad_char?"0":s.pad_char.charAt(1):" ",u=s.width-(h+i).length,l=s.width&&u>0?c.repeat(u):"",m+=s.align?h+i+l:"0"===c?h+l+i:l+h+i)}return m}var a=Object.create(null);function o(e){if(a[e])return a[e];for(var r,n=e,i=[],o=0;n;){if(null!==(r=t.text.exec(n)))i.push(r[0]);else if(null!==(r=t.modulo.exec(n)))i.push("%");else{if(null===(r=t.placeholder.exec(n)))throw new SyntaxError("[sprintf] unexpected placeholder");if(r[2]){o|=1;var s=[],l=r[2],c=[];if(null===(c=t.key.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");for(s.push(c[1]);""!==(l=l.substring(c[0].length));)if(null!==(c=t.key_access.exec(l)))s.push(c[1]);else{if(null===(c=t.index_access.exec(l)))throw new SyntaxError("[sprintf] failed to parse named argument key");s.push(c[1])}r[2]=s}else o|=2;if(3===o)throw new Error("[sprintf] mixing positional and named placeholders is not (yet) supported");i.push({placeholder:r[0],param_no:r[1],keys:r[2],sign:r[3],pad_char:r[4],align:r[5],width:r[6],precision:r[7],type:r[8]})}n=n.substring(r[0].length)}return a[e]=i}void 0!==r&&(r.sprintf=e,r.vsprintf=n),"undefined"!=typeof window&&(window.sprintf=e,window.vsprintf=n)}()},{}],302:[function(t,e,r){"use strict";e.exports=function(t,e){if(t.dimension<=0)return{positions:[],cells:[]};if(1===t.dimension)return function(t,e){for(var r=i(t,e),n=r.length,a=new Array(n),o=new Array(n),s=0;sn|0},vertex:function(t,e,r,n,i,a,o,s,l,c,u,f,h){var p=(o<<0)+(s<<1)+(l<<2)+(c<<3)|0;if(0!==p&&15!==p)switch(p){case 0:u.push([t-.5,e-.5]);break;case 1:u.push([t-.25-.25*(n+r-2*h)/(r-n),e-.25-.25*(i+r-2*h)/(r-i)]);break;case 2:u.push([t-.75-.25*(-n-r+2*h)/(n-r),e-.25-.25*(a+n-2*h)/(n-a)]);break;case 3:u.push([t-.5,e-.5-.5*(i+r+a+n-4*h)/(r-i+n-a)]);break;case 4:u.push([t-.25-.25*(a+i-2*h)/(i-a),e-.75-.25*(-i-r+2*h)/(i-r)]);break;case 5:u.push([t-.5-.5*(n+r+a+i-4*h)/(r-n+i-a),e-.5]);break;case 6:u.push([t-.5-.25*(-n-r+a+i)/(n-r+i-a),e-.5-.25*(-i-r+a+n)/(i-r+n-a)]);break;case 7:u.push([t-.75-.25*(a+i-2*h)/(i-a),e-.75-.25*(a+n-2*h)/(n-a)]);break;case 8:u.push([t-.75-.25*(-a-i+2*h)/(a-i),e-.75-.25*(-a-n+2*h)/(a-n)]);break;case 9:u.push([t-.5-.25*(n+r+-a-i)/(r-n+a-i),e-.5-.25*(i+r+-a-n)/(r-i+a-n)]);break;case 10:u.push([t-.5-.5*(-n-r-a-i+4*h)/(n-r+a-i),e-.5]);break;case 11:u.push([t-.25-.25*(-a-i+2*h)/(a-i),e-.75-.25*(i+r-2*h)/(r-i)]);break;case 12:u.push([t-.5,e-.5-.5*(-i-r-a-n+4*h)/(i-r+a-n)]);break;case 13:u.push([t-.75-.25*(n+r-2*h)/(r-n),e-.25-.25*(-a-n+2*h)/(a-n)]);break;case 14:u.push([t-.25-.25*(-n-r+2*h)/(n-r),e-.25-.25*(-i-r+2*h)/(i-r)]);break;case 15:u.push([t-.5,e-.5])}},cell:function(t,e,r,n,i,a,o,s,l){i?s.push([t,e]):s.push([e,t])}});return function(t,e){var r=[],i=[];return n(t,r,i,e),{positions:r,cells:i}}}};var o={}},{"ndarray-extract-contour":251,"zero-crossings":318}],303:[function(t,e,r){(function(r){(function(){"use strict";e.exports=function t(e,r,i){i=i||{};var o=a[e];o||(o=a[e]={" ":{data:new Float32Array(0),shape:.2}});var s=o[r];if(!s)if(r.length<=1||!/\d/.test(r))s=o[r]=function(t){for(var e=t.cells,r=t.positions,n=new Float32Array(6*e.length),i=0,a=0,o=0;o0&&(f+=.02);var p=new Float32Array(u),d=0,m=-.5*f;for(h=0;hMath.max(r,n)?i[2]=1:r>Math.max(e,n)?i[0]=1:i[1]=1;for(var a=0,o=0,l=0;l<3;++l)a+=t[l]*t[l],o+=i[l]*t[l];for(l=0;l<3;++l)i[l]-=o/a*t[l];return s(i,i),i}function h(t,e,r,i,a,o,s,l){this.center=n(r),this.up=n(i),this.right=n(a),this.radius=n([o]),this.angle=n([s,l]),this.angle.bounds=[[-1/0,-Math.PI/2],[1/0,Math.PI/2]],this.setDistanceLimits(t,e),this.computedCenter=this.center.curve(0),this.computedUp=this.up.curve(0),this.computedRight=this.right.curve(0),this.computedRadius=this.radius.curve(0),this.computedAngle=this.angle.curve(0),this.computedToward=[0,0,0],this.computedEye=[0,0,0],this.computedMatrix=new Array(16);for(var c=0;c<16;++c)this.computedMatrix[c]=.5;this.recalcMatrix(0)}var p=h.prototype;p.setDistanceLimits=function(t,e){t=t>0?Math.log(t):-1/0,e=e>0?Math.log(e):1/0,e=Math.max(e,t),this.radius.bounds[0][0]=t,this.radius.bounds[1][0]=e},p.getDistanceLimits=function(t){var e=this.radius.bounds[0];return t?(t[0]=Math.exp(e[0][0]),t[1]=Math.exp(e[1][0]),t):[Math.exp(e[0][0]),Math.exp(e[1][0])]},p.recalcMatrix=function(t){this.center.curve(t),this.up.curve(t),this.right.curve(t),this.radius.curve(t),this.angle.curve(t);for(var e=this.computedUp,r=this.computedRight,n=0,i=0,a=0;a<3;++a)i+=e[a]*r[a],n+=e[a]*e[a];var l=Math.sqrt(n),u=0;for(a=0;a<3;++a)r[a]-=e[a]*i/n,u+=r[a]*r[a],e[a]/=l;var f=Math.sqrt(u);for(a=0;a<3;++a)r[a]/=f;var h=this.computedToward;o(h,e,r),s(h,h);var p=Math.exp(this.computedRadius[0]),d=this.computedAngle[0],m=this.computedAngle[1],g=Math.cos(d),v=Math.sin(d),y=Math.cos(m),x=Math.sin(m),b=this.computedCenter,_=g*y,w=v*y,T=x,k=-g*x,A=-v*x,M=y,S=this.computedEye,E=this.computedMatrix;for(a=0;a<3;++a){var L=_*r[a]+w*h[a]+T*e[a];E[4*a+1]=k*r[a]+A*h[a]+M*e[a],E[4*a+2]=L,E[4*a+3]=0}var C=E[1],P=E[5],I=E[9],O=E[2],z=E[6],D=E[10],R=P*D-I*z,F=I*O-C*D,B=C*z-P*O,N=c(R,F,B);R/=N,F/=N,B/=N,E[0]=R,E[4]=F,E[8]=B;for(a=0;a<3;++a)S[a]=b[a]+E[2+4*a]*p;for(a=0;a<3;++a){u=0;for(var j=0;j<3;++j)u+=E[a+4*j]*S[j];E[12+a]=-u}E[15]=1},p.getMatrix=function(t,e){this.recalcMatrix(t);var r=this.computedMatrix;if(e){for(var n=0;n<16;++n)e[n]=r[n];return e}return r};var d=[0,0,0];p.rotate=function(t,e,r,n){if(this.angle.move(t,e,r),n){this.recalcMatrix(t);var i=this.computedMatrix;d[0]=i[2],d[1]=i[6],d[2]=i[10];for(var o=this.computedUp,s=this.computedRight,l=this.computedToward,c=0;c<3;++c)i[4*c]=o[c],i[4*c+1]=s[c],i[4*c+2]=l[c];a(i,i,n,d);for(c=0;c<3;++c)o[c]=i[4*c],s[c]=i[4*c+1];this.up.set(t,o[0],o[1],o[2]),this.right.set(t,s[0],s[1],s[2])}},p.pan=function(t,e,r,n){e=e||0,r=r||0,n=n||0,this.recalcMatrix(t);var i=this.computedMatrix,a=(Math.exp(this.computedRadius[0]),i[1]),o=i[5],s=i[9],l=c(a,o,s);a/=l,o/=l,s/=l;var u=i[0],f=i[4],h=i[8],p=u*a+f*o+h*s,d=c(u-=a*p,f-=o*p,h-=s*p),m=(u/=d)*e+a*r,g=(f/=d)*e+o*r,v=(h/=d)*e+s*r;this.center.move(t,m,g,v);var y=Math.exp(this.computedRadius[0]);y=Math.max(1e-4,y+n),this.radius.set(t,Math.log(y))},p.translate=function(t,e,r,n){this.center.move(t,e||0,r||0,n||0)},p.setMatrix=function(t,e,r,n){var a=1;"number"==typeof r&&(a=0|r),(a<0||a>3)&&(a=1);var o=(a+2)%3;e||(this.recalcMatrix(t),e=this.computedMatrix);var s=e[a],l=e[a+4],f=e[a+8];if(n){var h=Math.abs(s),p=Math.abs(l),d=Math.abs(f),m=Math.max(h,p,d);h===m?(s=s<0?-1:1,l=f=0):d===m?(f=f<0?-1:1,s=l=0):(l=l<0?-1:1,s=f=0)}else{var g=c(s,l,f);s/=g,l/=g,f/=g}var v,y,x=e[o],b=e[o+4],_=e[o+8],w=x*s+b*l+_*f,T=c(x-=s*w,b-=l*w,_-=f*w),k=l*(_/=T)-f*(b/=T),A=f*(x/=T)-s*_,M=s*b-l*x,S=c(k,A,M);if(k/=S,A/=S,M/=S,this.center.jump(t,q,G,Y),this.radius.idle(t),this.up.jump(t,s,l,f),this.right.jump(t,x,b,_),2===a){var E=e[1],L=e[5],C=e[9],P=E*x+L*b+C*_,I=E*k+L*A+C*M;v=R<0?-Math.PI/2:Math.PI/2,y=Math.atan2(I,P)}else{var O=e[2],z=e[6],D=e[10],R=O*s+z*l+D*f,F=O*x+z*b+D*_,B=O*k+z*A+D*M;v=Math.asin(u(R)),y=Math.atan2(B,F)}this.angle.jump(t,y,v),this.recalcMatrix(t);var N=e[2],j=e[6],U=e[10],V=this.computedMatrix;i(V,e);var H=V[15],q=V[12]/H,G=V[13]/H,Y=V[14]/H,W=Math.exp(this.computedRadius[0]);this.center.jump(t,q-N*W,G-j*W,Y-U*W)},p.lastT=function(){return Math.max(this.center.lastT(),this.up.lastT(),this.right.lastT(),this.radius.lastT(),this.angle.lastT())},p.idle=function(t){this.center.idle(t),this.up.idle(t),this.right.idle(t),this.radius.idle(t),this.angle.idle(t)},p.flush=function(t){this.center.flush(t),this.up.flush(t),this.right.flush(t),this.radius.flush(t),this.angle.flush(t)},p.setDistance=function(t,e){e>0&&this.radius.set(t,Math.log(e))},p.lookAt=function(t,e,r,n){this.recalcMatrix(t),e=e||this.computedEye,r=r||this.computedCenter;var i=(n=n||this.computedUp)[0],a=n[1],o=n[2],s=c(i,a,o);if(!(s<1e-6)){i/=s,a/=s,o/=s;var l=e[0]-r[0],f=e[1]-r[1],h=e[2]-r[2],p=c(l,f,h);if(!(p<1e-6)){l/=p,f/=p,h/=p;var d=this.computedRight,m=d[0],g=d[1],v=d[2],y=i*m+a*g+o*v,x=c(m-=y*i,g-=y*a,v-=y*o);if(!(x<.01&&(x=c(m=a*h-o*f,g=o*l-i*h,v=i*f-a*l))<1e-6)){m/=x,g/=x,v/=x,this.up.set(t,i,a,o),this.right.set(t,m,g,v),this.center.set(t,r[0],r[1],r[2]),this.radius.set(t,Math.log(p));var b=a*v-o*g,_=o*m-i*v,w=i*g-a*m,T=c(b,_,w),k=i*l+a*f+o*h,A=m*l+g*f+v*h,M=(b/=T)*l+(_/=T)*f+(w/=T)*h,S=Math.asin(u(k)),E=Math.atan2(M,A),L=this.angle._state,C=L[L.length-1],P=L[L.length-2];C%=2*Math.PI;var I=Math.abs(C+2*Math.PI-E),O=Math.abs(C-E),z=Math.abs(C-2*Math.PI-E);I0?r.pop():new ArrayBuffer(t)}function d(t){return new Uint8Array(p(t),0,t)}function m(t){return new Uint16Array(p(2*t),0,t)}function g(t){return new Uint32Array(p(4*t),0,t)}function v(t){return new Int8Array(p(t),0,t)}function y(t){return new Int16Array(p(2*t),0,t)}function x(t){return new Int32Array(p(4*t),0,t)}function b(t){return new Float32Array(p(4*t),0,t)}function _(t){return new Float64Array(p(8*t),0,t)}function w(t){return o?new Uint8ClampedArray(p(t),0,t):d(t)}function T(t){return s?new BigUint64Array(p(8*t),0,t):null}function k(t){return l?new BigInt64Array(p(8*t),0,t):null}function A(t){return new DataView(p(t),0,t)}function M(t){t=n.nextPow2(t);var e=n.log2(t),r=f[e];return r.length>0?r.pop():new a(t)}r.free=function(t){if(a.isBuffer(t))f[n.log2(t.length)].push(t);else{if("[object ArrayBuffer]"!==Object.prototype.toString.call(t)&&(t=t.buffer),!t)return;var e=t.length||t.byteLength,r=0|n.log2(e);u[r].push(t)}},r.freeUint8=r.freeUint16=r.freeUint32=r.freeBigUint64=r.freeInt8=r.freeInt16=r.freeInt32=r.freeBigInt64=r.freeFloat32=r.freeFloat=r.freeFloat64=r.freeDouble=r.freeUint8Clamped=r.freeDataView=function(t){h(t.buffer)},r.freeArrayBuffer=h,r.freeBuffer=function(t){f[n.log2(t.length)].push(t)},r.malloc=function(t,e){if(void 0===e||"arraybuffer"===e)return p(t);switch(e){case"uint8":return d(t);case"uint16":return m(t);case"uint32":return g(t);case"int8":return v(t);case"int16":return y(t);case"int32":return x(t);case"float":case"float32":return b(t);case"double":case"float64":return _(t);case"uint8_clamped":return w(t);case"bigint64":return k(t);case"biguint64":return T(t);case"buffer":return M(t);case"data":case"dataview":return A(t);default:return null}return null},r.mallocArrayBuffer=p,r.mallocUint8=d,r.mallocUint16=m,r.mallocUint32=g,r.mallocInt8=v,r.mallocInt16=y,r.mallocInt32=x,r.mallocFloat32=r.mallocFloat=b,r.mallocFloat64=r.mallocDouble=_,r.mallocUint8Clamped=w,r.mallocBigUint64=T,r.mallocBigInt64=k,r.mallocDataView=A,r.mallocBuffer=M,r.clearCache=function(){for(var t=0;t<32;++t)c.UINT8[t].length=0,c.UINT16[t].length=0,c.UINT32[t].length=0,c.INT8[t].length=0,c.INT16[t].length=0,c.INT32[t].length=0,c.FLOAT[t].length=0,c.DOUBLE[t].length=0,c.BIGUINT64[t].length=0,c.BIGINT64[t].length=0,c.UINT8C[t].length=0,u[t].length=0,f[t].length=0}}).call(this)}).call(this,void 0!==n?n:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{"bit-twiddle":32,buffer:3,dup:65}],309:[function(t,e,r){"use strict";function n(t){this.roots=new Array(t),this.ranks=new Array(t);for(var e=0;e0&&(a=n.size),n.lineSpacing&&n.lineSpacing>0&&(o=n.lineSpacing),n.styletags&&n.styletags.breaklines&&(s.breaklines=!!n.styletags.breaklines),n.styletags&&n.styletags.bolds&&(s.bolds=!!n.styletags.bolds),n.styletags&&n.styletags.italics&&(s.italics=!!n.styletags.italics),n.styletags&&n.styletags.subscripts&&(s.subscripts=!!n.styletags.subscripts),n.styletags&&n.styletags.superscripts&&(s.superscripts=!!n.styletags.superscripts));return r.font=[n.fontStyle,n.fontVariant,n.fontWeight,a+"px",n.font].filter((function(t){return t})).join(" "),r.textAlign="start",r.textBaseline="alphabetic",r.direction="ltr",h(function(t,e,r,n,a,o){r=r.replace(/\n/g,""),r=!0===o.breaklines?r.replace(/\/g,"\n"):r.replace(/\/g," ");var s="",l=[];for(p=0;p-1?parseInt(t[1+i]):0,l=a>-1?parseInt(r[1+a]):0;s!==l&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,l-s),n=n.replace("?px ",S())),m+=.25*x*(l-s)}if(!0===o.superscripts){var c=t.indexOf("+"),u=r.indexOf("+"),f=c>-1?parseInt(t[1+c]):0,h=u>-1?parseInt(r[1+u]):0;f!==h&&(n=n.replace(S(),"?px "),g*=Math.pow(.75,h-f),n=n.replace("?px ",S())),m-=.25*x*(h-f)}if(!0===o.bolds){var p=t.indexOf("b|")>-1,d=r.indexOf("b|")>-1;!p&&d&&(n=v?n.replace("italic ","italic bold "):"bold "+n),p&&!d&&(n=n.replace("bold ",""))}if(!0===o.italics){var v=t.indexOf("i|")>-1,y=r.indexOf("i|")>-1;!v&&y&&(n="italic "+n),v&&!y&&(n=n.replace("italic ",""))}e.font=n}for(h=0;h",a="",o=i.length,s=a.length,l="+"===e[0]||"-"===e[0],c=0,u=-s;c>-1&&-1!==(c=r.indexOf(i,c))&&-1!==(u=r.indexOf(a,c+o))&&!(u<=c);){for(var f=c;f=u)n[f]=null,r=r.substr(0,f)+" "+r.substr(f+1);else if(null!==n[f]){var h=n[f].indexOf(e[0]);-1===h?n[f]+=e:l&&(n[f]=n[f].substr(0,h+1)+(1+parseInt(n[f][h+1]))+n[f].substr(h+2))}var p=c+o,d=r.substr(p,u-p).indexOf(i);c=-1!==d?d:u+s}return n}function u(t,e){var r=n(t,128);return e?a(r.cells,r.positions,.25):{edges:r.cells,positions:r.positions}}function f(t,e,r,n){var i=u(t,n),a=function(t,e,r){for(var n=e.textAlign||"start",i=e.textBaseline||"alphabetic",a=[1<<30,1<<30],o=[0,0],s=t.length,l=0;l=0?e[a]:i}))},has___:{value:y((function(e){var n=v(e);return n?r in n:t.indexOf(e)>=0}))},set___:{value:y((function(n,i){var a,o=v(n);return o?o[r]=i:(a=t.indexOf(n))>=0?e[a]=i:(a=t.length,e[a]=i,t[a]=n),this}))},delete___:{value:y((function(n){var i,a,o=v(n);return o?r in o&&delete o[r]:!((i=t.indexOf(n))<0)&&(a=t.length-1,t[i]=void 0,e[i]=e[a],t[i]=t[a],t.length=a,e.length=a,!0)}))}})};d.prototype=Object.create(Object.prototype,{get:{value:function(t,e){return this.get___(t,e)},writable:!0,configurable:!0},has:{value:function(t){return this.has___(t)},writable:!0,configurable:!0},set:{value:function(t,e){return this.set___(t,e)},writable:!0,configurable:!0},delete:{value:function(t){return this.delete___(t)},writable:!0,configurable:!0}}),"function"==typeof r?function(){function n(){this instanceof d||x();var e,n=new r,i=void 0,a=!1;return e=t?function(t,e){return n.set(t,e),n.has(t)||(i||(i=new d),i.set(t,e)),this}:function(t,e){if(a)try{n.set(t,e)}catch(r){i||(i=new d),i.set___(t,e)}else n.set(t,e);return this},Object.create(d.prototype,{get___:{value:y((function(t,e){return i?n.has(t)?n.get(t):i.get___(t,e):n.get(t,e)}))},has___:{value:y((function(t){return n.has(t)||!!i&&i.has___(t)}))},set___:{value:y(e)},delete___:{value:y((function(t){var e=!!n.delete(t);return i&&i.delete___(t)||e}))},permitHostObjects___:{value:y((function(t){if(t!==m)throw new Error("bogus call to permitHostObjects___");a=!0}))}})}t&&"undefined"!=typeof Proxy&&(Proxy=void 0),n.prototype=d.prototype,e.exports=n,Object.defineProperty(WeakMap.prototype,"constructor",{value:WeakMap,enumerable:!1,configurable:!0,writable:!0})}():("undefined"!=typeof Proxy&&(Proxy=void 0),e.exports=d)}function m(t){t.permitHostObjects___&&t.permitHostObjects___(m)}function g(t){return!("weakmap:"==t.substr(0,"weakmap:".length)&&"___"===t.substr(t.length-3))}function v(t){if(t!==Object(t))throw new TypeError("Not an object: "+t);var e=t[l];if(e&&e.key===t)return e;if(s(t)){e={key:t};try{return o(t,l,{value:e,writable:!1,enumerable:!1,configurable:!1}),e}catch(t){return}}}function y(t){return t.prototype=null,Object.freeze(t)}function x(){h||"undefined"==typeof console||(h=!0,console.warn("WeakMap should be invoked as new WeakMap(), not WeakMap(). This will be an error in the future."))}}()},{}],314:[function(t,e,r){var n=t("./hidden-store.js");e.exports=function(){var t={};return function(e){if(("object"!=typeof e||null===e)&&"function"!=typeof e)throw new Error("Weakmap-shim: Key must be object");var r=e.valueOf(t);return r&&r.identity===t?r:n(e,t)}}},{"./hidden-store.js":315}],315:[function(t,e,r){e.exports=function(t,e){var r={identity:e},n=t.valueOf;return Object.defineProperty(t,"valueOf",{value:function(t){return t!==e?n.apply(this,arguments):r},writable:!0}),r}},{}],316:[function(t,e,r){var n=t("./create-store.js");e.exports=function(){var t=n();return{get:function(e,r){var n=t(e);return n.hasOwnProperty("value")?n.value:r},set:function(e,r){return t(e).value=r,this},has:function(e){return"value"in t(e)},delete:function(e){return delete t(e).value}}}},{"./create-store.js":314}],317:[function(t,e,r){"use strict";var n,i=function(){return function(t,e,r,n,i,a){var o=t[0],s=r[0],l=[0],c=s;n|=0;var u=0,f=s;for(u=0;u=0!=p>=0&&i.push(l[0]+.5+.5*(h+p)/(h-p)),n+=f,++l[0]}}};e.exports=(n={funcName:{funcName:"zeroCrossings"}.funcName},function(t){var e={};return function(r,n,i){var a=r.dtype,o=r.order,s=[a,o.join()].join(),l=e[s];return l||(e[s]=l=t([a,o])),l(r.shape.slice(0),r.data,r.stride,0|r.offset,n,i)}}(i.bind(void 0,n)))},{}],318:[function(t,e,r){"use strict";e.exports=function(t,e){var r=[];return e=+e||0,n(t.hi(t.shape[0]-1),r,e),r};var n=t("./lib/zc-core")},{"./lib/zc-core":317}]},{},[6])(6)}))}).call(this)}).call(this,"undefined"!=typeof global?global:"undefined"!=typeof self?self:"undefined"!=typeof window?window:{})},{}]},{},[27])(27)}));
\ No newline at end of file diff --git a/docs/source/models/checkpoint_list.csv b/docs/source/models/checkpoint_list.csv index df934fa..77cf863 100644 --- a/docs/source/models/checkpoint_list.csv +++ b/docs/source/models/checkpoint_list.csv @@ -1,4 +1,10 @@ PTLFlow name,Original name,Source +craft-kitti-4d99b0c1.ckpt,craft-kitti.pth,https://github.com/askerlee/craft +craft-sintel-ff8e6563.ckpt,craft-sintel.pth, +craft-things-5a41930c.ckpt,craft-things.pth, +csflow-chairs-458a9436.ckpt,CSFlow-chairs.pth,https://github.com/MasterHow/CSFlow +csflow-kitti-dc66357a.ckpt,CSFlow-kitti.pth, +csflow-things-ebdd403b.ckpt,CSFlow-things.pth, dicl-chairs-fdc24e2f.ckpt,ckpt_chairs.pth.tar,https://github.com/jytime/DICL-Flow dicl-kitti-4813ccab.ckpt,ckpt_kitti.pth.tar, dicl-sintel-fa9fc259.ckpt,ckpt_sintel.pth.tar, @@ -8,6 +14,10 @@ fastflownet-kitti-6d3526a8.ckpt,fastflownet_ft_kitti.pth, fastflownet-mix-fd9b8c0d.ckpt,fastflownet_ft_mix.pth, fastflownet-sintel-6475ea96.ckpt,fastflownet_ft_sintel.pth, fastflownet-things3d-fc093d29.ckpt,fastflownet_things.pth, +flowformer-chairs-2b34ea4b.ckpt,chairs.pth,https://github.com/drinkingcoder/FlowFormer-Official +flowformer-kitti-1e45a6c8.ckpt,kitti.pth, +flowformer-sintel-27cc959a.ckpt,sintel.pth, +flowformer-things-ab5f3255.ckpt,things.pth, flownet2-things-d63b53a7.ckpt,FlowNet2_checkpoint.pth.tar,https://github.com/NVIDIA/flownet2-pytorch flownetc-things-cc8ac7fd.ckpt,FlowNet2-C_checkpoint, flownetcs-things-4bdecffa.ckpt,FlowNet2-CS_checkpoint, @@ -18,6 +28,18 @@ gma-chairs-d4ec321d.ckpt,gma-chairs.pth,https://github.com/zacjiang/GMA gma-things-90aafb63.ckpt,gma-things.pth, gma-sintel-98d6f3d0.ckpt,gma-sintel.pth, gma-kitti-8ca3ec80.ckpt,gma-kitti.pth, +gmflow-chairs-4922131e.ckpt,gmflow_chairs-1d776046.pth,https://github.com/haofeixu/gmflow +gmflow-kitti-af50eb2e.ckpt,gmflow_kitti-285701a8.pth, +gmflow-sintel-d6f83ccd.ckpt,gmflow_sintel-0c07dcb3.pth, +gmflow-things-5a18a9e8.ckpt,gmflow_things-e9887eda.pth, +gmflow_refine-chairs-88cdc009.ckpt,gmflow_with_refine_chairs-020cc9be.pth, +gmflow_refine-kitti-b7bf2fda.ckpt,gmflow_with_refine_kitti-8d3b9786.pth, +gmflow_refine-sintel-ee46a2c4.ckpt,gmflow_with_refine_sintel-3ed1cf48.pth, +gmflow_refine-things-e40899f5.ckpt,gmflow_with_refine_things-36579974.pth, +gmflownet-kitti-712b4660.ckpt,gmflownet-kitti.pth,https://github.com/xiaofeng94/GMFlowNet +gmflownet-things-9f061ac7.ckpt,gmflownet-things.pth, +gmflownet_mix-sintel-33492618.ckpt,gmflownet_mix-sintel.pth, +gmflownet_mix-things-8396f0a1.ckpt,gmflownet_mix-things.pth, hd3-chairs-0d46c9fd.ckpt,hd3f_chairs-04bf114d.pth,https://github.com/ucbdrive/hd3 hd3-kitti-6eb77dd3.ckpt,hd3f_chairs_things-462a3896.pth, hd3-sintel-10689995.ckpt,hd3f_chairs_things_kitti-41b15827.pth, diff --git a/docs/source/models/models_list.rst b/docs/source/models/models_list.rst index 2b08556..93cb189 100644 --- a/docs/source/models/models_list.rst +++ b/docs/source/models/models_list.rst @@ -7,6 +7,28 @@ Below is a list and a brief explanation about the models currently available on List of models ============== +CRAFT +----- + +`[source code] `__ + +- Paper: **CRAFT: Cross-Attentional Flow Transformers for Robust Optical Flow** - `https://arxiv.org/abs/2203.16896 `_ + +- Reference code: `https://github.com/askerlee/craft `_ + +- Model names: ``craft`` + +CSFlow +------ + +`[source code] `__ + +- Paper: **CSFlow: Learning optical flow via cross strip correlation for autonomous driving* - `https://arxiv.org/abs/2202.00909 `_ + +- Reference code: `https://github.com/MasterHow/CSFlow `_ + +- Model names: ``csflow`` + DICL-Flow --------- @@ -29,6 +51,17 @@ FastFlownet - Model names: ``fastflownet`` +FlowFormer +---------- + +`[source code] `__ + +- Paper: **FlowFormer: A Transformer Architecture for Optical Flow* - `https://arxiv.org/abs/2203.16194 `_ + +- Reference code: `https://github.com/drinkingcoder/FlowFormer-Official `_ + +- Model names: ``flowformer`` + Flownet ------- @@ -55,6 +88,28 @@ GMA - Model names: ``gma`` +GMFlow +------ + +`[source code] `__ + +- Paper: **GMFlow: Learning Optical Flow via Global Matching** - `https://arxiv.org/abs/2111.13680 `_ + +- Reference code: `https://github.com/haofeixu/gmflow `_ + +- Model names: ``gmflow``, ``gmflow_refine`` + +GMFlowNet +--------- + +`[source code] `__ + +- Paper: **Global Matching with Overlapping Attention for Optical Flow Estimation** - `https://arxiv.org/abs/2203.11335 `_ + +- Reference code: `https://github.com/xiaofeng94/GMFlowNet `_ + +- Model names: ``gmflownet``, ``gmflownet_mix`` + HD3 --- diff --git a/docs/source/results/metrics_all.csv b/docs/source/results/metrics_all.csv index f96c891..a1c6c3f 100644 --- a/docs/source/results/metrics_all.csv +++ b/docs/source/results/metrics_all.csv @@ -1,75 +1,97 @@ -model,checkpoint,sintel-final-trainval-occ-val/epe,sintel-final-trainval-occ-val/px1,sintel-final-trainval-occ-val/px3,sintel-final-trainval-occ-val/px5,sintel-final-trainval-occ-val/outlier,sintel-final-trainval-occ-val/epe_occ,sintel-final-trainval-occ-val/epe_non_occ,sintel-final-trainval-occ-val/px1_occ,sintel-final-trainval-occ-val/px1_non_occ,sintel-final-trainval-occ-val/px3_occ,sintel-final-trainval-occ-val/px3_non_occ,sintel-final-trainval-occ-val/px5_occ,sintel-final-trainval-occ-val/px5_non_occ,sintel-final-trainval-occ-val/outlier_occ,sintel-final-trainval-occ-val/outlier_non_occ,sintel-final-trainval-occ-val/occ_f1,sintel-final-trainval-occ-val/conf_f1,sintel-clean-trainval-occ-val/epe,sintel-clean-trainval-occ-val/px1,sintel-clean-trainval-occ-val/px3,sintel-clean-trainval-occ-val/px5,sintel-clean-trainval-occ-val/outlier,sintel-clean-trainval-occ-val/epe_occ,sintel-clean-trainval-occ-val/epe_non_occ,sintel-clean-trainval-occ-val/px1_occ,sintel-clean-trainval-occ-val/px1_non_occ,sintel-clean-trainval-occ-val/px3_occ,sintel-clean-trainval-occ-val/px3_non_occ,sintel-clean-trainval-occ-val/px5_occ,sintel-clean-trainval-occ-val/px5_non_occ,sintel-clean-trainval-occ-val/outlier_occ,sintel-clean-trainval-occ-val/outlier_non_occ,sintel-clean-trainval-occ-val/occ_f1,sintel-clean-trainval-occ-val/conf_f1,kitti-2012-trainval-val/epe,kitti-2012-trainval-val/px1,kitti-2012-trainval-val/px3,kitti-2012-trainval-val/px5,kitti-2012-trainval-val/outlier,kitti-2012-trainval-val/conf_f1,kitti-2015-trainval-val/epe,kitti-2015-trainval-val/px1,kitti-2015-trainval-val/px3,kitti-2015-trainval-val/px5,kitti-2015-trainval-val/outlier,kitti-2015-trainval-val/conf_f1 -dicl,chairs,4.9470000000000001,0.751,0.86499999999999999,0.89100000000000001,0.13200000000000001,11.704000000000001,3.6989999999999998,0.372,0.78000000000000003,0.61599999999999999,0.89100000000000001,0.69699999999999995,0.91500000000000004,0.373,0.107,,,3.661,0.81999999999999995,0.90900000000000003,0.92700000000000005,0.088999999999999996,10.273999999999999,2.3889999999999998,0.42699999999999999,0.85299999999999998,0.67000000000000004,0.93600000000000005,0.74199999999999999,0.95099999999999996,0.31900000000000001,0.063,,,5.968,0.44,0.71199999999999997,0.79300000000000004,0.28100000000000003,,18.445,0.376,0.54100000000000004,0.60199999999999998,0.45600000000000002, -dicl,kitti,9.5980000000000008,0.63400000000000001,0.76500000000000001,0.80400000000000005,0.23400000000000001,16.164999999999999,8.4700000000000006,0.29899999999999999,0.65300000000000002,0.51000000000000001,0.78500000000000003,0.59999999999999998,0.82199999999999995,0.48199999999999998,0.214,,,7.798,0.68999999999999995,0.80400000000000005,0.83799999999999997,0.19500000000000001,14.635,6.5579999999999998,0.318,0.71399999999999997,0.54000000000000004,0.82699999999999996,0.63,0.85899999999999999,0.45100000000000001,0.17199999999999999,,,1.161,0.82899999999999996,0.93799999999999994,0.96099999999999997,0.045999999999999999,,1.393,0.77900000000000003,0.91300000000000003,0.94599999999999995,0.058000000000000003, -dicl,sintel,2.0499999999999998,0.85599999999999998,0.92200000000000004,0.94099999999999995,0.074999999999999997,6.0309999999999997,1.397,0.496,0.88600000000000001,0.71899999999999997,0.94499999999999995,0.79300000000000004,0.95999999999999996,0.26800000000000002,0.052999999999999999,,,1.2969999999999999,0.90200000000000002,0.95099999999999996,0.96399999999999997,0.045999999999999999,4.8940000000000001,0.67800000000000005,0.54600000000000004,0.93600000000000005,0.76000000000000001,0.97499999999999998,0.82799999999999996,0.98199999999999998,0.22600000000000001,0.024,,,2.161,0.75900000000000001,0.90400000000000003,0.93799999999999994,0.081000000000000003,,5.6559999999999997,0.66300000000000003,0.81599999999999995,0.85899999999999999,0.16400000000000001, -dicl,things,3.831,0.81100000000000005,0.88900000000000001,0.91300000000000003,0.107,9.9030000000000005,2.7090000000000001,0.41799999999999998,0.84299999999999997,0.64800000000000002,0.91500000000000004,0.72699999999999998,0.93500000000000005,0.34100000000000003,0.082000000000000003,,,2.0059999999999998,0.878,0.93899999999999995,0.95399999999999996,0.058999999999999997,6.6909999999999998,1.121,0.48799999999999999,0.91400000000000003,0.72999999999999998,0.96399999999999997,0.80100000000000005,0.97399999999999998,0.25800000000000001,0.034000000000000002,,,3.7450000000000001,0.69499999999999995,0.84899999999999998,0.88900000000000001,0.14000000000000001,,9.9060000000000006,0.59899999999999998,0.75,0.79100000000000004,0.23999999999999999, -fastflownet,chairs,4.3799999999999999,0.73399999999999999,0.85699999999999998,0.89100000000000001,0.14000000000000001,10.960000000000001,3.1619999999999999,0.32300000000000001,0.76400000000000001,0.57099999999999995,0.88600000000000001,0.66800000000000004,0.91700000000000004,0.41899999999999998,0.111,,,3.2519999999999998,0.79000000000000004,0.89600000000000002,0.92200000000000004,0.10100000000000001,9.7189999999999994,2.0110000000000001,0.35399999999999998,0.82599999999999996,0.62,0.92800000000000005,0.71299999999999997,0.94899999999999995,0.36899999999999999,0.070000000000000007,,,5.7599999999999998,0.51100000000000001,0.75700000000000001,0.82399999999999995,0.23599999999999999,,14.374000000000001,0.42799999999999999,0.628,0.69499999999999995,0.36799999999999999, -fastflownet,kitti,5.1470000000000002,0.73899999999999999,0.84499999999999997,0.876,0.152,12.287000000000001,3.903,0.32900000000000001,0.76800000000000002,0.54900000000000004,0.874,0.63700000000000001,0.90200000000000002,0.44,0.124,,,3.952,0.78800000000000003,0.88200000000000001,0.90600000000000003,0.11600000000000001,10.859,2.7080000000000002,0.35599999999999998,0.82299999999999995,0.58599999999999997,0.91300000000000003,0.67400000000000004,0.93400000000000005,0.40300000000000002,0.085000000000000006,,,1.534,0.77500000000000002,0.91800000000000004,0.94799999999999995,0.065000000000000002,,2.198,0.72299999999999998,0.88500000000000001,0.92400000000000004,0.086999999999999994, -fastflownet,mix,2.754,0.79700000000000004,0.89100000000000001,0.91800000000000004,0.105,7.8140000000000001,1.9219999999999999,0.38100000000000001,0.83099999999999996,0.63100000000000001,0.91900000000000004,0.72199999999999998,0.94199999999999995,0.35599999999999998,0.076999999999999999,,,2.641,0.83699999999999997,0.91800000000000004,0.93799999999999994,0.080000000000000002,8.2889999999999997,1.617,0.41299999999999998,0.875,0.66500000000000004,0.94799999999999995,0.748,0.96299999999999997,0.32300000000000001,0.050999999999999997,,,2.6019999999999999,0.69699999999999995,0.88,0.92000000000000004,0.105,,5.5099999999999998,0.60599999999999998,0.79300000000000004,0.84499999999999997,0.187, -fastflownet,sintel,2.79,0.79600000000000004,0.89100000000000001,0.91800000000000004,0.105,7.8559999999999999,1.952,0.38100000000000001,0.82999999999999996,0.63200000000000001,0.91900000000000004,0.72199999999999998,0.94199999999999995,0.35599999999999998,0.078,,,2.7290000000000001,0.83599999999999997,0.91700000000000004,0.93700000000000006,0.081000000000000003,8.5570000000000004,1.6599999999999999,0.41299999999999998,0.874,0.66400000000000003,0.94699999999999995,0.746,0.96199999999999997,0.32500000000000001,0.051999999999999998,,,5.2270000000000003,0.61599999999999999,0.81499999999999995,0.86099999999999999,0.17899999999999999,,13.731,0.50700000000000001,0.68200000000000005,0.73099999999999998,0.314, -fastflownet,things,4.2699999999999996,0.73899999999999999,0.85999999999999999,0.89300000000000002,0.13600000000000001,10.553000000000001,3.1160000000000001,0.317,0.77000000000000002,0.57299999999999995,0.88900000000000001,0.67100000000000004,0.91900000000000004,0.41599999999999998,0.108,,,2.931,0.79600000000000004,0.90100000000000002,0.92700000000000005,0.096000000000000002,8.8599999999999994,1.8109999999999999,0.35099999999999998,0.83299999999999996,0.626,0.93200000000000005,0.72199999999999998,0.95299999999999996,0.36199999999999999,0.066000000000000003,,,5.5279999999999996,0.55700000000000005,0.77500000000000002,0.83299999999999996,0.218,,13.134,0.46500000000000002,0.65700000000000003,0.71799999999999997,0.33800000000000002, -flownet2,things,3.9700000000000002,0.77200000000000002,0.872,0.90100000000000002,0.124,9.3379999999999992,2.9780000000000002,0.41299999999999998,0.79900000000000004,0.63300000000000001,0.89600000000000002,0.71699999999999997,0.92200000000000004,0.35499999999999998,0.10100000000000001,,,3.0049999999999999,0.82099999999999995,0.90700000000000003,0.92900000000000005,0.089999999999999997,7.867,2.101,0.45500000000000002,0.85099999999999998,0.68500000000000005,0.93100000000000005,0.76400000000000001,0.94999999999999996,0.30399999999999999,0.066000000000000003,,,5.468,0.59899999999999998,0.77500000000000002,0.82299999999999995,0.219,,13.143000000000001,0.499,0.65000000000000002,0.69899999999999995,0.34599999999999997, -flownetc,things,5.6470000000000002,0.46600000000000003,0.76000000000000001,0.83199999999999996,0.23799999999999999,11.412000000000001,4.5789999999999997,0.189,0.48099999999999998,0.49099999999999999,0.78200000000000003,0.61399999999999999,0.85399999999999998,0.501,0.216,,,4.5449999999999999,0.47899999999999998,0.79700000000000004,0.86499999999999999,0.20000000000000001,9.9149999999999991,3.5569999999999999,0.19500000000000001,0.49399999999999999,0.52600000000000002,0.82199999999999995,0.65800000000000003,0.88900000000000001,0.46600000000000003,0.17499999999999999,,,8.1039999999999992,0.19800000000000001,0.56999999999999995,0.68899999999999995,0.42599999999999999,,15.984999999999999,0.17899999999999999,0.46200000000000002,0.56899999999999995,0.53600000000000003, -flownetcs,things,4.1639999999999997,0.73499999999999999,0.85699999999999998,0.89200000000000002,0.14000000000000001,9.4979999999999993,3.21,0.33700000000000002,0.76300000000000001,0.59999999999999998,0.88200000000000001,0.69899999999999995,0.91500000000000004,0.38900000000000001,0.115,,,3.1309999999999998,0.78700000000000003,0.89300000000000002,0.92100000000000004,0.10299999999999999,8.0890000000000004,2.226,0.375,0.81799999999999995,0.65200000000000002,0.91900000000000004,0.747,0.94299999999999995,0.33600000000000002,0.078,,,5.0279999999999996,0.55800000000000005,0.77400000000000002,0.83099999999999996,0.218,,12.516999999999999,0.45200000000000001,0.63900000000000001,0.69899999999999995,0.35599999999999998, -flownetcss,things,4.0209999999999999,0.754,0.86399999999999999,0.89800000000000002,0.13200000000000001,9.2989999999999995,3.0609999999999999,0.36599999999999999,0.78200000000000003,0.61699999999999999,0.88900000000000001,0.71099999999999997,0.91900000000000004,0.372,0.108,,,2.9889999999999999,0.80900000000000005,0.90100000000000002,0.92500000000000004,0.096000000000000002,7.8630000000000004,2.0960000000000001,0.41199999999999998,0.83999999999999997,0.67100000000000004,0.92600000000000005,0.75800000000000001,0.94599999999999995,0.318,0.071999999999999995,,,4.657,0.60099999999999998,0.79000000000000004,0.84099999999999997,0.20200000000000001,,11.901,0.49099999999999999,0.65700000000000003,0.71099999999999997,0.33800000000000002, -flownets,things,5.2439999999999998,0.56699999999999995,0.79300000000000004,0.85099999999999998,0.20499999999999999,11.214,4.1529999999999996,0.23000000000000001,0.58399999999999996,0.51700000000000002,0.81599999999999995,0.63300000000000001,0.874,0.47499999999999998,0.182,,,3.9580000000000002,0.58799999999999997,0.81999999999999995,0.876,0.17799999999999999,9.4640000000000004,2.9430000000000001,0.24199999999999999,0.60699999999999998,0.54600000000000004,0.84499999999999997,0.66500000000000004,0.90000000000000002,0.44600000000000001,0.152,,,7.7229999999999999,0.21199999999999999,0.56100000000000005,0.68300000000000005,0.435,,14.811999999999999,0.23599999999999999,0.47899999999999998,0.57999999999999996,0.51900000000000002, -flownetsd,things,7.8209999999999997,0.65600000000000003,0.78100000000000003,0.82099999999999995,0.219,14.407999999999999,6.6870000000000003,0.313,0.67400000000000004,0.52300000000000002,0.80200000000000005,0.61099999999999999,0.84099999999999997,0.46999999999999997,0.19800000000000001,,,7.6159999999999997,0.68700000000000006,0.80000000000000004,0.83599999999999997,0.20000000000000001,13.936999999999999,6.4829999999999997,0.33300000000000002,0.70699999999999996,0.55200000000000005,0.82099999999999995,0.64200000000000002,0.85599999999999998,0.441,0.17899999999999999,,,17.257999999999999,0.20699999999999999,0.41399999999999998,0.495,0.58599999999999997,,24.210000000000001,0.27000000000000002,0.42399999999999999,0.48899999999999999,0.57599999999999996, -gma,chairs,4.1349999999999998,0.72099999999999997,0.88600000000000001,0.91400000000000003,0.11,10.129,3.0019999999999998,0.38100000000000001,0.747,0.65700000000000003,0.91200000000000003,0.73399999999999999,0.93600000000000005,0.33100000000000002,0.085000000000000006,,,2.371,0.80000000000000004,0.92700000000000005,0.94899999999999995,0.069000000000000006,7.0220000000000002,1.5289999999999999,0.45600000000000002,0.82799999999999996,0.71199999999999997,0.95399999999999996,0.79100000000000004,0.96999999999999997,0.27300000000000002,0.043999999999999997,,,4.5999999999999996,0.36399999999999999,0.72199999999999998,0.82099999999999995,0.26800000000000002,,9.9830000000000005,0.38,0.63900000000000001,0.72499999999999998,0.35199999999999998, -gma,things,2.867,0.84599999999999997,0.91200000000000003,0.93200000000000005,0.084000000000000005,7.3970000000000002,2.0310000000000001,0.52900000000000003,0.874,0.72799999999999998,0.93400000000000005,0.79000000000000004,0.95099999999999996,0.26000000000000001,0.063,,,1.4099999999999999,0.90200000000000002,0.95599999999999996,0.96899999999999997,0.041000000000000002,4.6790000000000003,0.80800000000000005,0.59899999999999998,0.93300000000000005,0.79500000000000004,0.97699999999999998,0.85299999999999998,0.98499999999999999,0.191,0.021999999999999999,,,2.0699999999999998,0.70299999999999996,0.88900000000000001,0.93000000000000005,0.097000000000000003,,4.7539999999999996,0.63,0.81499999999999995,0.86199999999999999,0.16600000000000001, -gma,sintel,1.3839999999999999,0.88500000000000001,0.94399999999999995,0.96099999999999997,0.051999999999999998,4.1079999999999997,0.91000000000000003,0.60599999999999998,0.91000000000000003,0.79400000000000004,0.96199999999999997,0.85199999999999998,0.97499999999999998,0.191,0.035000000000000003,,,0.72299999999999998,0.92500000000000004,0.96899999999999997,0.97899999999999998,0.028000000000000001,2.8889999999999998,0.36299999999999999,0.65900000000000003,0.95299999999999996,0.83499999999999996,0.98599999999999999,0.88600000000000001,0.99199999999999999,0.14999999999999999,0.012,,,1.3129999999999999,0.80300000000000005,0.93799999999999994,0.96399999999999997,0.044999999999999998,,1.5489999999999999,0.76900000000000002,0.92000000000000004,0.95099999999999996,0.052999999999999999, -gma,kitti,6.577,0.76200000000000001,0.85799999999999998,0.88500000000000001,0.14000000000000001,13.73,5.2149999999999999,0.40000000000000002,0.79000000000000004,0.59999999999999998,0.88400000000000001,0.67700000000000005,0.90800000000000003,0.38900000000000001,0.115,,,4.75,0.80900000000000005,0.89600000000000002,0.91700000000000004,0.10199999999999999,11.076000000000001,3.5590000000000002,0.438,0.84099999999999997,0.64600000000000002,0.92300000000000004,0.72099999999999997,0.94099999999999995,0.34200000000000003,0.075999999999999998,,,1.526,0.81200000000000006,0.93700000000000006,0.95999999999999996,0.049000000000000002,,0.76700000000000002,0.85899999999999999,0.95799999999999996,0.97699999999999998,0.021999999999999999, -hd3,chairs,9.7219999999999995,0.60499999999999998,0.83399999999999996,0.85699999999999998,0.16300000000000001,18.102,8.202,0.28100000000000003,0.628,0.55800000000000005,0.86299999999999999,0.63700000000000001,0.88200000000000001,0.432,0.13500000000000001,,,4.9180000000000001,0.76800000000000002,0.90200000000000002,0.91800000000000004,0.096000000000000002,11.970000000000001,3.5699999999999998,0.34100000000000003,0.80300000000000005,0.63100000000000001,0.93400000000000005,0.71299999999999997,0.94499999999999995,0.35799999999999998,0.065000000000000002,,,12.218999999999999,0.312,0.57399999999999995,0.63500000000000001,0.42299999999999999,,21.716000000000001,0.26800000000000002,0.48299999999999998,0.53300000000000003,0.51500000000000001, -hd3,kitti,44.813000000000002,0.59099999999999997,0.69799999999999995,0.71999999999999997,0.30099999999999999,57.817999999999998,42.808999999999997,0.27800000000000002,0.60699999999999998,0.46999999999999997,0.71499999999999997,0.53700000000000003,0.73599999999999999,0.52200000000000002,0.28399999999999997,,,37.061999999999998,0.64000000000000001,0.73199999999999998,0.75,0.26700000000000002,49.043999999999997,35.033999999999999,0.29599999999999999,0.66200000000000003,0.501,0.752,0.56799999999999995,0.76800000000000002,0.48999999999999999,0.247,,,1.262,0.82699999999999996,0.93600000000000005,0.95899999999999996,0.048000000000000001,,1.9430000000000001,0.72599999999999998,0.90200000000000002,0.93700000000000006,0.067000000000000004, -hd3,sintel,1.603,0.85999999999999999,0.93100000000000005,0.95099999999999996,0.065000000000000002,5.0069999999999997,1.0329999999999999,0.48799999999999999,0.89400000000000002,0.73199999999999998,0.95599999999999996,0.81100000000000005,0.96999999999999997,0.254,0.041000000000000002,,,2.3109999999999999,0.89200000000000002,0.94099999999999995,0.95299999999999996,0.058000000000000003,7.3529999999999998,1.2989999999999999,0.50800000000000001,0.93000000000000005,0.73699999999999999,0.96799999999999997,0.80400000000000005,0.97499999999999998,0.253,0.032000000000000001,,,6.3010000000000002,0.64700000000000002,0.81399999999999995,0.84899999999999998,0.18099999999999999,,15.291,0.56599999999999995,0.70699999999999996,0.73899999999999999,0.28999999999999998, -hd3,things,6.5010000000000003,0.77900000000000003,0.85999999999999999,0.88300000000000001,0.13700000000000001,15.18,4.8380000000000001,0.35499999999999998,0.81200000000000006,0.57599999999999996,0.89000000000000001,0.65400000000000003,0.91100000000000003,0.41399999999999998,0.107,,,3.214,0.86899999999999999,0.92800000000000005,0.94199999999999995,0.070000000000000007,9.4830000000000005,1.948,0.45200000000000001,0.91000000000000003,0.69699999999999995,0.95799999999999996,0.77100000000000002,0.96699999999999997,0.28999999999999998,0.041000000000000002,,,6.8540000000000001,0.626,0.79100000000000004,0.82999999999999996,0.20300000000000001,,14.505000000000001,0.57299999999999995,0.72099999999999997,0.75700000000000001,0.27200000000000002, -hd3_ctxt,chairs,5.7919999999999998,0.627,0.85899999999999999,0.88400000000000001,0.13900000000000001,12.77,4.4740000000000002,0.29699999999999999,0.64800000000000002,0.60299999999999998,0.88600000000000001,0.68500000000000005,0.90700000000000003,0.38700000000000001,0.113,,,3.7269999999999999,0.753,0.91200000000000003,0.92900000000000005,0.085999999999999993,10.167,2.4630000000000001,0.36499999999999999,0.78200000000000003,0.66400000000000003,0.93999999999999995,0.746,0.95299999999999996,0.32400000000000001,0.058999999999999997,,,13.695,0.20200000000000001,0.49299999999999999,0.56299999999999994,0.505,,22.971,0.22600000000000001,0.42899999999999999,0.48199999999999998,0.56999999999999995, -hd3_ctxt,kitti,7.8579999999999997,0.70799999999999996,0.82799999999999996,0.85999999999999999,0.16900000000000001,15.295999999999999,6.5030000000000001,0.35399999999999998,0.73199999999999998,0.56899999999999995,0.85299999999999998,0.65000000000000002,0.88200000000000001,0.42099999999999999,0.14499999999999999,,,6.016,0.76900000000000002,0.872,0.89500000000000002,0.126,13.042,4.7149999999999999,0.38800000000000001,0.79800000000000004,0.61199999999999999,0.89900000000000002,0.68899999999999995,0.91800000000000004,0.377,0.099000000000000005,,,0.999,0.85499999999999998,0.94699999999999995,0.96599999999999997,0.037999999999999999,,1.536,0.75600000000000001,0.92200000000000004,0.95099999999999996,0.051999999999999998, -hd3_ctxt,sintel,1.7370000000000001,0.86799999999999999,0.93400000000000005,0.95199999999999996,0.062,5.2359999999999998,1.1479999999999999,0.51200000000000001,0.89900000000000002,0.74399999999999999,0.95599999999999996,0.81799999999999995,0.96999999999999997,0.24199999999999999,0.040000000000000001,,,2.1030000000000002,0.89600000000000002,0.94499999999999995,0.95699999999999996,0.053999999999999999,6.7320000000000002,1.1830000000000001,0.53300000000000003,0.93100000000000005,0.755,0.96899999999999997,0.82099999999999995,0.97599999999999998,0.23400000000000001,0.029999999999999999,,,5.0579999999999998,0.64800000000000002,0.84999999999999998,0.88900000000000001,0.14099999999999999,,13.449999999999999,0.55700000000000005,0.72499999999999998,0.76600000000000001,0.26900000000000002, -hd3_ctxt,things,4.4210000000000003,0.80600000000000005,0.88200000000000001,0.90300000000000002,0.11600000000000001,11.506,3.0710000000000002,0.40600000000000003,0.83899999999999997,0.622,0.90900000000000003,0.69699999999999995,0.92800000000000005,0.36799999999999999,0.087999999999999995,,,2.0720000000000001,0.88400000000000001,0.93999999999999995,0.95299999999999996,0.058000000000000003,6.8099999999999996,1.1619999999999999,0.51000000000000001,0.92000000000000004,0.74099999999999999,0.96499999999999997,0.80700000000000005,0.97399999999999998,0.247,0.034000000000000002,,,4.6449999999999996,0.65800000000000003,0.82999999999999996,0.86799999999999999,0.159,,9.9589999999999996,0.59099999999999997,0.75,0.78900000000000003,0.23999999999999999, -irr_pwc,chairs_occ,3.9470000000000001,0.80200000000000005,0.88300000000000001,0.90900000000000003,0.113,10.004,2.8279999999999998,0.41099999999999998,0.83199999999999996,0.63,0.90900000000000003,0.70999999999999996,0.93100000000000005,0.35799999999999998,0.087999999999999995,0.70699999999999996,,2.3149999999999999,0.85599999999999998,0.92600000000000005,0.94399999999999995,0.070999999999999994,7.5190000000000001,1.323,0.45800000000000002,0.89000000000000001,0.69299999999999995,0.95299999999999996,0.77100000000000002,0.96599999999999997,0.29499999999999998,0.045999999999999999,0.73899999999999999,,3.887,0.56599999999999995,0.81000000000000005,0.87,0.17999999999999999,,10.672000000000001,0.49099999999999999,0.68799999999999994,0.751,0.30399999999999999, -irr_pwc,kitti,8.1709999999999994,0.71599999999999997,0.81999999999999995,0.84899999999999998,0.17899999999999999,15.455,6.7329999999999997,0.373,0.73899999999999999,0.58199999999999996,0.84099999999999997,0.66200000000000003,0.86799999999999999,0.40799999999999997,0.157,0.68700000000000006,,7.4470000000000001,0.76500000000000001,0.85599999999999998,0.88,0.14299999999999999,15.721,5.5999999999999996,0.39800000000000002,0.79200000000000004,0.61399999999999999,0.88,0.68999999999999995,0.90100000000000002,0.376,0.11899999999999999,0.71099999999999997,,1.1279999999999999,0.84399999999999997,0.94599999999999995,0.96599999999999997,0.039,,1.522,0.79800000000000004,0.92300000000000004,0.94999999999999996,0.051999999999999998, -irr_pwc,sintel,2.4430000000000001,0.84899999999999998,0.91600000000000004,0.93700000000000006,0.080000000000000002,6.6660000000000004,1.7150000000000001,0.502,0.878,0.71799999999999997,0.93700000000000006,0.78900000000000003,0.95499999999999996,0.26900000000000002,0.058999999999999997,0.76200000000000001,,1.8500000000000001,0.89000000000000001,0.94399999999999995,0.95699999999999996,0.053999999999999999,6.194,1.0569999999999999,0.53700000000000003,0.92300000000000004,0.748,0.96699999999999997,0.81399999999999995,0.97499999999999998,0.23899999999999999,0.032000000000000001,0.77300000000000002,,2.581,0.73799999999999999,0.89500000000000002,0.92800000000000005,0.090999999999999998,,7.968,0.625,0.78400000000000003,0.82299999999999995,0.20200000000000001, -irr_pwc,things,3.4039999999999999,0.81000000000000005,0.89300000000000002,0.91800000000000004,0.104,8.5660000000000007,2.46,0.435,0.83899999999999997,0.66100000000000003,0.91700000000000004,0.73999999999999999,0.93899999999999995,0.32700000000000001,0.080000000000000002,0.72499999999999998,,1.8560000000000001,0.87,0.93600000000000005,0.95199999999999996,0.062,6.3019999999999996,1.0349999999999999,0.49199999999999999,0.90400000000000003,0.72199999999999998,0.95999999999999996,0.79600000000000004,0.97199999999999998,0.26500000000000001,0.037999999999999999,0.752,,3.5499999999999998,0.66400000000000003,0.84199999999999997,0.88600000000000001,0.14599999999999999,,9.5079999999999991,0.55600000000000005,0.72599999999999998,0.77700000000000002,0.26400000000000001, -irr_pwcnet,things,4.4109999999999996,0.77000000000000002,0.86799999999999999,0.89700000000000002,0.129,10.869,3.2109999999999999,0.35799999999999998,0.80200000000000005,0.59699999999999998,0.89600000000000002,0.68700000000000006,0.92200000000000004,0.39200000000000002,0.10199999999999999,,,3.0899999999999999,0.81999999999999995,0.90700000000000003,0.92900000000000005,0.090999999999999998,9.1140000000000008,1.964,0.39500000000000002,0.85599999999999998,0.64600000000000002,0.93600000000000005,0.73399999999999999,0.95299999999999996,0.34200000000000003,0.063,,,5.9669999999999996,0.53800000000000003,0.77400000000000002,0.83399999999999996,0.219,,14.707000000000001,0.45200000000000001,0.63500000000000001,0.69499999999999995,0.36099999999999999, -irr_pwcnet_irr,things,4.0519999999999996,0.755,0.86699999999999999,0.90000000000000002,0.129,10.112,2.968,0.35699999999999998,0.78400000000000003,0.59899999999999998,0.89400000000000002,0.69099999999999995,0.92400000000000004,0.39000000000000001,0.10299999999999999,,,2.734,0.80900000000000005,0.90800000000000003,0.93200000000000005,0.088999999999999996,8.4469999999999992,1.6919999999999999,0.39400000000000002,0.84199999999999997,0.64700000000000002,0.93600000000000005,0.73799999999999999,0.95599999999999996,0.34000000000000002,0.060999999999999999,,,5.1449999999999996,0.50900000000000001,0.76000000000000001,0.82999999999999996,0.23200000000000001,,12.983000000000001,0.441,0.63600000000000001,0.69999999999999996,0.35799999999999998, -lcv_raft,chairs,4.0339999999999998,0.753,0.88200000000000001,0.91100000000000003,0.115,9.9369999999999994,2.9300000000000002,0.40899999999999997,0.77900000000000003,0.65500000000000003,0.90700000000000003,0.73199999999999998,0.93300000000000005,0.33400000000000002,0.089999999999999997,,,2.282,0.81000000000000005,0.92400000000000004,0.94699999999999995,0.072999999999999995,7.0090000000000003,1.3879999999999999,0.46899999999999997,0.83899999999999997,0.71399999999999997,0.94899999999999995,0.79000000000000004,0.96799999999999997,0.27200000000000002,0.048000000000000001,,,4.3710000000000004,0.35999999999999999,0.71299999999999997,0.82399999999999995,0.27700000000000002,,9.2129999999999992,0.35899999999999999,0.625,0.72299999999999998,0.36699999999999999, -lcv_raft,things,2.9809999999999999,0.82799999999999996,0.90900000000000003,0.93100000000000005,0.086999999999999994,7.7919999999999998,2.0990000000000002,0.47999999999999998,0.85799999999999998,0.70199999999999996,0.93200000000000005,0.77500000000000002,0.95099999999999996,0.28499999999999998,0.065000000000000002,,,1.7609999999999999,0.874,0.94499999999999995,0.96099999999999997,0.051999999999999998,5.8239999999999998,0.97499999999999998,0.53100000000000003,0.90700000000000003,0.75700000000000001,0.96899999999999997,0.82599999999999996,0.97999999999999998,0.22900000000000001,0.029999999999999999,,,2.5089999999999999,0.67800000000000005,0.874,0.91700000000000004,0.114,,6.1299999999999999,0.59499999999999997,0.78900000000000003,0.83699999999999997,0.19700000000000001, -liteflownet,kitti,5.9169999999999998,0.73799999999999999,0.84199999999999997,0.872,0.156,13.327,4.5700000000000003,0.33100000000000002,0.76700000000000002,0.55100000000000005,0.86899999999999999,0.63900000000000001,0.89700000000000002,0.439,0.129,,,4.5529999999999999,0.79600000000000004,0.88,0.90300000000000002,0.11799999999999999,11.640000000000001,3.234,0.35599999999999998,0.83299999999999996,0.58999999999999997,0.91100000000000003,0.67700000000000005,0.93000000000000005,0.39900000000000002,0.087999999999999995,,,1.1639999999999999,0.82899999999999996,0.94199999999999995,0.96499999999999997,0.041000000000000002,,1.7829999999999999,0.77900000000000003,0.90700000000000003,0.93799999999999994,0.065000000000000002, -liteflownet,sintel,1.845,0.85599999999999998,0.92300000000000004,0.94299999999999995,0.072999999999999995,5.5490000000000004,1.2290000000000001,0.49199999999999999,0.88900000000000001,0.72399999999999998,0.94699999999999995,0.79900000000000004,0.96299999999999997,0.26100000000000001,0.049000000000000002,,,1.419,0.90500000000000003,0.94799999999999995,0.96099999999999997,0.049000000000000002,5.2389999999999999,0.75600000000000001,0.53500000000000003,0.94399999999999995,0.754,0.97399999999999998,0.82099999999999995,0.97999999999999998,0.23100000000000001,0.025000000000000001,,,3.661,0.73599999999999999,0.86599999999999999,0.90000000000000002,0.125,,10.353999999999999,0.63100000000000001,0.75800000000000001,0.79400000000000004,0.23499999999999999, -liteflownet,things,4.024,0.78600000000000003,0.879,0.90600000000000003,0.11799999999999999,10.305999999999999,2.8399999999999999,0.37,0.81899999999999995,0.61899999999999999,0.90700000000000003,0.70499999999999996,0.93000000000000005,0.37,0.089999999999999997,,,2.504,0.85099999999999998,0.92400000000000004,0.94199999999999995,0.073999999999999996,7.9180000000000001,1.45,0.41699999999999998,0.89100000000000001,0.68200000000000005,0.95399999999999996,0.76400000000000001,0.96699999999999997,0.30599999999999999,0.044999999999999998,,,4.532,0.627,0.80200000000000005,0.84899999999999998,0.191,,11.477,0.52200000000000002,0.68300000000000005,0.73399999999999999,0.311, -liteflownet2,sintel,1.9199999999999999,0.84399999999999997,0.91900000000000004,0.94199999999999995,0.076999999999999999,5.7709999999999999,1.274,0.45900000000000002,0.877,0.70699999999999996,0.94299999999999995,0.79100000000000004,0.96099999999999997,0.27800000000000002,0.053999999999999999,,,1.4950000000000001,0.88700000000000001,0.94399999999999995,0.95899999999999996,0.052999999999999999,5.298,0.83999999999999997,0.495,0.92500000000000004,0.73699999999999999,0.96899999999999997,0.81299999999999994,0.97799999999999998,0.249,0.029999999999999999,,,1.6479999999999999,0.78000000000000003,0.92100000000000004,0.94999999999999996,0.062,,3.2330000000000001,0.70999999999999996,0.873,0.91100000000000003,0.10000000000000001, -liteflownet2_pseudoreg,kitti,5.6680000000000001,0.74299999999999999,0.84199999999999997,0.872,0.155,12.760999999999999,4.3680000000000003,0.34100000000000003,0.77200000000000002,0.56200000000000006,0.87,0.64600000000000002,0.89600000000000002,0.42799999999999999,0.129,,,4.5090000000000003,0.79700000000000004,0.88100000000000001,0.90200000000000002,0.11799999999999999,11.509,3.194,0.37,0.83199999999999996,0.60299999999999998,0.91000000000000003,0.68400000000000005,0.92900000000000005,0.38700000000000001,0.088999999999999996,,,1.0249999999999999,0.84399999999999997,0.94999999999999996,0.96999999999999997,0.035000000000000003,,1.389,0.79700000000000004,0.92700000000000005,0.95499999999999996,0.045999999999999999, -liteflownet3,sintel,1.8839999999999999,0.84699999999999998,0.92300000000000004,0.94399999999999995,0.072999999999999995,5.6180000000000003,1.24,0.46899999999999997,0.879,0.71599999999999997,0.94699999999999995,0.79800000000000004,0.96299999999999997,0.27000000000000002,0.050000000000000003,,0.67900000000000005,1.4099999999999999,0.88900000000000001,0.94699999999999995,0.96099999999999997,0.050999999999999997,5.0469999999999997,0.78400000000000003,0.503,0.92500000000000004,0.745,0.97099999999999997,0.82099999999999995,0.97999999999999998,0.23999999999999999,0.027,,0.66200000000000003,1.6739999999999999,0.76600000000000001,0.91500000000000004,0.94799999999999995,0.068000000000000005,0.53500000000000003,3.4100000000000001,0.69599999999999995,0.85899999999999999,0.90200000000000002,0.114,0.51300000000000001 -liteflownet3_pseudoreg,kitti,5.577,0.745,0.84599999999999997,0.874,0.152,12.622,4.3250000000000002,0.33900000000000002,0.77400000000000002,0.56399999999999995,0.873,0.64600000000000002,0.89900000000000002,0.42599999999999999,0.125,,0.61899999999999999,4.3979999999999997,0.80100000000000005,0.88500000000000001,0.90600000000000003,0.113,11.292999999999999,3.1230000000000002,0.376,0.83699999999999997,0.60599999999999998,0.91500000000000004,0.68600000000000005,0.93200000000000005,0.38300000000000001,0.084000000000000005,,0.63100000000000001,1.0209999999999999,0.84599999999999997,0.94799999999999995,0.96899999999999997,0.035999999999999997,0.55200000000000005,1.421,0.79400000000000004,0.92500000000000004,0.95399999999999996,0.048000000000000001,0.52300000000000002 -liteflownet3s,sintel,2.0190000000000001,0.84399999999999997,0.92100000000000004,0.94199999999999995,0.074999999999999997,6.0199999999999996,1.3360000000000001,0.46100000000000002,0.876,0.70899999999999996,0.94399999999999995,0.79100000000000004,0.96199999999999997,0.27700000000000002,0.051999999999999998,,0.68899999999999995,1.5289999999999999,0.88700000000000001,0.94499999999999995,0.95899999999999996,0.052999999999999999,5.4420000000000002,0.85899999999999999,0.496,0.92400000000000004,0.73799999999999999,0.96899999999999997,0.81399999999999995,0.97899999999999998,0.248,0.029000000000000001,,0.68500000000000005,1.788,0.755,0.91000000000000003,0.94399999999999995,0.072999999999999995,0.50600000000000001,3.6859999999999999,0.68600000000000005,0.85199999999999998,0.89400000000000002,0.124,0.497 -liteflownet3s_pseudoreg,kitti,5.4820000000000002,0.747,0.84999999999999998,0.879,0.14799999999999999,12.689,4.1769999999999996,0.34699999999999998,0.77700000000000002,0.56999999999999995,0.878,0.65400000000000003,0.90400000000000003,0.41899999999999998,0.12,,0.64400000000000002,4.2140000000000004,0.80800000000000005,0.89000000000000001,0.91000000000000003,0.108,11.004,2.9529999999999998,0.38800000000000001,0.84399999999999997,0.61399999999999999,0.92000000000000004,0.69399999999999995,0.93600000000000005,0.375,0.079000000000000001,,0.65400000000000003,1.04,0.83999999999999997,0.94799999999999995,0.96899999999999997,0.035999999999999997,0.53600000000000003,1.538,0.78400000000000003,0.91800000000000004,0.94899999999999995,0.052999999999999999,0.51200000000000001 -maskflownet,kitti,5.907,0.75900000000000001,0.85299999999999998,0.88,0.14399999999999999,12.561,4.6609999999999996,0.40100000000000002,0.78600000000000003,0.60899999999999999,0.877,0.68500000000000005,0.90200000000000002,0.38100000000000001,0.121,,,4.3639999999999999,0.81399999999999995,0.89000000000000001,0.91000000000000003,0.108,10.962,3.1379999999999999,0.437,0.84499999999999997,0.64200000000000002,0.91600000000000004,0.71499999999999997,0.93400000000000005,0.34699999999999998,0.082000000000000003,,,1.3340000000000001,0.77100000000000002,0.93700000000000006,0.96299999999999997,0.044999999999999998,,2.8450000000000002,0.66100000000000003,0.85799999999999998,0.90800000000000003,0.106, -maskflownet,sintel,2.6920000000000002,0.85699999999999998,0.91700000000000004,0.93500000000000005,0.080000000000000002,7.2729999999999997,1.883,0.52800000000000002,0.88600000000000001,0.71499999999999997,0.93899999999999995,0.77900000000000003,0.95399999999999996,0.27300000000000002,0.058000000000000003,,,1.786,0.90100000000000002,0.94499999999999995,0.95699999999999996,0.052999999999999999,6.2309999999999999,1.0029999999999999,0.57099999999999995,0.93200000000000005,0.749,0.96699999999999997,0.80700000000000005,0.97599999999999998,0.23899999999999999,0.032000000000000001,,,1.7030000000000001,0.72399999999999998,0.92100000000000004,0.95199999999999996,0.059999999999999998,,3.1259999999999999,0.63800000000000001,0.85799999999999998,0.90900000000000003,0.10100000000000001, -maskflownet_s,sintel,2.8069999999999999,0.84099999999999997,0.91100000000000003,0.93100000000000005,0.085999999999999993,7.5090000000000003,1.984,0.495,0.87,0.69899999999999995,0.93400000000000005,0.76800000000000002,0.95099999999999996,0.28899999999999998,0.063,0.63,,1.9379999999999999,0.88100000000000001,0.93799999999999994,0.95299999999999996,0.058999999999999997,6.5250000000000004,1.1279999999999999,0.53100000000000003,0.91300000000000003,0.73099999999999998,0.96199999999999997,0.79600000000000004,0.97299999999999998,0.25600000000000001,0.035999999999999997,0.63300000000000001,,1.861,0.68999999999999995,0.91000000000000003,0.94599999999999995,0.070000000000000007,,3.4569999999999999,0.60999999999999999,0.84299999999999997,0.89900000000000002,0.11899999999999999, -maskflownet_s,things,4.2999999999999998,0.747,0.86299999999999999,0.89500000000000002,0.13300000000000001,10.385,3.1640000000000001,0.371,0.77500000000000002,0.60699999999999998,0.89000000000000001,0.69299999999999995,0.91900000000000004,0.38200000000000001,0.107,0.44700000000000001,,3.0019999999999998,0.81000000000000005,0.90400000000000003,0.92700000000000005,0.094,8.6769999999999996,1.929,0.41199999999999998,0.84399999999999997,0.65300000000000002,0.93200000000000005,0.73799999999999999,0.95099999999999996,0.33400000000000002,0.066000000000000003,0.45700000000000002,,4.7350000000000003,0.505,0.76300000000000001,0.83099999999999996,0.22900000000000001,,11.388,0.438,0.64400000000000002,0.71299999999999997,0.34899999999999998, -pwcdcnet,sintel,2.3279999999999998,0.83499999999999996,0.91200000000000003,0.93500000000000005,0.085000000000000006,6.6340000000000003,1.601,0.47099999999999997,0.86499999999999999,0.69699999999999995,0.93500000000000005,0.77300000000000002,0.95399999999999996,0.28999999999999998,0.062,,,1.8080000000000001,0.872,0.93600000000000005,0.95199999999999996,0.060999999999999999,6.0579999999999998,1.097,0.505,0.90500000000000003,0.72699999999999998,0.95999999999999996,0.79800000000000004,0.97199999999999998,0.26000000000000001,0.039,,,2.0720000000000001,0.753,0.90800000000000003,0.94099999999999995,0.074999999999999997,,3.1589999999999998,0.69499999999999995,0.86699999999999999,0.90900000000000003,0.104, -pwcdcnet,things,4.2130000000000001,0.746,0.86799999999999999,0.90100000000000002,0.128,9.9719999999999995,3.0800000000000001,0.33500000000000002,0.77600000000000002,0.59699999999999998,0.89600000000000002,0.69799999999999995,0.92500000000000004,0.39100000000000001,0.10100000000000001,,,2.6760000000000002,0.79600000000000004,0.90600000000000003,0.93300000000000005,0.090999999999999998,7.8769999999999998,1.7010000000000001,0.36699999999999999,0.83099999999999996,0.64700000000000002,0.93500000000000005,0.747,0.95599999999999996,0.34000000000000002,0.063,,,4.5819999999999999,0.52800000000000002,0.77400000000000002,0.84099999999999997,0.217,,10.994,0.442,0.65200000000000002,0.72099999999999997,0.34100000000000003, -pwcnet,sintel,2.8199999999999998,0.75900000000000001,0.88200000000000001,0.91600000000000004,0.115,7.7990000000000004,1.982,0.316,0.79300000000000004,0.60399999999999998,0.91100000000000003,0.71299999999999997,0.93999999999999995,0.38400000000000001,0.085999999999999993,,,2.2549999999999999,0.80000000000000004,0.91200000000000003,0.93799999999999994,0.085000000000000006,7.2560000000000002,1.405,0.33400000000000002,0.83899999999999997,0.63800000000000001,0.94399999999999995,0.74399999999999999,0.96299999999999997,0.34899999999999998,0.053999999999999999,,,3.327,0.628,0.83599999999999997,0.89000000000000001,0.152,,6.2489999999999997,0.55600000000000005,0.76800000000000002,0.82799999999999996,0.217, -pwcnet,things,4.827,0.67200000000000004,0.84099999999999997,0.88400000000000001,0.156,11.551,3.4990000000000001,0.25600000000000001,0.69999999999999996,0.51300000000000001,0.873,0.63400000000000001,0.91300000000000003,0.47699999999999998,0.124,,,3.3580000000000001,0.72099999999999997,0.88,0.91600000000000004,0.11700000000000001,9.625,2.1720000000000002,0.27000000000000002,0.755,0.55000000000000004,0.91500000000000004,0.67900000000000005,0.94599999999999995,0.439,0.082000000000000003,,,5.5510000000000002,0.496,0.74099999999999999,0.81000000000000005,0.252,,12.669,0.40600000000000003,0.626,0.69899999999999995,0.36699999999999999, -raft,chairs,4.2960000000000003,0.73199999999999998,0.88100000000000001,0.90900000000000003,0.11600000000000001,10.028,3.2160000000000002,0.40600000000000003,0.75600000000000001,0.65700000000000003,0.90500000000000003,0.73599999999999999,0.93000000000000005,0.33200000000000002,0.091999999999999998,,,2.1970000000000001,0.80900000000000005,0.92700000000000005,0.94899999999999995,0.070000000000000007,6.7850000000000001,1.3440000000000001,0.47799999999999998,0.83799999999999997,0.72099999999999997,0.95199999999999996,0.79600000000000004,0.96999999999999997,0.26400000000000001,0.045999999999999999,,,4.5940000000000003,0.29699999999999999,0.69499999999999995,0.81599999999999995,0.29399999999999998,,9.7490000000000006,0.34599999999999997,0.63100000000000001,0.72499999999999998,0.35999999999999999, -raft,kitti,6.2930000000000001,0.755,0.85299999999999998,0.88200000000000001,0.14499999999999999,13.709,4.8419999999999996,0.40100000000000002,0.78200000000000003,0.60899999999999999,0.877,0.68400000000000005,0.90300000000000002,0.38100000000000001,0.121,,,4.6870000000000003,0.80200000000000005,0.89400000000000002,0.91700000000000004,0.104,10.993,3.46,0.42999999999999999,0.83299999999999996,0.65000000000000002,0.92000000000000004,0.72499999999999998,0.93999999999999995,0.33900000000000002,0.079000000000000001,,,1.2709999999999999,0.81399999999999995,0.94099999999999995,0.96499999999999997,0.043999999999999997,,0.77900000000000003,0.85799999999999998,0.95799999999999996,0.97799999999999998,0.023, -raft,sintel,1.571,0.879,0.93899999999999995,0.95699999999999996,0.057000000000000002,4.6479999999999997,1.0409999999999999,0.58899999999999997,0.90500000000000003,0.78200000000000003,0.95799999999999996,0.84199999999999997,0.97199999999999998,0.20399999999999999,0.039,,,0.871,0.92000000000000004,0.96599999999999997,0.97699999999999998,0.031,3.274,0.439,0.64300000000000002,0.94799999999999995,0.82499999999999996,0.98399999999999999,0.878,0.98999999999999999,0.161,0.014999999999999999,,,1.341,0.79900000000000004,0.93600000000000005,0.96299999999999997,0.047,,1.631,0.75700000000000001,0.91300000000000003,0.94699999999999995,0.058000000000000003, -raft,things,3.0089999999999999,0.84699999999999998,0.91300000000000003,0.93300000000000005,0.083000000000000004,7.6269999999999998,2.1240000000000001,0.52400000000000002,0.875,0.72499999999999998,0.93500000000000005,0.79000000000000004,0.95099999999999996,0.26300000000000001,0.062,,,1.5069999999999999,0.89900000000000002,0.95499999999999996,0.96699999999999997,0.042999999999999997,5.1219999999999999,0.79400000000000004,0.58699999999999997,0.93000000000000005,0.78600000000000003,0.97599999999999998,0.84599999999999997,0.98399999999999999,0.20100000000000001,0.023,,,2.2610000000000001,0.70899999999999996,0.88900000000000001,0.92800000000000005,0.098000000000000004,,5.468,0.63300000000000001,0.81000000000000005,0.85299999999999998,0.17499999999999999, -raft_small,things,3.548,0.80000000000000004,0.88900000000000001,0.91500000000000004,0.108,9.0139999999999993,2.5659999999999998,0.39900000000000002,0.83099999999999996,0.64000000000000001,0.91500000000000004,0.72899999999999998,0.93700000000000006,0.34799999999999998,0.082000000000000003,,,2.1899999999999999,0.84499999999999997,0.92700000000000005,0.94699999999999995,0.070000000000000007,6.992,1.3109999999999999,0.443,0.88,0.69199999999999995,0.95399999999999996,0.77900000000000003,0.96899999999999997,0.29399999999999998,0.043999999999999997,,,3.6179999999999999,0.60299999999999998,0.83199999999999996,0.88200000000000001,0.158,,8.6359999999999992,0.52200000000000002,0.72899999999999998,0.78600000000000003,0.26100000000000001, -scopeflow,chairs,3.9500000000000002,0.81200000000000006,0.88600000000000001,0.90900000000000003,0.11,9.8819999999999997,2.8839999999999999,0.41499999999999998,0.84299999999999997,0.63400000000000001,0.91300000000000003,0.71399999999999997,0.93100000000000005,0.35499999999999998,0.085000000000000006,0.71099999999999997,,2.569,0.86099999999999999,0.92500000000000004,0.94099999999999995,0.072999999999999995,7.8959999999999999,1.5609999999999999,0.46000000000000002,0.89700000000000002,0.68899999999999995,0.95199999999999996,0.76500000000000001,0.96399999999999997,0.29999999999999999,0.047,0.73699999999999999,,4.0940000000000003,0.59599999999999997,0.81699999999999995,0.875,0.17299999999999999,,11.975,0.49099999999999999,0.67800000000000005,0.73899999999999999,0.317, -scopeflow,kitti,10.458,0.70499999999999996,0.81100000000000005,0.84099999999999997,0.187,16.911999999999999,9.423,0.36099999999999999,0.72699999999999998,0.57199999999999995,0.83199999999999996,0.65100000000000002,0.85999999999999999,0.41799999999999998,0.16700000000000001,0,,8.1110000000000007,0.753,0.84699999999999998,0.872,0.152,15.348000000000001,6.6950000000000003,0.38900000000000001,0.77900000000000003,0.60599999999999998,0.87,0.68100000000000005,0.89300000000000002,0.38500000000000001,0.129,0,,1.002,0.85199999999999998,0.94999999999999996,0.96899999999999997,0.035999999999999997,,1.337,0.81000000000000005,0.93300000000000005,0.95799999999999996,0.044999999999999998, -scopeflow,sintel,2.4609999999999999,0.84999999999999998,0.91700000000000004,0.93799999999999994,0.079000000000000001,6.7519999999999998,1.7210000000000001,0.50700000000000001,0.879,0.71799999999999997,0.93899999999999995,0.78900000000000003,0.95599999999999996,0.26800000000000002,0.057000000000000002,0.76200000000000001,,1.6299999999999999,0.89500000000000002,0.94699999999999995,0.95999999999999996,0.050999999999999997,5.7690000000000001,0.88500000000000001,0.54400000000000004,0.92700000000000005,0.753,0.96899999999999997,0.81799999999999995,0.97799999999999998,0.23300000000000001,0.029999999999999999,0.78000000000000003,,2.0760000000000001,0.749,0.90500000000000003,0.93899999999999995,0.079000000000000001,,6.0839999999999996,0.64600000000000002,0.81299999999999994,0.85299999999999998,0.16800000000000001, -scopeflow,things,3.2690000000000001,0.81299999999999994,0.89500000000000002,0.92000000000000004,0.10199999999999999,8.3260000000000005,2.3650000000000002,0.435,0.84299999999999997,0.66200000000000003,0.91900000000000004,0.74199999999999999,0.94099999999999995,0.32600000000000001,0.078,0.72599999999999998,,1.8320000000000001,0.872,0.93600000000000005,0.95299999999999996,0.060999999999999999,6.2350000000000003,1.024,0.49099999999999999,0.90600000000000003,0.72199999999999998,0.96099999999999997,0.79700000000000004,0.97299999999999998,0.26500000000000001,0.036999999999999998,0.754,,3.4649999999999999,0.67100000000000004,0.84899999999999998,0.89200000000000002,0.13900000000000001,,9.7840000000000007,0.55600000000000005,0.72499999999999998,0.77500000000000002,0.26600000000000001, -scv4,chairs,4.8970000000000002,0.73499999999999999,0.88700000000000001,0.91000000000000003,0.11,11.262,3.6640000000000001,0.42099999999999999,0.75900000000000001,0.66800000000000004,0.91100000000000003,0.73799999999999999,0.93000000000000005,0.32100000000000001,0.085999999999999993,,,2.2010000000000001,0.83899999999999997,0.93899999999999995,0.95499999999999996,0.058000000000000003,7.0529999999999999,1.2749999999999999,0.50900000000000001,0.86899999999999999,0.74299999999999999,0.96399999999999997,0.80900000000000005,0.97499999999999998,0.24299999999999999,0.034000000000000002,,,6.25,0.35499999999999998,0.69099999999999995,0.78500000000000003,0.30299999999999999,,13.192,0.39100000000000001,0.63,0.69799999999999995,0.36499999999999999, -scv4,kitti,7.1020000000000003,0.751,0.83699999999999997,0.86199999999999999,0.16200000000000001,14.108000000000001,5.7629999999999999,0.40899999999999997,0.77700000000000002,0.60099999999999998,0.85899999999999999,0.67400000000000004,0.88100000000000001,0.38900000000000001,0.14000000000000001,,,6.617,0.80300000000000005,0.875,0.89300000000000002,0.123,13.300000000000001,5.282,0.44900000000000001,0.83399999999999996,0.64300000000000002,0.90000000000000002,0.71099999999999997,0.91400000000000003,0.34599999999999997,0.099000000000000005,,,3.1269999999999998,0.82399999999999995,0.93200000000000005,0.95199999999999996,0.055,,2.4980000000000002,0.82399999999999995,0.93000000000000005,0.95299999999999996,0.045999999999999999, -scv4,sintel,2.6459999999999999,0.878,0.93000000000000005,0.94499999999999995,0.066000000000000003,6.9050000000000002,1.8080000000000001,0.59599999999999997,0.90500000000000003,0.77300000000000002,0.94999999999999996,0.82399999999999995,0.96199999999999997,0.215,0.047,,,1.5289999999999999,0.93400000000000005,0.96599999999999997,0.97399999999999998,0.032000000000000001,5.3949999999999996,0.72699999999999998,0.66900000000000004,0.96299999999999997,0.82999999999999996,0.98499999999999999,0.874,0.98899999999999999,0.157,0.014999999999999999,,,2.5270000000000001,0.81699999999999995,0.93100000000000005,0.95399999999999996,0.055,,3.5910000000000002,0.76400000000000001,0.90200000000000002,0.93100000000000005,0.071999999999999995, -scv4,things,3.8490000000000002,0.82599999999999996,0.89200000000000002,0.91500000000000004,0.105,9.5429999999999993,2.7280000000000002,0.50800000000000001,0.85399999999999998,0.70599999999999996,0.91500000000000004,0.76500000000000001,0.93400000000000005,0.28299999999999997,0.083000000000000004,,,1.796,0.88800000000000001,0.94199999999999995,0.95899999999999996,0.056000000000000001,6.1719999999999997,0.94399999999999995,0.59399999999999997,0.92000000000000004,0.78800000000000003,0.96399999999999997,0.84299999999999997,0.97599999999999998,0.19900000000000001,0.035999999999999997,,,4.2320000000000002,0.71099999999999997,0.84999999999999998,0.88500000000000001,0.14299999999999999,,9.8469999999999995,0.624,0.755,0.79100000000000004,0.23699999999999999, -starflow,kitti,7.2309999999999999,0.67100000000000004,0.81200000000000006,0.84899999999999998,0.187,14.124000000000001,5.9660000000000002,0.32000000000000001,0.69299999999999995,0.55600000000000005,0.83499999999999996,0.64300000000000002,0.871,0.435,0.16300000000000001,0.56599999999999995,,5.0110000000000001,0.72999999999999998,0.85199999999999998,0.88400000000000001,0.14499999999999999,11.722,3.7429999999999999,0.34899999999999998,0.75700000000000001,0.59399999999999997,0.879,0.68000000000000005,0.90800000000000003,0.39600000000000002,0.11899999999999999,0.56399999999999995,,1.712,0.75700000000000001,0.91900000000000004,0.94899999999999995,0.064000000000000001,,2.8180000000000001,0.66200000000000003,0.85499999999999998,0.90000000000000002,0.11600000000000001, -starflow,sintel,2.0219999999999998,0.83799999999999997,0.91600000000000004,0.93899999999999995,0.080000000000000002,5.9649999999999999,1.375,0.47999999999999998,0.86699999999999999,0.71199999999999997,0.93799999999999994,0.79000000000000004,0.95699999999999996,0.27400000000000002,0.058000000000000003,0.72899999999999998,,1.6000000000000001,0.88100000000000001,0.94299999999999995,0.95799999999999996,0.055,5.5650000000000004,0.89300000000000002,0.51500000000000001,0.91500000000000004,0.745,0.96599999999999997,0.81599999999999995,0.97699999999999998,0.24299999999999999,0.032000000000000001,0.749,,3.5449999999999999,0.68899999999999995,0.85999999999999999,0.90100000000000002,0.128,,7.819,0.58999999999999997,0.76300000000000001,0.81100000000000005,0.223, -starflow,things,3.5840000000000001,0.79600000000000004,0.89100000000000001,0.91800000000000004,0.105,8.6539999999999999,2.665,0.42299999999999999,0.82499999999999996,0.66300000000000003,0.91500000000000004,0.74399999999999999,0.93799999999999994,0.32400000000000001,0.081000000000000003,0.70299999999999996,,1.8220000000000001,0.86199999999999999,0.93600000000000005,0.95299999999999996,0.062,6.0330000000000004,1.0309999999999999,0.47899999999999998,0.89600000000000002,0.72599999999999998,0.95999999999999996,0.80200000000000005,0.97299999999999998,0.26100000000000001,0.037999999999999999,0.73499999999999999,,4.1449999999999996,0.63500000000000001,0.82299999999999995,0.86899999999999999,0.16600000000000001,,9.7400000000000002,0.53500000000000003,0.71299999999999997,0.76600000000000001,0.27600000000000002, -vcn,chairs,3.96,0.751,0.873,0.90500000000000003,0.123,10.07,2.8620000000000001,0.35999999999999999,0.78000000000000003,0.60799999999999998,0.90100000000000002,0.70099999999999996,0.93000000000000005,0.38,0.096000000000000002,,,2.8010000000000002,0.80500000000000005,0.91000000000000003,0.93400000000000005,0.086999999999999994,8.4559999999999995,1.748,0.39700000000000002,0.83899999999999997,0.65700000000000003,0.93899999999999995,0.745,0.95799999999999996,0.33100000000000002,0.058999999999999997,,,4.444,0.51800000000000002,0.79300000000000004,0.85999999999999999,0.19700000000000001,,10.773,0.436,0.65900000000000003,0.73199999999999998,0.33300000000000002, -vcn,kitti,9.1959999999999997,0.69099999999999995,0.79100000000000004,0.81899999999999995,0.20699999999999999,15.380000000000001,8.0350000000000001,0.33500000000000002,0.71499999999999997,0.54800000000000004,0.81299999999999994,0.63,0.83899999999999997,0.442,0.186,,,5.9340000000000002,0.76900000000000002,0.85099999999999998,0.874,0.14599999999999999,12.215999999999999,4.8049999999999997,0.371,0.79900000000000004,0.59399999999999997,0.876,0.67600000000000005,0.89600000000000002,0.39500000000000002,0.122,,,1.141,0.84199999999999997,0.94399999999999995,0.96499999999999997,0.041000000000000002,,1.4550000000000001,0.78600000000000003,0.92100000000000004,0.95099999999999996,0.052999999999999999, -vcn,sintel,2.2509999999999999,0.82299999999999995,0.90800000000000003,0.93300000000000005,0.087999999999999995,6.4630000000000001,1.571,0.40500000000000003,0.85699999999999998,0.67000000000000004,0.93400000000000005,0.76400000000000001,0.95399999999999996,0.316,0.063,,,1.595,0.86699999999999999,0.93600000000000005,0.95399999999999996,0.060999999999999999,5.6299999999999999,0.91800000000000004,0.435,0.90600000000000003,0.70699999999999996,0.96299999999999997,0.79500000000000004,0.97499999999999998,0.27900000000000003,0.035000000000000003,,,2.2090000000000001,0.71399999999999997,0.88300000000000001,0.92300000000000004,0.10000000000000001,,4.3140000000000001,0.60799999999999998,0.80000000000000004,0.85699999999999998,0.17499999999999999, -vcn,things,3.9660000000000002,0.76100000000000001,0.87,0.90100000000000002,0.126,9.6509999999999998,2.9689999999999999,0.35999999999999999,0.79100000000000004,0.60999999999999999,0.89700000000000002,0.70299999999999996,0.92400000000000004,0.379,0.10000000000000001,,,2.4940000000000002,0.82199999999999995,0.91300000000000003,0.93600000000000005,0.084000000000000005,7.6109999999999998,1.573,0.39800000000000002,0.85799999999999998,0.66400000000000003,0.94099999999999995,0.755,0.95999999999999996,0.32300000000000001,0.056000000000000001,,,3.4660000000000002,0.64700000000000002,0.83999999999999997,0.89000000000000001,0.14899999999999999,,8.6270000000000007,0.54100000000000004,0.72799999999999998,0.78500000000000003,0.26100000000000001, -vcn_small,chairs,4.2969999999999997,0.68300000000000005,0.84799999999999998,0.89100000000000001,0.14899999999999999,10.369,3.214,0.30299999999999999,0.70899999999999996,0.56599999999999995,0.876,0.67100000000000004,0.91600000000000004,0.42299999999999999,0.121,,,3.2170000000000001,0.72399999999999998,0.88200000000000001,0.91800000000000004,0.114,9.1470000000000002,2.1190000000000002,0.32600000000000001,0.753,0.60599999999999998,0.91300000000000003,0.70899999999999996,0.94399999999999995,0.38300000000000001,0.084000000000000005,,,5.6559999999999997,0.33400000000000002,0.68200000000000005,0.79500000000000004,0.309,,13.18,0.314,0.56499999999999995,0.66400000000000003,0.42999999999999999, -vcn_small,things,5.2469999999999999,0.68700000000000006,0.84399999999999997,0.88400000000000001,0.152,10.705,4.2990000000000004,0.28599999999999998,0.71399999999999997,0.56499999999999995,0.871,0.67000000000000004,0.90700000000000003,0.42299999999999999,0.125,,,4.7999999999999998,0.72799999999999998,0.86799999999999999,0.90200000000000002,0.128,9.5850000000000009,3.931,0.30599999999999999,0.75800000000000001,0.60399999999999998,0.89600000000000002,0.70999999999999996,0.92600000000000005,0.38300000000000001,0.10100000000000001,,,4.2480000000000002,0.48499999999999999,0.76600000000000001,0.84199999999999997,0.223,,9.6920000000000002,0.41099999999999998,0.65100000000000002,0.72799999999999998,0.33900000000000002, +model,checkpoint,sintel-final-trainval-occ-val/epe,sintel-final-trainval-occ-val/px1,sintel-final-trainval-occ-val/px3,sintel-final-trainval-occ-val/px5,sintel-final-trainval-occ-val/outlier,sintel-final-trainval-occ-val/epe_occ,sintel-final-trainval-occ-val/epe_non_occ,sintel-final-trainval-occ-val/px1_occ,sintel-final-trainval-occ-val/px1_non_occ,sintel-final-trainval-occ-val/px3_occ,sintel-final-trainval-occ-val/px3_non_occ,sintel-final-trainval-occ-val/px5_occ,sintel-final-trainval-occ-val/px5_non_occ,sintel-final-trainval-occ-val/outlier_occ,sintel-final-trainval-occ-val/outlier_non_occ,sintel-final-trainval-occ-val/occ_f1,sintel-final-trainval-occ-val/conf_f1,sintel-clean-trainval-occ-val/epe,sintel-clean-trainval-occ-val/px1,sintel-clean-trainval-occ-val/px3,sintel-clean-trainval-occ-val/px5,sintel-clean-trainval-occ-val/outlier,sintel-clean-trainval-occ-val/epe_occ,sintel-clean-trainval-occ-val/epe_non_occ,sintel-clean-trainval-occ-val/px1_occ,sintel-clean-trainval-occ-val/px1_non_occ,sintel-clean-trainval-occ-val/px3_occ,sintel-clean-trainval-occ-val/px3_non_occ,sintel-clean-trainval-occ-val/px5_occ,sintel-clean-trainval-occ-val/px5_non_occ,sintel-clean-trainval-occ-val/outlier_occ,sintel-clean-trainval-occ-val/outlier_non_occ,sintel-clean-trainval-occ-val/occ_f1,sintel-clean-trainval-occ-val/conf_f1,kitti-2012-trainval-val/epe,kitti-2012-trainval-val/px1,kitti-2012-trainval-val/px3,kitti-2012-trainval-val/px5,kitti-2012-trainval-val/outlier,kitti-2012-trainval-val/conf_f1,kitti-2015-trainval-val/epe,kitti-2015-trainval-val/px1,kitti-2015-trainval-val/px3,kitti-2015-trainval-val/px5,kitti-2015-trainval-val/outlier,kitti-2015-trainval-val/conf_f1 +craft,things,2.87,0.844,0.912,0.932,0.084,,,,,,,,,,,,,1.238,0.902,0.957,0.97,0.04,,,,,,,,,,,,,2.087,0.701,0.891,0.932,0.094,,4.926,0.626,0.813,0.86,0.169, +craft,sintel,1.399,0.883,0.943,0.961,0.052,,,,,,,,,,,,,0.706,0.925,0.969,0.979,0.028,,,,,,,,,,,,,1.283,0.806,0.939,0.965,0.045,,1.592,0.764,0.918,0.949,0.055, +craft,kitti,5.334,0.762,0.861,0.888,0.137,,,,,,,,,,,,,3.691,0.81,0.901,0.923,0.097,,,,,,,,,,,,,1.312,0.815,0.942,0.964,0.043,,0.784,0.853,0.956,0.977,0.022, +csflow,chairs,4.229,0.735,0.885,0.913,0.112,10.192,3.113,0.408,0.76,0.659,0.909,0.736,0.934,0.33,0.088,,,2.157,0.812,0.93,0.952,0.067,6.771,1.315,0.481,0.84,0.724,0.955,0.798,0.972,0.262,0.043,,,4.344,0.323,0.721,0.829,0.267,,9.251,0.364,0.644,0.735,0.346, +csflow,things,2.817,0.847,0.914,0.934,0.082,7.277,1.968,0.527,0.875,0.728,0.936,0.792,0.952,0.26,0.061,,,1.404,0.902,0.957,0.969,0.041,4.985,0.702,0.598,0.933,0.794,0.977,0.852,0.985,0.193,0.021,,,2.153,0.708,0.893,0.932,0.093,,5.24,0.633,0.813,0.858,0.17, +csflow,kitti,2.598,0.861,0.925,0.943,0.071,6.958,1.8,0.544,0.889,0.738,0.946,0.801,0.961,0.248,0.051,,,1.122,0.909,0.96,0.972,0.037,4.037,0.599,0.613,0.939,0.801,0.98,0.859,0.987,0.184,0.019,,,1.241,0.817,0.943,0.966,0.042,,0.928,0.841,0.95,0.972,0.029, +dicl,chairs,4.947,0.751,0.865,0.891,0.132,11.704,3.699,0.372,0.78,0.616,0.891,0.697,0.915,0.373,0.107,,,3.661,0.82,0.909,0.927,0.089,10.274,2.389,0.427,0.853,0.67,0.936,0.742,0.951,0.319,0.063,,,5.968,0.44,0.712,0.793,0.281,,18.445,0.376,0.541,0.602,0.456, +dicl,kitti,9.598,0.634,0.765,0.804,0.234,16.165,8.47,0.299,0.653,0.51,0.785,0.6,0.822,0.482,0.214,,,7.798,0.69,0.804,0.838,0.195,14.635,6.558,0.318,0.714,0.54,0.827,0.63,0.859,0.451,0.172,,,1.161,0.829,0.938,0.961,0.046,,1.393,0.779,0.913,0.946,0.058, +dicl,sintel,2.05,0.856,0.922,0.941,0.075,6.031,1.397,0.496,0.886,0.719,0.945,0.793,0.96,0.268,0.053,,,1.297,0.902,0.951,0.964,0.046,4.894,0.678,0.546,0.936,0.76,0.975,0.828,0.982,0.226,0.024,,,2.161,0.759,0.904,0.938,0.081,,5.656,0.663,0.816,0.859,0.164, +dicl,things,3.831,0.811,0.889,0.913,0.107,9.903,2.709,0.418,0.843,0.648,0.915,0.727,0.935,0.341,0.082,,,2.006,0.878,0.939,0.954,0.059,6.691,1.121,0.488,0.914,0.73,0.964,0.801,0.974,0.258,0.034,,,3.745,0.695,0.849,0.889,0.14,,9.906,0.599,0.75,0.791,0.24, +fastflownet,chairs,4.38,0.734,0.857,0.891,0.14,10.96,3.162,0.323,0.764,0.571,0.886,0.668,0.917,0.419,0.111,,,3.252,0.79,0.896,0.922,0.101,9.719,2.011,0.354,0.826,0.62,0.928,0.713,0.949,0.369,0.07,,,5.76,0.511,0.757,0.824,0.236,,14.374,0.428,0.628,0.695,0.368, +fastflownet,kitti,5.147,0.739,0.845,0.876,0.152,12.287,3.903,0.329,0.768,0.549,0.874,0.637,0.902,0.44,0.124,,,3.952,0.788,0.882,0.906,0.116,10.859,2.708,0.356,0.823,0.586,0.913,0.674,0.934,0.403,0.085,,,1.534,0.775,0.918,0.948,0.065,,2.198,0.723,0.885,0.924,0.087, +fastflownet,mix,2.754,0.797,0.891,0.918,0.105,7.814,1.922,0.381,0.831,0.631,0.919,0.722,0.942,0.356,0.077,,,2.641,0.837,0.918,0.938,0.08,8.289,1.617,0.413,0.875,0.665,0.948,0.748,0.963,0.323,0.051,,,2.602,0.697,0.88,0.92,0.105,,5.51,0.606,0.793,0.845,0.187, +fastflownet,sintel,2.79,0.796,0.891,0.918,0.105,7.856,1.952,0.381,0.83,0.632,0.919,0.722,0.942,0.356,0.078,,,2.729,0.836,0.917,0.937,0.081,8.557,1.66,0.413,0.874,0.664,0.947,0.746,0.962,0.325,0.052,,,5.227,0.616,0.815,0.861,0.179,,13.731,0.507,0.682,0.731,0.314, +fastflownet,things,4.27,0.739,0.86,0.893,0.136,10.553,3.116,0.317,0.77,0.573,0.889,0.671,0.919,0.416,0.108,,,2.931,0.796,0.901,0.927,0.096,8.86,1.811,0.351,0.833,0.626,0.932,0.722,0.953,0.362,0.066,,,5.528,0.557,0.775,0.833,0.218,,13.134,0.465,0.657,0.718,0.338, +flowformer,chairs,4.733,0.745,0.873,0.903,0.124,11.164,3.521,0.406,0.771,0.642,0.898,0.719,0.925,0.349,0.099,,,3.231,0.803,0.922,0.945,0.075,8.381,2.206,0.461,0.834,0.711,0.948,0.789,0.967,0.276,0.05,,,4.991,0.356,0.709,0.814,0.283,,11.386,0.362,0.62,0.713,0.373, +flowformer,things,2.797,0.838,0.914,0.935,0.082,7.098,1.998,0.522,0.865,0.731,0.936,0.796,0.953,0.256,0.062,,,1.237,0.896,0.959,0.971,0.039,4.147,0.721,0.6,0.926,0.806,0.978,0.864,0.986,0.181,0.02,,,3.239,0.694,0.893,0.931,0.092,,5.643,0.607,0.809,0.859,0.172, +flowformer,sintel,1.741,0.885,0.945,0.961,0.051,4.75,1.205,0.619,0.911,0.806,0.962,0.863,0.974,0.18,0.035,,,0.758,0.929,0.972,0.981,0.026,2.956,0.38,0.681,0.955,0.852,0.988,0.899,0.993,0.133,0.011,,,2.318,0.782,0.932,0.958,0.054,,1.747,0.761,0.916,0.949,0.054, +flowformer,kitti,5.191,0.729,0.849,0.881,0.149,11.171,4.166,0.383,0.754,0.601,0.872,0.683,0.902,0.389,0.126,,,4.063,0.789,0.89,0.916,0.108,9.626,3.096,0.435,0.817,0.654,0.915,0.733,0.937,0.336,0.084,,,2.374,0.793,0.924,0.948,0.064,,1.637,0.831,0.935,0.957,0.042, +flownet2,things,3.97,0.772,0.872,0.901,0.124,9.338,2.978,0.413,0.799,0.633,0.896,0.717,0.922,0.355,0.101,,,3.005,0.821,0.907,0.929,0.09,7.867,2.101,0.455,0.851,0.685,0.931,0.764,0.95,0.304,0.066,,,5.468,0.599,0.775,0.823,0.219,,13.143,0.499,0.65,0.699,0.346, +flownetc,things,5.647,0.466,0.76,0.832,0.238,11.412,4.579,0.189,0.481,0.491,0.782,0.614,0.854,0.501,0.216,,,4.545,0.479,0.797,0.865,0.2,9.915,3.557,0.195,0.494,0.526,0.822,0.658,0.889,0.466,0.175,,,8.104,0.198,0.57,0.689,0.426,,15.985,0.179,0.462,0.569,0.536, +flownetcs,things,4.164,0.735,0.857,0.892,0.14,9.498,3.21,0.337,0.763,0.6,0.882,0.699,0.915,0.389,0.115,,,3.131,0.787,0.893,0.921,0.103,8.089,2.226,0.375,0.818,0.652,0.919,0.747,0.943,0.336,0.078,,,5.028,0.558,0.774,0.831,0.218,,12.517,0.452,0.639,0.699,0.356, +flownetcss,things,4.021,0.754,0.864,0.898,0.132,9.299,3.061,0.366,0.782,0.617,0.889,0.711,0.919,0.372,0.108,,,2.989,0.809,0.901,0.925,0.096,7.863,2.096,0.412,0.84,0.671,0.926,0.758,0.946,0.318,0.072,,,4.657,0.601,0.79,0.841,0.202,,11.901,0.491,0.657,0.711,0.338, +flownets,things,5.244,0.567,0.793,0.851,0.205,11.214,4.153,0.23,0.584,0.517,0.816,0.633,0.874,0.475,0.182,,,3.958,0.588,0.82,0.876,0.178,9.464,2.943,0.242,0.607,0.546,0.845,0.665,0.9,0.446,0.152,,,7.723,0.212,0.561,0.683,0.435,,14.812,0.236,0.479,0.58,0.519, +flownetsd,things,7.821,0.656,0.781,0.821,0.219,14.408,6.687,0.313,0.674,0.523,0.802,0.611,0.841,0.47,0.198,,,7.616,0.687,0.8,0.836,0.2,13.937,6.483,0.333,0.707,0.552,0.821,0.642,0.856,0.441,0.179,,,17.258,0.207,0.414,0.495,0.586,,24.21,0.27,0.424,0.489,0.576, +gma,chairs,4.135,0.721,0.886,0.914,0.11,10.129,3.002,0.381,0.747,0.657,0.912,0.734,0.936,0.331,0.085,,,2.371,0.8,0.927,0.949,0.069,7.022,1.529,0.456,0.828,0.712,0.954,0.791,0.97,0.273,0.044,,,4.6,0.364,0.722,0.821,0.268,,9.983,0.38,0.639,0.725,0.352, +gma,things,2.867,0.846,0.912,0.932,0.084,7.397,2.031,0.529,0.874,0.728,0.934,0.79,0.951,0.26,0.063,,,1.41,0.902,0.956,0.969,0.041,4.679,0.808,0.599,0.933,0.795,0.977,0.853,0.985,0.191,0.022,,,2.07,0.703,0.889,0.93,0.097,,4.754,0.63,0.815,0.862,0.166, +gma,sintel,1.384,0.885,0.944,0.961,0.052,4.108,0.91,0.606,0.91,0.794,0.962,0.852,0.975,0.191,0.035,,,0.723,0.925,0.969,0.979,0.028,2.889,0.363,0.659,0.953,0.835,0.986,0.886,0.992,0.15,0.012,,,1.313,0.803,0.938,0.964,0.045,,1.549,0.769,0.92,0.951,0.053, +gma,kitti,6.577,0.762,0.858,0.885,0.14,13.73,5.215,0.4,0.79,0.6,0.884,0.677,0.908,0.389,0.115,,,4.75,0.809,0.896,0.917,0.102,11.076,3.559,0.438,0.841,0.646,0.923,0.721,0.941,0.342,0.076,,,1.526,0.812,0.937,0.96,0.049,,0.767,0.859,0.958,0.977,0.022, +gmflow,chairs,4.836,0.482,0.795,0.863,0.202,10.564,3.754,0.255,0.497,0.556,0.818,0.666,0.885,0.435,0.179,,,3.551,0.492,0.82,0.892,0.176,8.718,2.599,0.263,0.508,0.579,0.845,0.697,0.916,0.409,0.15,,,9.01,0.16,0.518,0.669,0.479,,16.541,0.226,0.469,0.576,0.529, +gmflow,things,3.055,0.786,0.9,0.927,0.097,7.333,2.29,0.422,0.813,0.689,0.923,0.77,0.947,0.3,0.073,,,1.564,0.833,0.938,0.959,0.058,4.899,0.993,0.479,0.863,0.752,0.962,0.827,0.977,0.234,0.036,,,5.917,0.432,0.712,0.786,0.281,,11.369,0.404,0.639,0.712,0.351, +gmflow,sintel,1.587,0.842,0.931,0.954,0.064,4.294,1.123,0.527,0.868,0.765,0.952,0.834,0.97,0.221,0.044,,,1.063,0.879,0.956,0.972,0.04,3.268,0.687,0.566,0.908,0.801,0.977,0.864,0.987,0.183,0.02,,,2.417,0.664,0.872,0.916,0.115,,3.248,0.63,0.831,0.88,0.147, +gmflow,kitti,4.059,0.684,0.847,0.887,0.149,9.882,3.045,0.321,0.708,0.571,0.875,0.668,0.913,0.418,0.121,,,2.948,0.716,0.879,0.917,0.116,8.353,1.991,0.342,0.744,0.6,0.911,0.698,0.946,0.388,0.084,,,2.1,0.705,0.893,0.932,0.093,,2.373,0.749,0.883,0.915,0.096, +gmflow_refine,chairs,4.789,0.466,0.801,0.868,0.195,10.804,3.719,0.175,0.482,0.486,0.83,0.631,0.893,0.504,0.166,,,3.613,0.475,0.828,0.896,0.167,9.262,2.615,0.179,0.493,0.502,0.86,0.657,0.924,0.486,0.136,,,8.978,0.155,0.526,0.679,0.471,,17.372,0.212,0.465,0.571,0.531, +gmflow_refine,things,2.962,0.753,0.884,0.919,0.112,7.808,2.173,0.273,0.786,0.58,0.913,0.717,0.942,0.407,0.083,,,1.73,0.797,0.919,0.947,0.077,5.696,1.089,0.296,0.836,0.633,0.949,0.772,0.969,0.353,0.048,,,4.566,0.483,0.763,0.831,0.228,,9.324,0.444,0.674,0.744,0.314, +gmflow_refine,sintel,1.806,0.802,0.91,0.939,0.085,5.208,1.296,0.346,0.836,0.668,0.936,0.775,0.958,0.317,0.06,,,1.368,0.838,0.933,0.955,0.063,4.489,0.903,0.364,0.878,0.695,0.959,0.8,0.974,0.288,0.039,,,2.525,0.499,0.818,0.906,0.16,,3.368,0.48,0.758,0.851,0.209, +gmflow_refine,kitti,4.28,0.647,0.841,0.882,0.155,10.483,3.216,0.241,0.673,0.524,0.872,0.642,0.91,0.465,0.125,,,2.929,0.705,0.886,0.92,0.11,8.647,1.934,0.263,0.738,0.559,0.922,0.678,0.951,0.428,0.074,,,2.413,0.603,0.854,0.905,0.129,,2.926,0.59,0.82,0.88,0.147, +gmflownet,things,2.828,0.85,0.916,0.935,0.081,7.379,1.977,0.529,0.879,0.726,0.937,0.789,0.954,0.261,0.06,,,1.341,0.906,0.958,0.97,0.04,4.61,0.716,0.609,0.936,0.801,0.978,0.858,0.986,0.185,0.021,,,2.042,0.722,0.898,0.935,0.087,,4.618,0.647,0.823,0.868,0.159, +gmflownet,kitti,6.4,0.748,0.85,0.879,0.148,13.448,5.127,0.377,0.774,0.59,0.875,0.67,0.9,0.4,0.124,,,4.616,0.799,0.895,0.918,0.103,11.61,3.265,0.415,0.83,0.635,0.923,0.712,0.942,0.353,0.076,,,1.185,0.818,0.944,0.967,0.04,,0.788,0.854,0.957,0.977,0.024, +gmflownet_mix,things,2.766,0.853,0.918,0.937,0.079,7.14,1.948,0.535,0.881,0.732,0.939,0.794,0.955,0.256,0.058,,,1.177,0.91,0.961,0.972,0.037,4.342,0.57,0.619,0.94,0.809,0.981,0.863,0.988,0.177,0.018,,,2.06,0.708,0.895,0.935,0.09,,4.91,0.623,0.812,0.861,0.169, +gmflownet_mix,sintel,1.523,0.888,0.944,0.96,0.052,4.262,1.045,0.613,0.913,0.795,0.962,0.852,0.974,0.19,0.035,,,0.732,0.93,0.97,0.98,0.027,2.871,0.354,0.67,0.958,0.842,0.987,0.891,0.992,0.143,0.012,,,1.309,0.81,0.939,0.963,0.045,,1.508,0.771,0.917,0.949,0.054, +hd3,chairs,9.722,0.605,0.834,0.857,0.163,18.102,8.202,0.281,0.628,0.558,0.863,0.637,0.882,0.432,0.135,,,4.918,0.768,0.902,0.918,0.096,11.97,3.57,0.341,0.803,0.631,0.934,0.713,0.945,0.358,0.065,,,12.219,0.312,0.574,0.635,0.423,,21.716,0.268,0.483,0.533,0.515, +hd3,kitti,44.813,0.591,0.698,0.72,0.301,57.818,42.809,0.278,0.607,0.47,0.715,0.537,0.736,0.522,0.284,,,37.062,0.64,0.732,0.75,0.267,49.044,35.034,0.296,0.662,0.501,0.752,0.568,0.768,0.49,0.247,,,1.262,0.827,0.936,0.959,0.048,,1.943,0.726,0.902,0.937,0.067, +hd3,sintel,1.603,0.86,0.931,0.951,0.065,5.007,1.033,0.488,0.894,0.732,0.956,0.811,0.97,0.254,0.041,,,2.311,0.892,0.941,0.953,0.058,7.353,1.299,0.508,0.93,0.737,0.968,0.804,0.975,0.253,0.032,,,6.301,0.647,0.814,0.849,0.181,,15.291,0.566,0.707,0.739,0.29, +hd3,things,6.501,0.779,0.86,0.883,0.137,15.18,4.838,0.355,0.812,0.576,0.89,0.654,0.911,0.414,0.107,,,3.214,0.869,0.928,0.942,0.07,9.483,1.948,0.452,0.91,0.697,0.958,0.771,0.967,0.29,0.041,,,6.854,0.626,0.791,0.83,0.203,,14.505,0.573,0.721,0.757,0.272, +hd3_ctxt,chairs,5.792,0.627,0.859,0.884,0.139,12.77,4.474,0.297,0.648,0.603,0.886,0.685,0.907,0.387,0.113,,,3.727,0.753,0.912,0.929,0.086,10.167,2.463,0.365,0.782,0.664,0.94,0.746,0.953,0.324,0.059,,,13.695,0.202,0.493,0.563,0.505,,22.971,0.226,0.429,0.482,0.57, +hd3_ctxt,kitti,7.858,0.708,0.828,0.86,0.169,15.296,6.503,0.354,0.732,0.569,0.853,0.65,0.882,0.421,0.145,,,6.016,0.769,0.872,0.895,0.126,13.042,4.715,0.388,0.798,0.612,0.899,0.689,0.918,0.377,0.099,,,0.999,0.855,0.947,0.966,0.038,,1.536,0.756,0.922,0.951,0.052, +hd3_ctxt,sintel,1.737,0.868,0.934,0.952,0.062,5.236,1.148,0.512,0.899,0.744,0.956,0.818,0.97,0.242,0.04,,,2.103,0.896,0.945,0.957,0.054,6.732,1.183,0.533,0.931,0.755,0.969,0.821,0.976,0.234,0.03,,,5.058,0.648,0.85,0.889,0.141,,13.45,0.557,0.725,0.766,0.269, +hd3_ctxt,things,4.421,0.806,0.882,0.903,0.116,11.506,3.071,0.406,0.839,0.622,0.909,0.697,0.928,0.368,0.088,,,2.072,0.884,0.94,0.953,0.058,6.81,1.162,0.51,0.92,0.741,0.965,0.807,0.974,0.247,0.034,,,4.645,0.658,0.83,0.868,0.159,,9.959,0.591,0.75,0.789,0.24, +irr_pwc,chairs_occ,3.947,0.802,0.883,0.909,0.113,10.004,2.828,0.411,0.832,0.63,0.909,0.71,0.931,0.358,0.088,0.707,,2.315,0.856,0.926,0.944,0.071,7.519,1.323,0.458,0.89,0.693,0.953,0.771,0.966,0.295,0.046,0.739,,3.887,0.566,0.81,0.87,0.18,,10.672,0.491,0.688,0.751,0.304, +irr_pwc,kitti,8.171,0.716,0.82,0.849,0.179,15.455,6.733,0.373,0.739,0.582,0.841,0.662,0.868,0.408,0.157,0.687,,7.447,0.765,0.856,0.88,0.143,15.721,5.6,0.398,0.792,0.614,0.88,0.69,0.901,0.376,0.119,0.711,,1.128,0.844,0.946,0.966,0.039,,1.522,0.798,0.923,0.95,0.052, +irr_pwc,sintel,2.443,0.849,0.916,0.937,0.08,6.666,1.715,0.502,0.878,0.718,0.937,0.789,0.955,0.269,0.059,0.762,,1.85,0.89,0.944,0.957,0.054,6.194,1.057,0.537,0.923,0.748,0.967,0.814,0.975,0.239,0.032,0.773,,2.581,0.738,0.895,0.928,0.091,,7.968,0.625,0.784,0.823,0.202, +irr_pwc,things,3.404,0.81,0.893,0.918,0.104,8.566,2.46,0.435,0.839,0.661,0.917,0.74,0.939,0.327,0.08,0.725,,1.856,0.87,0.936,0.952,0.062,6.302,1.035,0.492,0.904,0.722,0.96,0.796,0.972,0.265,0.038,0.752,,3.55,0.664,0.842,0.886,0.146,,9.508,0.556,0.726,0.777,0.264, +irr_pwcnet,things,4.411,0.77,0.868,0.897,0.129,10.869,3.211,0.358,0.802,0.597,0.896,0.687,0.922,0.392,0.102,,,3.09,0.82,0.907,0.929,0.091,9.114,1.964,0.395,0.856,0.646,0.936,0.734,0.953,0.342,0.063,,,5.967,0.538,0.774,0.834,0.219,,14.707,0.452,0.635,0.695,0.361, +irr_pwcnet_irr,things,4.052,0.755,0.867,0.9,0.129,10.112,2.968,0.357,0.784,0.599,0.894,0.691,0.924,0.39,0.103,,,2.734,0.809,0.908,0.932,0.089,8.447,1.692,0.394,0.842,0.647,0.936,0.738,0.956,0.34,0.061,,,5.145,0.509,0.76,0.83,0.232,,12.983,0.441,0.636,0.7,0.358, +lcv_raft,chairs,4.034,0.753,0.882,0.911,0.115,9.937,2.93,0.409,0.779,0.655,0.907,0.732,0.933,0.334,0.09,,,2.282,0.81,0.924,0.947,0.073,7.009,1.388,0.469,0.839,0.714,0.949,0.79,0.968,0.272,0.048,,,4.371,0.36,0.713,0.824,0.277,,9.213,0.359,0.625,0.723,0.367, +lcv_raft,things,2.981,0.828,0.909,0.931,0.087,7.792,2.099,0.48,0.858,0.702,0.932,0.775,0.951,0.285,0.065,,,1.761,0.874,0.945,0.961,0.052,5.824,0.975,0.531,0.907,0.757,0.969,0.826,0.98,0.229,0.03,,,2.509,0.678,0.874,0.917,0.114,,6.13,0.595,0.789,0.837,0.197, +liteflownet,kitti,5.917,0.738,0.842,0.872,0.156,13.327,4.57,0.331,0.767,0.551,0.869,0.639,0.897,0.439,0.129,,,4.553,0.796,0.88,0.903,0.118,11.64,3.234,0.356,0.833,0.59,0.911,0.677,0.93,0.399,0.088,,,1.164,0.829,0.942,0.965,0.041,,1.783,0.779,0.907,0.938,0.065, +liteflownet,sintel,1.845,0.856,0.923,0.943,0.073,5.549,1.229,0.492,0.889,0.724,0.947,0.799,0.963,0.261,0.049,,,1.419,0.905,0.948,0.961,0.049,5.239,0.756,0.535,0.944,0.754,0.974,0.821,0.98,0.231,0.025,,,3.661,0.736,0.866,0.9,0.125,,10.354,0.631,0.758,0.794,0.235, +liteflownet,things,4.024,0.786,0.879,0.906,0.118,10.306,2.84,0.37,0.819,0.619,0.907,0.705,0.93,0.37,0.09,,,2.504,0.851,0.924,0.942,0.074,7.918,1.45,0.417,0.891,0.682,0.954,0.764,0.967,0.306,0.045,,,4.532,0.627,0.802,0.849,0.191,,11.477,0.522,0.683,0.734,0.311, +liteflownet2,sintel,1.92,0.844,0.919,0.942,0.077,5.771,1.274,0.459,0.877,0.707,0.943,0.791,0.961,0.278,0.054,,,1.495,0.887,0.944,0.959,0.053,5.298,0.84,0.495,0.925,0.737,0.969,0.813,0.978,0.249,0.03,,,1.648,0.78,0.921,0.95,0.062,,3.233,0.71,0.873,0.911,0.1, +liteflownet2_pseudoreg,kitti,5.668,0.743,0.842,0.872,0.155,12.761,4.368,0.341,0.772,0.562,0.87,0.646,0.896,0.428,0.129,,,4.509,0.797,0.881,0.902,0.118,11.509,3.194,0.37,0.832,0.603,0.91,0.684,0.929,0.387,0.089,,,1.025,0.844,0.95,0.97,0.035,,1.389,0.797,0.927,0.955,0.046, +liteflownet3,sintel,1.884,0.847,0.923,0.944,0.073,5.618,1.24,0.469,0.879,0.716,0.947,0.798,0.963,0.27,0.05,,0.679,1.41,0.889,0.947,0.961,0.051,5.047,0.784,0.503,0.925,0.745,0.971,0.821,0.98,0.24,0.027,,0.662,1.674,0.766,0.915,0.948,0.068,0.535,3.41,0.696,0.859,0.902,0.114,0.513 +liteflownet3_pseudoreg,kitti,5.577,0.745,0.846,0.874,0.152,12.622,4.325,0.339,0.774,0.564,0.873,0.646,0.899,0.426,0.125,,0.619,4.398,0.801,0.885,0.906,0.113,11.293,3.123,0.376,0.837,0.606,0.915,0.686,0.932,0.383,0.084,,0.631,1.021,0.846,0.948,0.969,0.036,0.552,1.421,0.794,0.925,0.954,0.048,0.523 +liteflownet3s,sintel,2.019,0.844,0.921,0.942,0.075,6.02,1.336,0.461,0.876,0.709,0.944,0.791,0.962,0.277,0.052,,0.689,1.529,0.887,0.945,0.959,0.053,5.442,0.859,0.496,0.924,0.738,0.969,0.814,0.979,0.248,0.029,,0.685,1.788,0.755,0.91,0.944,0.073,0.506,3.686,0.686,0.852,0.894,0.124,0.497 +liteflownet3s_pseudoreg,kitti,5.482,0.747,0.85,0.879,0.148,12.689,4.177,0.347,0.777,0.57,0.878,0.654,0.904,0.419,0.12,,0.644,4.214,0.808,0.89,0.91,0.108,11.004,2.953,0.388,0.844,0.614,0.92,0.694,0.936,0.375,0.079,,0.654,1.04,0.84,0.948,0.969,0.036,0.536,1.538,0.784,0.918,0.949,0.053,0.512 +maskflownet,kitti,5.907,0.759,0.853,0.88,0.144,12.561,4.661,0.401,0.786,0.609,0.877,0.685,0.902,0.381,0.121,,,4.364,0.814,0.89,0.91,0.108,10.962,3.138,0.437,0.845,0.642,0.916,0.715,0.934,0.347,0.082,,,1.334,0.771,0.937,0.963,0.045,,2.845,0.661,0.858,0.908,0.106, +maskflownet,sintel,2.692,0.857,0.917,0.935,0.08,7.273,1.883,0.528,0.886,0.715,0.939,0.779,0.954,0.273,0.058,,,1.786,0.901,0.945,0.957,0.053,6.231,1.003,0.571,0.932,0.749,0.967,0.807,0.976,0.239,0.032,,,1.703,0.724,0.921,0.952,0.06,,3.126,0.638,0.858,0.909,0.101, +maskflownet_s,sintel,2.807,0.841,0.911,0.931,0.086,7.509,1.984,0.495,0.87,0.699,0.934,0.768,0.951,0.289,0.063,0.63,,1.938,0.881,0.938,0.953,0.059,6.525,1.128,0.531,0.913,0.731,0.962,0.796,0.973,0.256,0.036,0.633,,1.861,0.69,0.91,0.946,0.07,,3.457,0.61,0.843,0.899,0.119, +maskflownet_s,things,4.3,0.747,0.863,0.895,0.133,10.385,3.164,0.371,0.775,0.607,0.89,0.693,0.919,0.382,0.107,0.447,,3.002,0.81,0.904,0.927,0.094,8.677,1.929,0.412,0.844,0.653,0.932,0.738,0.951,0.334,0.066,0.457,,4.735,0.505,0.763,0.831,0.229,,11.388,0.438,0.644,0.713,0.349, +pwcdcnet,sintel,2.328,0.835,0.912,0.935,0.085,6.634,1.601,0.471,0.865,0.697,0.935,0.773,0.954,0.29,0.062,,,1.808,0.872,0.936,0.952,0.061,6.058,1.097,0.505,0.905,0.727,0.96,0.798,0.972,0.26,0.039,,,2.072,0.753,0.908,0.941,0.075,,3.159,0.695,0.867,0.909,0.104, +pwcdcnet,things,4.213,0.746,0.868,0.901,0.128,9.972,3.08,0.335,0.776,0.597,0.896,0.698,0.925,0.391,0.101,,,2.676,0.796,0.906,0.933,0.091,7.877,1.701,0.367,0.831,0.647,0.935,0.747,0.956,0.34,0.063,,,4.582,0.528,0.774,0.841,0.217,,10.994,0.442,0.652,0.721,0.341, +pwcnet,sintel,2.82,0.759,0.882,0.916,0.115,7.799,1.982,0.316,0.793,0.604,0.911,0.713,0.94,0.384,0.086,,,2.255,0.8,0.912,0.938,0.085,7.256,1.405,0.334,0.839,0.638,0.944,0.744,0.963,0.349,0.054,,,3.327,0.628,0.836,0.89,0.152,,6.249,0.556,0.768,0.828,0.217, +pwcnet,things,4.827,0.672,0.841,0.884,0.156,11.551,3.499,0.256,0.7,0.513,0.873,0.634,0.913,0.477,0.124,,,3.358,0.721,0.88,0.916,0.117,9.625,2.172,0.27,0.755,0.55,0.915,0.679,0.946,0.439,0.082,,,5.551,0.496,0.741,0.81,0.252,,12.669,0.406,0.626,0.699,0.367, +raft,chairs,4.296,0.732,0.881,0.909,0.116,10.028,3.216,0.406,0.756,0.657,0.905,0.736,0.93,0.332,0.092,,,2.197,0.809,0.927,0.949,0.07,6.785,1.344,0.478,0.838,0.721,0.952,0.796,0.97,0.264,0.046,,,4.594,0.297,0.695,0.816,0.294,,9.749,0.346,0.631,0.725,0.36, +raft,kitti,6.293,0.755,0.853,0.882,0.145,13.709,4.842,0.401,0.782,0.609,0.877,0.684,0.903,0.381,0.121,,,4.687,0.802,0.894,0.917,0.104,10.993,3.46,0.43,0.833,0.65,0.92,0.725,0.94,0.339,0.079,,,1.271,0.814,0.941,0.965,0.044,,0.779,0.858,0.958,0.978,0.023, +raft,sintel,1.571,0.879,0.939,0.957,0.057,4.648,1.041,0.589,0.905,0.782,0.958,0.842,0.972,0.204,0.039,,,0.871,0.92,0.966,0.977,0.031,3.274,0.439,0.643,0.948,0.825,0.984,0.878,0.99,0.161,0.015,,,1.341,0.799,0.936,0.963,0.047,,1.631,0.757,0.913,0.947,0.058, +raft,things,3.009,0.847,0.913,0.933,0.083,7.627,2.124,0.524,0.875,0.725,0.935,0.79,0.951,0.263,0.062,,,1.507,0.899,0.955,0.967,0.043,5.122,0.794,0.587,0.93,0.786,0.976,0.846,0.984,0.201,0.023,,,2.261,0.709,0.889,0.928,0.098,,5.468,0.633,0.81,0.853,0.175, +raft_small,things,3.548,0.8,0.889,0.915,0.108,9.014,2.566,0.399,0.831,0.64,0.915,0.729,0.937,0.348,0.082,,,2.19,0.845,0.927,0.947,0.07,6.992,1.311,0.443,0.88,0.692,0.954,0.779,0.969,0.294,0.044,,,3.618,0.603,0.832,0.882,0.158,,8.636,0.522,0.729,0.786,0.261, +scopeflow,chairs,3.95,0.812,0.886,0.909,0.11,9.882,2.884,0.415,0.843,0.634,0.913,0.714,0.931,0.355,0.085,0.711,,2.569,0.861,0.925,0.941,0.073,7.896,1.561,0.46,0.897,0.689,0.952,0.765,0.964,0.3,0.047,0.737,,4.094,0.596,0.817,0.875,0.173,,11.975,0.491,0.678,0.739,0.317, +scopeflow,kitti,10.458,0.705,0.811,0.841,0.187,16.912,9.423,0.361,0.727,0.572,0.832,0.651,0.86,0.418,0.167,0,,8.111,0.753,0.847,0.872,0.152,15.348,6.695,0.389,0.779,0.606,0.87,0.681,0.893,0.385,0.129,0,,1.002,0.852,0.95,0.969,0.036,,1.337,0.81,0.933,0.958,0.045, +scopeflow,sintel,2.461,0.85,0.917,0.938,0.079,6.752,1.721,0.507,0.879,0.718,0.939,0.789,0.956,0.268,0.057,0.762,,1.63,0.895,0.947,0.96,0.051,5.769,0.885,0.544,0.927,0.753,0.969,0.818,0.978,0.233,0.03,0.78,,2.076,0.749,0.905,0.939,0.079,,6.084,0.646,0.813,0.853,0.168, +scopeflow,things,3.269,0.813,0.895,0.92,0.102,8.326,2.365,0.435,0.843,0.662,0.919,0.742,0.941,0.326,0.078,0.726,,1.832,0.872,0.936,0.953,0.061,6.235,1.024,0.491,0.906,0.722,0.961,0.797,0.973,0.265,0.037,0.754,,3.465,0.671,0.849,0.892,0.139,,9.784,0.556,0.725,0.775,0.266, +scv4,chairs,4.897,0.735,0.887,0.91,0.11,11.262,3.664,0.421,0.759,0.668,0.911,0.738,0.93,0.321,0.086,,,2.201,0.839,0.939,0.955,0.058,7.053,1.275,0.509,0.869,0.743,0.964,0.809,0.975,0.243,0.034,,,6.25,0.355,0.691,0.785,0.303,,13.192,0.391,0.63,0.698,0.365, +scv4,kitti,7.102,0.751,0.837,0.862,0.162,14.108,5.763,0.409,0.777,0.601,0.859,0.674,0.881,0.389,0.14,,,6.617,0.803,0.875,0.893,0.123,13.3,5.282,0.449,0.834,0.643,0.9,0.711,0.914,0.346,0.099,,,3.127,0.824,0.932,0.952,0.055,,2.498,0.824,0.93,0.953,0.046, +scv4,sintel,2.646,0.878,0.93,0.945,0.066,6.905,1.808,0.596,0.905,0.773,0.95,0.824,0.962,0.215,0.047,,,1.529,0.934,0.966,0.974,0.032,5.395,0.727,0.669,0.963,0.83,0.985,0.874,0.989,0.157,0.015,,,2.527,0.817,0.931,0.954,0.055,,3.591,0.764,0.902,0.931,0.072, +scv4,things,3.849,0.826,0.892,0.915,0.105,9.543,2.728,0.508,0.854,0.706,0.915,0.765,0.934,0.283,0.083,,,1.796,0.888,0.942,0.959,0.056,6.172,0.944,0.594,0.92,0.788,0.964,0.843,0.976,0.199,0.036,,,4.232,0.711,0.85,0.885,0.143,,9.847,0.624,0.755,0.791,0.237, +starflow,kitti,7.231,0.671,0.812,0.849,0.187,14.124,5.966,0.32,0.693,0.556,0.835,0.643,0.871,0.435,0.163,0.566,,5.011,0.73,0.852,0.884,0.145,11.722,3.743,0.349,0.757,0.594,0.879,0.68,0.908,0.396,0.119,0.564,,1.712,0.757,0.919,0.949,0.064,,2.818,0.662,0.855,0.9,0.116, +starflow,sintel,2.022,0.838,0.916,0.939,0.08,5.965,1.375,0.48,0.867,0.712,0.938,0.79,0.957,0.274,0.058,0.729,,1.6,0.881,0.943,0.958,0.055,5.565,0.893,0.515,0.915,0.745,0.966,0.816,0.977,0.243,0.032,0.749,,3.545,0.689,0.86,0.901,0.128,,7.819,0.59,0.763,0.811,0.223, +starflow,things,3.584,0.796,0.891,0.918,0.105,8.654,2.665,0.423,0.825,0.663,0.915,0.744,0.938,0.324,0.081,0.703,,1.822,0.862,0.936,0.953,0.062,6.033,1.031,0.479,0.896,0.726,0.96,0.802,0.973,0.261,0.038,0.735,,4.145,0.635,0.823,0.869,0.166,,9.74,0.535,0.713,0.766,0.276, +vcn,chairs,3.96,0.751,0.873,0.905,0.123,10.07,2.862,0.36,0.78,0.608,0.901,0.701,0.93,0.38,0.096,,,2.801,0.805,0.91,0.934,0.087,8.456,1.748,0.397,0.839,0.657,0.939,0.745,0.958,0.331,0.059,,,4.444,0.518,0.793,0.86,0.197,,10.773,0.436,0.659,0.732,0.333, +vcn,kitti,9.196,0.691,0.791,0.819,0.207,15.38,8.035,0.335,0.715,0.548,0.813,0.63,0.839,0.442,0.186,,,5.934,0.769,0.851,0.874,0.146,12.216,4.805,0.371,0.799,0.594,0.876,0.676,0.896,0.395,0.122,,,1.141,0.842,0.944,0.965,0.041,,1.455,0.786,0.921,0.951,0.053, +vcn,sintel,2.251,0.823,0.908,0.933,0.088,6.463,1.571,0.405,0.857,0.67,0.934,0.764,0.954,0.316,0.063,,,1.595,0.867,0.936,0.954,0.061,5.63,0.918,0.435,0.906,0.707,0.963,0.795,0.975,0.279,0.035,,,2.209,0.714,0.883,0.923,0.1,,4.314,0.608,0.8,0.857,0.175, +vcn,things,3.966,0.761,0.87,0.901,0.126,9.651,2.969,0.36,0.791,0.61,0.897,0.703,0.924,0.379,0.1,,,2.494,0.822,0.913,0.936,0.084,7.611,1.573,0.398,0.858,0.664,0.941,0.755,0.96,0.323,0.056,,,3.466,0.647,0.84,0.89,0.149,,8.627,0.541,0.728,0.785,0.261, +vcn_small,chairs,4.297,0.683,0.848,0.891,0.149,10.369,3.214,0.303,0.709,0.566,0.876,0.671,0.916,0.423,0.121,,,3.217,0.724,0.882,0.918,0.114,9.147,2.119,0.326,0.753,0.606,0.913,0.709,0.944,0.383,0.084,,,5.656,0.334,0.682,0.795,0.309,,13.18,0.314,0.565,0.664,0.43, +vcn_small,things,5.247,0.687,0.844,0.884,0.152,10.705,4.299,0.286,0.714,0.565,0.871,0.67,0.907,0.423,0.125,,,4.8,0.728,0.868,0.902,0.128,9.585,3.931,0.306,0.758,0.604,0.896,0.71,0.926,0.383,0.101,,,4.248,0.485,0.766,0.842,0.223,,9.692,0.411,0.651,0.728,0.339, diff --git a/docs/source/results/speed_benchmark-all.csv b/docs/source/results/speed_benchmark-all.csv index 3f42ce7..bf51afe 100644 --- a/docs/source/results/speed_benchmark-all.csv +++ b/docs/source/results/speed_benchmark-all.csv @@ -1,34 +1,41 @@ Model,Params,Time(ms) -dicl,11226036,81.499 -fastflownet,1366114,20.416 -flownet2,162518834,102.711 -flownetc,39175298,46.593 -flownetcs,77870620,58.934 -flownetcss,116565942,71.323 -flownets,38676506,10.748 -flownetsd,45371666,16.331 -gma,5879873,165.717 -hd3,39561975,66.659 -hd3_ctxt,39942647,67.767 -irr_pwc,6362092,176.973 -irr_pwcnet,8639230,37.118 -irr_pwcnet_irr,3354146,40.394 -lcv_raft,5323328,130.238 -lcv_raft_small,1006674,49.608 -liteflownet,5379613,71.395 -liteflownet2,6429120,33.116 -liteflownet2_pseudoreg,6492907,37.444 -liteflownet3,7524188,55.895 -liteflownet3_pseudoreg,7587975,60.264 -liteflownet3s,8005810,57.595 -liteflownet3s_pseudoreg,8069597,62.03 -maskflownet,20655716,76.932 -maskflownet_s,10514256,42.617 -pwcnet,8243008,33.002 -pwcdcnet,9374274,38.314 -raft,5257536,129.953 -raft_small,990162,49.286 -scopeflow,6362092,182.887 -starflow,4772256,148.879 -vcn,10310781,156.231 -vcn_small,8370804,60.097 +craft,6307435,869.901 +csflow,5604672,266.595 +dicl,11226036,168.061 +fastflownet,1366114,38.734 +flowformer,16168113,969.823 +flownet2,162518834,194.865 +flownetc,39175298,81.396 +flownetcs,77870620,105.92 +flownetcss,116565942,132.075 +flownets,38676506,17.768 +flownetsd,45371666,27.348 +gma,5879873,355.223 +gmflow,4680288,198.347 +gmflow_refine,4716720,2187.918 +gmflownet,9343248,592.645 +gmflownet_mix,8687544,421.802 +hd3,39561975,127.002 +hd3_ctxt,39942647,132.77 +irr_pwc,6362092,329.515 +irr_pwcnet,8639230,62.109 +irr_pwcnet_irr,3354146,68.941 +lcv_raft,5323328,288.762 +lcv_raft_small,1006674,75.775 +liteflownet,5379613,150.792 +liteflownet2,6429120,66.144 +liteflownet2_pseudoreg,6492907,76.269 +liteflownet3,7524188,113.444 +liteflownet3_pseudoreg,7587975,124.193 +liteflownet3s,8005810,116.653 +liteflownet3s_pseudoreg,8069597,127.228 +maskflownet,20655716,139.05 +maskflownet_s,10514256,77.388 +pwcnet,8243008,57.33 +pwcdcnet,9374274,68.166 +raft,5257536,272.77 +raft_small,990162,73.929 +scopeflow,6362092,340.41 +starflow,4772256,298.348 +vcn,10310781,287.534 +vcn_small,8370804,106.829 diff --git a/docs/source/results/speed_plot.rst b/docs/source/results/speed_plot.rst index 693d3bf..8aff601 100644 --- a/docs/source/results/speed_plot.rst +++ b/docs/source/results/speed_plot.rst @@ -9,8 +9,8 @@ Inference speed vs. Trainable parameters Environment ----------- -- GPU: RTX 3060Ti +- GPU: Tesla T4 - CUDA: 11.2 -- PyTorch: 1.9.0 \ No newline at end of file +- PyTorch: 1.12.0 \ No newline at end of file diff --git a/docs/source/results/summarized_metrics-epe.csv b/docs/source/results/summarized_metrics-epe.csv index 915e3c6..178b3d9 100644 --- a/docs/source/results/summarized_metrics-epe.csv +++ b/docs/source/results/summarized_metrics-epe.csv @@ -1,39 +1,61 @@ model,checkpoint,sintel-final,sintel-clean,kitti-2012,kitti-2015 +craft,things,2.87,1.238,2.087,4.926 +craft,sintel,1.399,0.706,1.283,1.592 +craft,kitti,5.334,3.691,1.312,0.784 +csflow,chairs,4.229,2.157,4.344,9.251 +csflow,things,2.817,1.404,2.153,5.24 +csflow,kitti,2.598,1.122,1.241,0.928 dicl,chairs,4.947,3.661,5.968,18.445 dicl,kitti,9.598,7.798,1.161,1.393 dicl,sintel,2.05,1.297,2.161,5.656 dicl,things,3.831,2.006,3.745,9.906 +fastflownet,things,4.27,2.931,5.528,13.134 +fastflownet,sintel,2.79,2.729,5.227,13.731 fastflownet,chairs,4.38,3.252,5.76,14.374 fastflownet,kitti,5.147,3.952,1.534,2.198 fastflownet,mix,2.754,2.641,2.602,5.51 -fastflownet,sintel,2.79,2.729,5.227,13.731 -fastflownet,things,4.27,2.931,5.528,13.134 +flowformer,chairs,4.733,3.231,4.991,11.386 +flowformer,things,2.797,1.237,3.239,5.643 +flowformer,sintel,1.741,0.758,2.318,1.747 +flowformer,kitti,5.191,4.063,2.374,1.637 flownet2,things,3.97,3.005,5.468,13.143 flownetc,things,5.647,4.545,8.104,15.985 flownetcs,things,4.164,3.131,5.028,12.517 flownetcss,things,4.021,2.989,4.657,11.901 flownets,things,5.244,3.958,7.723,14.812 flownetsd,things,7.821,7.616,17.258,24.21 -gma,kitti,6.577,4.75,1.526,0.767 -gma,sintel,1.384,0.723,1.313,1.549 gma,chairs,4.135,2.371,4.6,9.983 gma,things,2.867,1.41,2.07,4.754 +gma,sintel,1.384,0.723,1.313,1.549 +gma,kitti,6.577,4.75,1.526,0.767 +gmflow,chairs,4.836,3.551,9.01,16.541 +gmflow,things,3.055,1.564,5.917,11.369 +gmflow,sintel,1.587,1.063,2.417,3.248 +gmflow,kitti,4.059,2.948,2.1,2.373 +gmflow_refine,kitti,4.28,2.929,2.413,2.926 +gmflow_refine,sintel,1.806,1.368,2.525,3.368 +gmflow_refine,things,2.962,1.73,4.566,9.324 +gmflow_refine,chairs,4.789,3.613,8.978,17.372 +gmflownet,things,2.828,1.341,2.042,4.618 +gmflownet,kitti,6.4,4.616,1.185,0.788 +gmflownet_mix,things,2.766,1.177,2.06,4.91 +gmflownet_mix,sintel,1.523,0.732,1.309,1.508 hd3,chairs,9.722,4.918,12.219,21.716 hd3,kitti,44.813,37.062,1.262,1.943 hd3,sintel,1.603,2.311,6.301,15.291 hd3,things,6.501,3.214,6.854,14.505 hd3_ctxt,things,4.421,2.072,4.645,9.959 -hd3_ctxt,chairs,5.792,3.727,13.695,22.971 -hd3_ctxt,kitti,7.858,6.016,0.999,1.536 hd3_ctxt,sintel,1.737,2.103,5.058,13.45 +hd3_ctxt,kitti,7.858,6.016,0.999,1.536 +hd3_ctxt,chairs,5.792,3.727,13.695,22.971 irr_pwc,chairs_occ,3.947,2.315,3.887,10.672 irr_pwc,kitti,8.171,7.447,1.128,1.522 irr_pwc,sintel,2.443,1.85,2.581,7.968 irr_pwc,things,3.404,1.856,3.55,9.508 irr_pwcnet,things,4.411,3.09,5.967,14.707 irr_pwcnet_irr,things,4.052,2.734,5.145,12.983 -lcv_raft,things,2.981,1.761,2.509,6.13 lcv_raft,chairs,4.034,2.282,4.371,9.213 +lcv_raft,things,2.981,1.761,2.509,6.13 liteflownet,things,4.024,2.504,4.532,11.477 liteflownet,sintel,1.845,1.419,3.661,10.354 liteflownet,kitti,5.917,4.553,1.164,1.783 @@ -47,29 +69,29 @@ maskflownet,kitti,5.907,4.364,1.334,2.845 maskflownet,sintel,2.692,1.786,1.703,3.126 maskflownet_s,sintel,2.807,1.938,1.861,3.457 maskflownet_s,things,4.3,3.002,4.735,11.388 -pwcdcnet,sintel,2.328,1.808,2.072,3.159 pwcdcnet,things,4.213,2.676,4.582,10.994 +pwcdcnet,sintel,2.328,1.808,2.072,3.159 pwcnet,sintel,2.82,2.255,3.327,6.249 pwcnet,things,4.827,3.358,5.551,12.669 -raft,sintel,1.571,0.871,1.341,1.631 -raft,things,3.009,1.507,2.261,5.468 raft,chairs,4.296,2.197,4.594,9.749 raft,kitti,6.293,4.687,1.271,0.779 +raft,sintel,1.571,0.871,1.341,1.631 +raft,things,3.009,1.507,2.261,5.468 raft_small,things,3.548,2.19,3.618,8.636 -scopeflow,chairs,3.95,2.569,4.094,11.975 -scopeflow,kitti,10.458,8.111,1.002,1.337 scopeflow,sintel,2.461,1.63,2.076,6.084 scopeflow,things,3.269,1.832,3.465,9.784 -scv4,things,3.849,1.796,4.232,9.847 -scv4,sintel,2.646,1.529,2.527,3.591 +scopeflow,chairs,3.95,2.569,4.094,11.975 +scopeflow,kitti,10.458,8.111,1.002,1.337 scv4,chairs,4.897,2.201,6.25,13.192 scv4,kitti,7.102,6.617,3.127,2.498 +scv4,sintel,2.646,1.529,2.527,3.591 +scv4,things,3.849,1.796,4.232,9.847 starflow,kitti,7.231,5.011,1.712,2.818 starflow,sintel,2.022,1.6,3.545,7.819 starflow,things,3.584,1.822,4.145,9.74 +vcn,sintel,2.251,1.595,2.209,4.314 vcn,chairs,3.96,2.801,4.444,10.773 vcn,kitti,9.196,5.934,1.141,1.455 -vcn,sintel,2.251,1.595,2.209,4.314 vcn,things,3.966,2.494,3.466,8.627 vcn_small,chairs,4.297,3.217,5.656,13.18 vcn_small,things,5.247,4.8,4.248,9.692 diff --git a/docs/source/results/summarized_metrics-epe_outlier.csv b/docs/source/results/summarized_metrics-epe_outlier.csv index 78e79b1..61a5d3a 100644 --- a/docs/source/results/summarized_metrics-epe_outlier.csv +++ b/docs/source/results/summarized_metrics-epe_outlier.csv @@ -1,39 +1,61 @@ model,checkpoint,sintel-final-epe,sintel-final-outlier,sintel-clean-epe,sintel-clean-outlier,kitti-2012-epe,kitti-2012-outlier,kitti-2015-epe,kitti-2015-outlier +craft,things,2.87,0.084,1.238,0.04,2.087,0.094,4.926,0.169 +craft,sintel,1.399,0.052,0.706,0.028,1.283,0.045,1.592,0.055 +craft,kitti,5.334,0.137,3.691,0.097,1.312,0.043,0.784,0.022 +csflow,chairs,4.229,0.112,2.157,0.067,4.344,0.267,9.251,0.346 +csflow,things,2.817,0.082,1.404,0.041,2.153,0.093,5.24,0.17 +csflow,kitti,2.598,0.071,1.122,0.037,1.241,0.042,0.928,0.029 dicl,chairs,4.947,0.132,3.661,0.089,5.968,0.281,18.445,0.456 dicl,kitti,9.598,0.234,7.798,0.195,1.161,0.046,1.393,0.058 dicl,sintel,2.05,0.075,1.297,0.046,2.161,0.081,5.656,0.164 dicl,things,3.831,0.107,2.006,0.059,3.745,0.14,9.906,0.24 +fastflownet,things,4.27,0.136,2.931,0.096,5.528,0.218,13.134,0.338 +fastflownet,sintel,2.79,0.105,2.729,0.081,5.227,0.179,13.731,0.314 fastflownet,chairs,4.38,0.14,3.252,0.101,5.76,0.236,14.374,0.368 fastflownet,kitti,5.147,0.152,3.952,0.116,1.534,0.065,2.198,0.087 fastflownet,mix,2.754,0.105,2.641,0.08,2.602,0.105,5.51,0.187 -fastflownet,sintel,2.79,0.105,2.729,0.081,5.227,0.179,13.731,0.314 -fastflownet,things,4.27,0.136,2.931,0.096,5.528,0.218,13.134,0.338 +flowformer,chairs,4.733,0.124,3.231,0.075,4.991,0.283,11.386,0.373 +flowformer,things,2.797,0.082,1.237,0.039,3.239,0.092,5.643,0.172 +flowformer,sintel,1.741,0.051,0.758,0.026,2.318,0.054,1.747,0.054 +flowformer,kitti,5.191,0.149,4.063,0.108,2.374,0.064,1.637,0.042 flownet2,things,3.97,0.124,3.005,0.09,5.468,0.219,13.143,0.346 flownetc,things,5.647,0.238,4.545,0.2,8.104,0.426,15.985,0.536 flownetcs,things,4.164,0.14,3.131,0.103,5.028,0.218,12.517,0.356 flownetcss,things,4.021,0.132,2.989,0.096,4.657,0.202,11.901,0.338 flownets,things,5.244,0.205,3.958,0.178,7.723,0.435,14.812,0.519 flownetsd,things,7.821,0.219,7.616,0.2,17.258,0.586,24.21,0.576 -gma,kitti,6.577,0.14,4.75,0.102,1.526,0.049,0.767,0.022 -gma,sintel,1.384,0.052,0.723,0.028,1.313,0.045,1.549,0.053 gma,chairs,4.135,0.11,2.371,0.069,4.6,0.268,9.983,0.352 gma,things,2.867,0.084,1.41,0.041,2.07,0.097,4.754,0.166 +gma,sintel,1.384,0.052,0.723,0.028,1.313,0.045,1.549,0.053 +gma,kitti,6.577,0.14,4.75,0.102,1.526,0.049,0.767,0.022 +gmflow,chairs,4.836,0.202,3.551,0.176,9.01,0.479,16.541,0.529 +gmflow,things,3.055,0.097,1.564,0.058,5.917,0.281,11.369,0.351 +gmflow,sintel,1.587,0.064,1.063,0.04,2.417,0.115,3.248,0.147 +gmflow,kitti,4.059,0.149,2.948,0.116,2.1,0.093,2.373,0.096 +gmflow_refine,kitti,4.28,0.155,2.929,0.11,2.413,0.129,2.926,0.147 +gmflow_refine,sintel,1.806,0.085,1.368,0.063,2.525,0.16,3.368,0.209 +gmflow_refine,things,2.962,0.112,1.73,0.077,4.566,0.228,9.324,0.314 +gmflow_refine,chairs,4.789,0.195,3.613,0.167,8.978,0.471,17.372,0.531 +gmflownet,things,2.828,0.081,1.341,0.04,2.042,0.087,4.618,0.159 +gmflownet,kitti,6.4,0.148,4.616,0.103,1.185,0.04,0.788,0.024 +gmflownet_mix,things,2.766,0.079,1.177,0.037,2.06,0.09,4.91,0.169 +gmflownet_mix,sintel,1.523,0.052,0.732,0.027,1.309,0.045,1.508,0.054 hd3,chairs,9.722,0.163,4.918,0.096,12.219,0.423,21.716,0.515 hd3,kitti,44.813,0.301,37.062,0.267,1.262,0.048,1.943,0.067 hd3,sintel,1.603,0.065,2.311,0.058,6.301,0.181,15.291,0.29 hd3,things,6.501,0.137,3.214,0.07,6.854,0.203,14.505,0.272 hd3_ctxt,things,4.421,0.116,2.072,0.058,4.645,0.159,9.959,0.24 -hd3_ctxt,chairs,5.792,0.139,3.727,0.086,13.695,0.505,22.971,0.57 -hd3_ctxt,kitti,7.858,0.169,6.016,0.126,0.999,0.038,1.536,0.052 hd3_ctxt,sintel,1.737,0.062,2.103,0.054,5.058,0.141,13.45,0.269 +hd3_ctxt,kitti,7.858,0.169,6.016,0.126,0.999,0.038,1.536,0.052 +hd3_ctxt,chairs,5.792,0.139,3.727,0.086,13.695,0.505,22.971,0.57 irr_pwc,chairs_occ,3.947,0.113,2.315,0.071,3.887,0.18,10.672,0.304 irr_pwc,kitti,8.171,0.179,7.447,0.143,1.128,0.039,1.522,0.052 irr_pwc,sintel,2.443,0.08,1.85,0.054,2.581,0.091,7.968,0.202 irr_pwc,things,3.404,0.104,1.856,0.062,3.55,0.146,9.508,0.264 irr_pwcnet,things,4.411,0.129,3.09,0.091,5.967,0.219,14.707,0.361 irr_pwcnet_irr,things,4.052,0.129,2.734,0.089,5.145,0.232,12.983,0.358 -lcv_raft,things,2.981,0.087,1.761,0.052,2.509,0.114,6.13,0.197 lcv_raft,chairs,4.034,0.115,2.282,0.073,4.371,0.277,9.213,0.367 +lcv_raft,things,2.981,0.087,1.761,0.052,2.509,0.114,6.13,0.197 liteflownet,things,4.024,0.118,2.504,0.074,4.532,0.191,11.477,0.311 liteflownet,sintel,1.845,0.073,1.419,0.049,3.661,0.125,10.354,0.235 liteflownet,kitti,5.917,0.156,4.553,0.118,1.164,0.041,1.783,0.065 @@ -47,29 +69,29 @@ maskflownet,kitti,5.907,0.144,4.364,0.108,1.334,0.045,2.845,0.106 maskflownet,sintel,2.692,0.08,1.786,0.053,1.703,0.06,3.126,0.101 maskflownet_s,sintel,2.807,0.086,1.938,0.059,1.861,0.07,3.457,0.119 maskflownet_s,things,4.3,0.133,3.002,0.094,4.735,0.229,11.388,0.349 -pwcdcnet,sintel,2.328,0.085,1.808,0.061,2.072,0.075,3.159,0.104 pwcdcnet,things,4.213,0.128,2.676,0.091,4.582,0.217,10.994,0.341 +pwcdcnet,sintel,2.328,0.085,1.808,0.061,2.072,0.075,3.159,0.104 pwcnet,sintel,2.82,0.115,2.255,0.085,3.327,0.152,6.249,0.217 pwcnet,things,4.827,0.156,3.358,0.117,5.551,0.252,12.669,0.367 -raft,sintel,1.571,0.057,0.871,0.031,1.341,0.047,1.631,0.058 -raft,things,3.009,0.083,1.507,0.043,2.261,0.098,5.468,0.175 raft,chairs,4.296,0.116,2.197,0.07,4.594,0.294,9.749,0.36 raft,kitti,6.293,0.145,4.687,0.104,1.271,0.044,0.779,0.023 +raft,sintel,1.571,0.057,0.871,0.031,1.341,0.047,1.631,0.058 +raft,things,3.009,0.083,1.507,0.043,2.261,0.098,5.468,0.175 raft_small,things,3.548,0.108,2.19,0.07,3.618,0.158,8.636,0.261 -scopeflow,chairs,3.95,0.11,2.569,0.073,4.094,0.173,11.975,0.317 -scopeflow,kitti,10.458,0.187,8.111,0.152,1.002,0.036,1.337,0.045 scopeflow,sintel,2.461,0.079,1.63,0.051,2.076,0.079,6.084,0.168 scopeflow,things,3.269,0.102,1.832,0.061,3.465,0.139,9.784,0.266 -scv4,things,3.849,0.105,1.796,0.056,4.232,0.143,9.847,0.237 -scv4,sintel,2.646,0.066,1.529,0.032,2.527,0.055,3.591,0.072 +scopeflow,chairs,3.95,0.11,2.569,0.073,4.094,0.173,11.975,0.317 +scopeflow,kitti,10.458,0.187,8.111,0.152,1.002,0.036,1.337,0.045 scv4,chairs,4.897,0.11,2.201,0.058,6.25,0.303,13.192,0.365 scv4,kitti,7.102,0.162,6.617,0.123,3.127,0.055,2.498,0.046 +scv4,sintel,2.646,0.066,1.529,0.032,2.527,0.055,3.591,0.072 +scv4,things,3.849,0.105,1.796,0.056,4.232,0.143,9.847,0.237 starflow,kitti,7.231,0.187,5.011,0.145,1.712,0.064,2.818,0.116 starflow,sintel,2.022,0.08,1.6,0.055,3.545,0.128,7.819,0.223 starflow,things,3.584,0.105,1.822,0.062,4.145,0.166,9.74,0.276 +vcn,sintel,2.251,0.088,1.595,0.061,2.209,0.1,4.314,0.175 vcn,chairs,3.96,0.123,2.801,0.087,4.444,0.197,10.773,0.333 vcn,kitti,9.196,0.207,5.934,0.146,1.141,0.041,1.455,0.053 -vcn,sintel,2.251,0.088,1.595,0.061,2.209,0.1,4.314,0.175 vcn,things,3.966,0.126,2.494,0.084,3.466,0.149,8.627,0.261 vcn_small,chairs,4.297,0.149,3.217,0.114,5.656,0.309,13.18,0.43 vcn_small,things,5.247,0.152,4.8,0.128,4.248,0.223,9.692,0.339 diff --git a/docs/source/starting/installation.rst b/docs/source/starting/installation.rst index ab8a27c..ed1c8d6 100644 --- a/docs/source/starting/installation.rst +++ b/docs/source/starting/installation.rst @@ -60,19 +60,6 @@ checkpoint name to ``--pretrained_ckpt`` as follows: This will show an error message with a list of the available checkpoint names. -Conda environment -================= - -It is recommended to use a virtual environment, such as ``conda`` or ``virtualenv``. -Most of the PTLFlow tests are done in ``conda``. To install PTLFlow in -a new conda environment, run: - -.. code-block:: bash - - conda create --name ptlflow-env - conda activate ptlflow-env - pip install ptlflow - Optional dependencies ===================== @@ -116,11 +103,9 @@ The instructions below show how to create a conda environment and install the re .. code-block:: bash - conda create --name ptlflow-env - conda activate ptlflow-env - pip install -r requirements.txt + conda env create -f environment.yml -Then you should be able to use PTLFlow from inside this directory. +This will create a conda environment called ``ptlflow`` with all the required dependencies already installed. Another option is to install PTLFlow to your environment. The benefit is that ptlflow will be accessible from anywhere while using the environment. The drawback is that you will have to reinstall diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..d79fd5e --- /dev/null +++ b/environment.yml @@ -0,0 +1,27 @@ +name: ptlflow +channels: + - pytorch + - defaults + - conda-forge +dependencies: + - python>=3.8,<=3.10 + - pip + - cudatoolkit>=11.0,<=11.6 + - pytorch>=1.8.1,<=1.12 + - pytorch-lightning>=1.1.0,<=1.6,!=1.3,!=1.4 + - torchmetrics>=0.2,<=0.9 + - torchvision>=0.8.2,<=0.13 + - pip: + - einops>=0.3.0,<=0.4.* + - numpy>=1.17.0,<=1.22.* + - opencv-python>=4.0.0.21,<=4.6.* + - packaging>=20.0,<=21.* + - pandas>=1.1.0,<=1.4.* + - pillow>=5.0,<=9.2.* + - plotly>=5.0.0,<=5.9.* + - pypng~=0.0.16 + - requests>=2.0.0,<=2.28.* + - scipy>=1.0.0,<=1.9.* + - tabulate~=0.8.3 + - timm~=0.6.3 + - tqdm>=4.41.0,<=4.64.* \ No newline at end of file diff --git a/ptlflow/__init__.py b/ptlflow/__init__.py index 79091ca..a2a730e 100644 --- a/ptlflow/__init__.py +++ b/ptlflow/__init__.py @@ -16,7 +16,7 @@ # limitations under the License. # ============================================================================= -__version__ = '0.2.5' +__version__ = '0.2.6' import logging from argparse import Namespace @@ -28,8 +28,11 @@ from torch import hub from ptlflow.models.base_model.base_model import BaseModel +from ptlflow.models.craft.craft import CRAFT +from ptlflow.models.csflow.csflow import CSFlow from ptlflow.models.dicl.dicl import DICL from ptlflow.models.fastflownet.fastflownet import FastFlowNet +from ptlflow.models.flowformer.flowformer import FlowFormer from ptlflow.models.flownet.flownet2 import FlowNet2 from ptlflow.models.flownet.flownetc import FlowNetC from ptlflow.models.flownet.flownetcs import FlowNetCS @@ -37,6 +40,8 @@ from ptlflow.models.flownet.flownets import FlowNetS from ptlflow.models.flownet.flownetsd import FlowNetSD from ptlflow.models.gma.gma import GMA +from ptlflow.models.gmflow.gmflow import GMFlow, GMFlowWithRefinement +from ptlflow.models.gmflownet.gmflownet import GMFlowNet, GMFlowNetMix from ptlflow.models.hd3.hd3 import HD3, HD3Context from ptlflow.models.irr.pwcnet import IRRPWCNet from ptlflow.models.irr.pwcnet_irr import IRRPWCNetIRR @@ -65,8 +70,11 @@ models_dict = { + 'craft': CRAFT, + 'csflow': CSFlow, 'dicl': DICL, 'fastflownet': FastFlowNet, + 'flowformer': FlowFormer, 'flownet2': FlowNet2, 'flownetc': FlowNetC, 'flownetcs': FlowNetCS, @@ -74,6 +82,10 @@ 'flownets': FlowNetS, 'flownetsd': FlowNetSD, 'gma': GMA, + 'gmflow': GMFlow, + 'gmflow_refine': GMFlowWithRefinement, + 'gmflownet': GMFlowNet, + 'gmflownet_mix': GMFlowNetMix, 'hd3': HD3, 'hd3_ctxt': HD3Context, 'irr_pwc': IRRPWC, diff --git a/ptlflow/data/AutoFlow_val.txt b/ptlflow/data/AutoFlow_val.txt new file mode 100644 index 0000000..56fe7b4 --- /dev/null +++ b/ptlflow/data/AutoFlow_val.txt @@ -0,0 +1,2000 @@ +table_0_batch_1 +table_0_batch_69 +table_0_batch_71 +table_0_batch_93 +table_0_batch_112 +table_0_batch_117 +table_1_batch_22 +table_1_batch_45 +table_1_batch_56 +table_1_batch_59 +table_1_batch_93 +table_1_batch_97 +table_1_batch_98 +table_2_batch_13 +table_2_batch_15 +table_2_batch_19 +table_2_batch_43 +table_2_batch_89 +table_2_batch_90 +table_2_batch_112 +table_3_batch_34 +table_3_batch_55 +table_3_batch_56 +table_3_batch_65 +table_3_batch_97 +table_3_batch_116 +table_4_batch_1 +table_4_batch_9 +table_4_batch_32 +table_4_batch_55 +table_4_batch_58 +table_4_batch_67 +table_4_batch_112 +table_5_batch_27 +table_5_batch_43 +table_5_batch_64 +table_5_batch_81 +table_5_batch_84 +table_5_batch_105 +table_5_batch_114 +table_6_batch_5 +table_6_batch_18 +table_6_batch_19 +table_6_batch_42 +table_6_batch_47 +table_6_batch_59 +table_6_batch_92 +table_6_batch_102 +table_7_batch_3 +table_7_batch_16 +table_7_batch_19 +table_7_batch_23 +table_7_batch_62 +table_7_batch_88 +table_8_batch_4 +table_8_batch_11 +table_8_batch_15 +table_8_batch_34 +table_8_batch_67 +table_8_batch_83 +table_9_batch_6 +table_9_batch_19 +table_9_batch_20 +table_9_batch_55 +table_9_batch_58 +table_9_batch_85 +table_9_batch_97 +table_10_batch_10 +table_10_batch_30 +table_10_batch_41 +table_10_batch_84 +table_10_batch_88 +table_10_batch_95 +table_10_batch_114 +table_11_batch_27 +table_11_batch_29 +table_11_batch_40 +table_11_batch_108 +table_11_batch_123 +table_11_batch_126 +table_12_batch_25 +table_12_batch_26 +table_12_batch_36 +table_12_batch_41 +table_12_batch_86 +table_12_batch_90 +table_13_batch_11 +table_13_batch_19 +table_13_batch_23 +table_13_batch_54 +table_13_batch_83 +table_13_batch_104 +table_14_batch_4 +table_14_batch_12 +table_14_batch_73 +table_14_batch_79 +table_14_batch_94 +table_14_batch_103 +table_14_batch_132 +table_15_batch_8 +table_15_batch_12 +table_15_batch_71 +table_15_batch_92 +table_15_batch_94 +table_16_batch_2 +table_16_batch_6 +table_16_batch_49 +table_16_batch_56 +table_16_batch_64 +table_16_batch_95 +table_16_batch_130 +table_16_batch_133 +table_17_batch_2 +table_17_batch_5 +table_17_batch_9 +table_17_batch_21 +table_17_batch_88 +table_17_batch_108 +table_18_batch_8 +table_18_batch_52 +table_18_batch_66 +table_18_batch_73 +table_18_batch_115 +table_18_batch_120 +table_19_batch_4 +table_19_batch_29 +table_19_batch_34 +table_19_batch_73 +table_19_batch_82 +table_19_batch_95 +table_20_batch_45 +table_20_batch_56 +table_20_batch_64 +table_20_batch_97 +table_20_batch_104 +table_20_batch_131 +table_20_batch_132 +table_20_batch_137 +table_21_batch_2 +table_21_batch_28 +table_21_batch_105 +table_21_batch_127 +table_21_batch_131 +table_21_batch_133 +table_21_batch_138 +table_22_batch_9 +table_22_batch_15 +table_22_batch_24 +table_22_batch_26 +table_22_batch_49 +table_22_batch_77 +table_22_batch_112 +table_23_batch_14 +table_23_batch_26 +table_23_batch_39 +table_23_batch_55 +table_23_batch_71 +table_23_batch_105 +table_23_batch_134 +table_24_batch_12 +table_24_batch_45 +table_24_batch_60 +table_24_batch_66 +table_24_batch_102 +table_24_batch_117 +table_24_batch_118 +table_25_batch_13 +table_25_batch_56 +table_25_batch_72 +table_25_batch_87 +table_25_batch_110 +table_25_batch_111 +table_25_batch_119 +table_25_batch_121 +table_26_batch_13 +table_26_batch_17 +table_26_batch_42 +table_26_batch_121 +table_26_batch_124 +table_26_batch_132 +table_27_batch_4 +table_27_batch_57 +table_27_batch_80 +table_27_batch_91 +table_27_batch_102 +table_27_batch_108 +table_28_batch_4 +table_28_batch_24 +table_28_batch_35 +table_28_batch_38 +table_28_batch_56 +table_28_batch_128 +table_28_batch_131 +table_29_batch_43 +table_29_batch_55 +table_29_batch_114 +table_29_batch_117 +table_29_batch_121 +table_29_batch_123 +table_30_batch_13 +table_30_batch_24 +table_30_batch_57 +table_30_batch_62 +table_30_batch_65 +table_30_batch_87 +table_30_batch_118 +table_31_batch_12 +table_31_batch_33 +table_31_batch_65 +table_31_batch_87 +table_31_batch_110 +table_31_batch_116 +table_32_batch_37 +table_32_batch_53 +table_32_batch_80 +table_32_batch_81 +table_32_batch_85 +table_32_batch_102 +table_32_batch_108 +table_33_batch_6 +table_33_batch_17 +table_33_batch_27 +table_33_batch_37 +table_33_batch_43 +table_33_batch_55 +table_33_batch_66 +table_34_batch_5 +table_34_batch_17 +table_34_batch_26 +table_34_batch_95 +table_34_batch_104 +table_34_batch_118 +table_34_batch_141 +table_35_batch_4 +table_35_batch_85 +table_35_batch_108 +table_35_batch_114 +table_35_batch_122 +table_35_batch_126 +table_36_batch_14 +table_36_batch_23 +table_36_batch_61 +table_36_batch_77 +table_36_batch_78 +table_36_batch_90 +table_36_batch_111 +table_36_batch_147 +table_37_batch_10 +table_37_batch_22 +table_37_batch_43 +table_37_batch_47 +table_37_batch_54 +table_37_batch_69 +table_37_batch_119 +table_38_batch_4 +table_38_batch_53 +table_38_batch_75 +table_38_batch_114 +table_38_batch_124 +table_38_batch_128 +table_39_batch_29 +table_39_batch_31 +table_39_batch_58 +table_39_batch_89 +table_39_batch_111 +table_39_batch_112 +table_39_batch_123 +table_40_batch_34 +table_40_batch_41 +table_40_batch_61 +table_40_batch_65 +table_40_batch_66 +table_40_batch_105 +table_40_batch_108 +table_41_batch_7 +table_41_batch_25 +table_41_batch_30 +table_41_batch_54 +table_41_batch_106 +table_41_batch_112 +table_41_batch_116 +table_42_batch_2 +table_42_batch_21 +table_42_batch_33 +table_42_batch_73 +table_42_batch_103 +table_42_batch_123 +table_43_batch_32 +table_43_batch_34 +table_43_batch_46 +table_43_batch_48 +table_43_batch_109 +table_43_batch_118 +table_43_batch_124 +table_43_batch_132 +table_44_batch_3 +table_44_batch_23 +table_44_batch_30 +table_44_batch_33 +table_44_batch_42 +table_44_batch_99 +table_44_batch_127 +table_45_batch_16 +table_45_batch_29 +table_45_batch_59 +table_45_batch_86 +table_45_batch_89 +table_45_batch_121 +table_46_batch_5 +table_46_batch_19 +table_46_batch_23 +table_46_batch_50 +table_46_batch_72 +table_46_batch_90 +table_46_batch_130 +table_47_batch_5 +table_47_batch_8 +table_47_batch_41 +table_47_batch_74 +table_47_batch_76 +table_47_batch_116 +table_47_batch_135 +table_47_batch_144 +table_48_batch_6 +table_48_batch_45 +table_48_batch_48 +table_48_batch_61 +table_48_batch_102 +table_49_batch_35 +table_49_batch_63 +table_49_batch_67 +table_49_batch_74 +table_49_batch_116 +table_49_batch_124 +table_49_batch_133 +table_50_batch_6 +table_50_batch_9 +table_50_batch_24 +table_50_batch_31 +table_50_batch_99 +table_50_batch_116 +table_50_batch_130 +table_51_batch_20 +table_51_batch_44 +table_51_batch_72 +table_51_batch_103 +table_51_batch_104 +table_51_batch_110 +table_52_batch_12 +table_52_batch_26 +table_52_batch_43 +table_52_batch_71 +table_52_batch_107 +table_52_batch_121 +table_53_batch_16 +table_53_batch_27 +table_53_batch_40 +table_53_batch_79 +table_53_batch_86 +table_53_batch_111 +table_53_batch_127 +table_54_batch_0 +table_54_batch_16 +table_54_batch_37 +table_54_batch_62 +table_54_batch_104 +table_54_batch_122 +table_55_batch_12 +table_55_batch_15 +table_55_batch_59 +table_55_batch_86 +table_55_batch_107 +table_55_batch_111 +table_55_batch_128 +table_56_batch_13 +table_56_batch_15 +table_56_batch_40 +table_56_batch_84 +table_56_batch_85 +table_56_batch_89 +table_56_batch_128 +table_57_batch_14 +table_57_batch_18 +table_57_batch_38 +table_57_batch_76 +table_57_batch_89 +table_57_batch_113 +table_57_batch_129 +table_58_batch_8 +table_58_batch_37 +table_58_batch_38 +table_58_batch_47 +table_58_batch_55 +table_58_batch_67 +table_58_batch_73 +table_59_batch_9 +table_59_batch_45 +table_59_batch_107 +table_59_batch_133 +table_59_batch_134 +table_59_batch_136 +table_59_batch_148 +table_60_batch_46 +table_60_batch_50 +table_60_batch_95 +table_60_batch_97 +table_60_batch_102 +table_60_batch_106 +table_60_batch_130 +table_61_batch_16 +table_61_batch_20 +table_61_batch_98 +table_61_batch_103 +table_61_batch_110 +table_61_batch_120 +table_61_batch_121 +table_62_batch_84 +table_62_batch_88 +table_62_batch_103 +table_62_batch_106 +table_62_batch_107 +table_62_batch_108 +table_62_batch_111 +table_63_batch_63 +table_63_batch_88 +table_63_batch_97 +table_63_batch_105 +table_63_batch_107 +table_63_batch_122 +table_63_batch_126 +table_64_batch_1 +table_64_batch_5 +table_64_batch_64 +table_64_batch_70 +table_64_batch_92 +table_64_batch_110 +table_64_batch_122 +table_65_batch_2 +table_65_batch_25 +table_65_batch_27 +table_65_batch_57 +table_65_batch_97 +table_65_batch_109 +table_66_batch_13 +table_66_batch_32 +table_66_batch_61 +table_66_batch_69 +table_66_batch_76 +table_66_batch_86 +table_67_batch_4 +table_67_batch_24 +table_67_batch_70 +table_67_batch_96 +table_67_batch_110 +table_67_batch_124 +table_68_batch_18 +table_68_batch_47 +table_68_batch_52 +table_68_batch_73 +table_68_batch_75 +table_68_batch_106 +table_68_batch_107 +table_69_batch_88 +table_69_batch_97 +table_69_batch_98 +table_69_batch_110 +table_69_batch_114 +table_69_batch_120 +table_70_batch_9 +table_70_batch_15 +table_70_batch_43 +table_70_batch_48 +table_70_batch_77 +table_70_batch_101 +table_70_batch_136 +table_71_batch_6 +table_71_batch_21 +table_71_batch_59 +table_71_batch_97 +table_71_batch_106 +table_71_batch_108 +table_71_batch_141 +table_72_batch_13 +table_72_batch_21 +table_72_batch_46 +table_72_batch_70 +table_72_batch_98 +table_72_batch_104 +table_72_batch_121 +table_73_batch_12 +table_73_batch_40 +table_73_batch_85 +table_73_batch_98 +table_73_batch_105 +table_73_batch_116 +table_74_batch_18 +table_74_batch_55 +table_74_batch_57 +table_74_batch_79 +table_74_batch_101 +table_74_batch_122 +table_74_batch_128 +table_75_batch_12 +table_75_batch_31 +table_75_batch_39 +table_75_batch_47 +table_75_batch_49 +table_75_batch_75 +table_75_batch_141 +table_76_batch_31 +table_76_batch_42 +table_76_batch_44 +table_76_batch_73 +table_76_batch_77 +table_76_batch_81 +table_77_batch_11 +table_77_batch_20 +table_77_batch_69 +table_77_batch_72 +table_77_batch_107 +table_77_batch_110 +table_77_batch_113 +table_78_batch_3 +table_78_batch_13 +table_78_batch_52 +table_78_batch_61 +table_78_batch_76 +table_78_batch_79 +table_78_batch_133 +table_79_batch_17 +table_79_batch_104 +table_79_batch_108 +table_79_batch_115 +table_79_batch_122 +table_79_batch_126 +table_80_batch_15 +table_80_batch_32 +table_80_batch_35 +table_80_batch_48 +table_80_batch_52 +table_80_batch_114 +table_80_batch_119 +table_81_batch_24 +table_81_batch_35 +table_81_batch_46 +table_81_batch_51 +table_81_batch_55 +table_81_batch_74 +table_81_batch_125 +table_82_batch_36 +table_82_batch_39 +table_82_batch_56 +table_82_batch_81 +table_82_batch_86 +table_82_batch_118 +table_82_batch_133 +table_83_batch_56 +table_83_batch_59 +table_83_batch_64 +table_83_batch_68 +table_83_batch_74 +table_83_batch_82 +table_83_batch_90 +table_84_batch_12 +table_84_batch_43 +table_84_batch_57 +table_84_batch_92 +table_84_batch_94 +table_84_batch_119 +table_84_batch_126 +table_85_batch_18 +table_85_batch_36 +table_85_batch_57 +table_85_batch_68 +table_85_batch_85 +table_85_batch_121 +table_85_batch_127 +table_85_batch_148 +table_86_batch_8 +table_86_batch_14 +table_86_batch_35 +table_86_batch_38 +table_86_batch_68 +table_87_batch_13 +table_87_batch_31 +table_87_batch_33 +table_87_batch_47 +table_87_batch_81 +table_87_batch_108 +table_88_batch_15 +table_88_batch_63 +table_88_batch_71 +table_88_batch_98 +table_88_batch_112 +table_88_batch_122 +table_89_batch_2 +table_89_batch_35 +table_89_batch_39 +table_89_batch_54 +table_89_batch_88 +table_89_batch_106 +table_89_batch_111 +table_90_batch_4 +table_90_batch_36 +table_90_batch_47 +table_90_batch_54 +table_90_batch_120 +table_90_batch_126 +table_91_batch_56 +table_91_batch_68 +table_91_batch_69 +table_91_batch_74 +table_91_batch_82 +table_91_batch_96 +table_92_batch_17 +table_92_batch_19 +table_92_batch_31 +table_92_batch_49 +table_92_batch_79 +table_92_batch_110 +table_92_batch_135 +table_93_batch_20 +table_93_batch_47 +table_93_batch_51 +table_93_batch_55 +table_93_batch_69 +table_93_batch_92 +table_94_batch_11 +table_94_batch_34 +table_94_batch_79 +table_94_batch_84 +table_94_batch_93 +table_94_batch_108 +table_95_batch_9 +table_95_batch_36 +table_95_batch_48 +table_95_batch_57 +table_95_batch_77 +table_95_batch_109 +table_96_batch_42 +table_96_batch_82 +table_96_batch_109 +table_96_batch_120 +table_96_batch_127 +table_96_batch_138 +table_96_batch_152 +table_97_batch_12 +table_97_batch_33 +table_97_batch_50 +table_97_batch_72 +table_97_batch_77 +table_97_batch_99 +table_97_batch_102 +table_98_batch_29 +table_98_batch_33 +table_98_batch_50 +table_98_batch_61 +table_98_batch_92 +table_98_batch_99 +table_98_batch_142 +table_99_batch_1 +table_99_batch_35 +table_99_batch_59 +table_99_batch_64 +table_99_batch_80 +table_99_batch_102 +table_99_batch_113 +table_100_batch_12 +table_100_batch_25 +table_100_batch_48 +table_100_batch_77 +table_100_batch_108 +table_100_batch_115 +table_100_batch_119 +table_100_batch_148 +table_101_batch_10 +table_101_batch_37 +table_101_batch_56 +table_101_batch_79 +table_101_batch_82 +table_101_batch_87 +table_101_batch_108 +table_102_batch_1 +table_102_batch_4 +table_102_batch_15 +table_102_batch_20 +table_102_batch_82 +table_102_batch_111 +table_103_batch_50 +table_103_batch_73 +table_103_batch_74 +table_103_batch_85 +table_103_batch_104 +table_103_batch_119 +table_103_batch_122 +table_104_batch_7 +table_104_batch_22 +table_104_batch_30 +table_104_batch_65 +table_104_batch_72 +table_104_batch_77 +table_105_batch_57 +table_105_batch_89 +table_105_batch_104 +table_105_batch_112 +table_105_batch_116 +table_105_batch_118 +table_106_batch_2 +table_106_batch_5 +table_106_batch_35 +table_106_batch_40 +table_106_batch_62 +table_106_batch_78 +table_106_batch_119 +table_107_batch_20 +table_107_batch_41 +table_107_batch_54 +table_107_batch_72 +table_107_batch_90 +table_107_batch_104 +table_107_batch_131 +table_108_batch_19 +table_108_batch_29 +table_108_batch_39 +table_108_batch_53 +table_108_batch_70 +table_108_batch_118 +table_108_batch_122 +table_109_batch_44 +table_109_batch_67 +table_109_batch_84 +table_109_batch_90 +table_109_batch_93 +table_109_batch_138 +table_109_batch_145 +table_110_batch_23 +table_110_batch_37 +table_110_batch_51 +table_110_batch_79 +table_110_batch_127 +table_110_batch_129 +table_111_batch_0 +table_111_batch_30 +table_111_batch_56 +table_111_batch_70 +table_111_batch_85 +table_111_batch_96 +table_111_batch_120 +table_111_batch_126 +table_112_batch_9 +table_112_batch_62 +table_112_batch_67 +table_112_batch_79 +table_112_batch_92 +table_113_batch_14 +table_113_batch_23 +table_113_batch_58 +table_113_batch_69 +table_113_batch_83 +table_113_batch_104 +table_114_batch_0 +table_114_batch_27 +table_114_batch_45 +table_114_batch_76 +table_114_batch_78 +table_115_batch_1 +table_115_batch_4 +table_115_batch_20 +table_115_batch_33 +table_116_batch_0 +table_116_batch_17 +table_116_batch_21 +table_116_batch_35 +table_116_batch_49 +table_116_batch_82 +table_116_batch_88 +table_116_batch_99 +table_117_batch_16 +table_117_batch_25 +table_117_batch_34 +table_117_batch_60 +table_117_batch_74 +table_117_batch_93 +table_117_batch_128 +table_118_batch_6 +table_118_batch_48 +table_118_batch_60 +table_118_batch_63 +table_118_batch_99 +table_118_batch_131 +table_118_batch_134 +table_119_batch_4 +table_119_batch_16 +table_119_batch_52 +table_119_batch_97 +table_119_batch_98 +table_119_batch_110 +table_119_batch_131 +table_120_batch_9 +table_120_batch_37 +table_120_batch_43 +table_120_batch_79 +table_120_batch_84 +table_120_batch_110 +table_120_batch_112 +table_121_batch_6 +table_121_batch_57 +table_121_batch_66 +table_121_batch_86 +table_121_batch_99 +table_121_batch_114 +table_121_batch_116 +table_122_batch_16 +table_122_batch_18 +table_122_batch_34 +table_122_batch_53 +table_122_batch_58 +table_122_batch_78 +table_123_batch_1 +table_123_batch_51 +table_123_batch_62 +table_123_batch_80 +table_123_batch_96 +table_123_batch_100 +table_124_batch_44 +table_124_batch_57 +table_124_batch_68 +table_124_batch_82 +table_124_batch_110 +table_124_batch_119 +table_124_batch_125 +table_125_batch_11 +table_125_batch_34 +table_125_batch_45 +table_125_batch_88 +table_125_batch_94 +table_125_batch_132 +table_126_batch_24 +table_126_batch_51 +table_126_batch_68 +table_126_batch_82 +table_126_batch_117 +table_126_batch_128 +table_126_batch_129 +table_127_batch_8 +table_127_batch_28 +table_127_batch_72 +table_127_batch_74 +table_127_batch_83 +table_127_batch_124 +table_127_batch_134 +table_128_batch_7 +table_128_batch_28 +table_128_batch_49 +table_128_batch_64 +table_128_batch_90 +table_128_batch_104 +table_128_batch_122 +table_129_batch_4 +table_129_batch_8 +table_129_batch_14 +table_129_batch_31 +table_129_batch_38 +table_129_batch_111 +table_129_batch_135 +table_130_batch_30 +table_130_batch_49 +table_130_batch_54 +table_130_batch_71 +table_130_batch_75 +table_130_batch_95 +table_130_batch_99 +table_131_batch_25 +table_131_batch_50 +table_131_batch_51 +table_131_batch_62 +table_131_batch_74 +table_131_batch_115 +table_132_batch_50 +table_132_batch_55 +table_132_batch_67 +table_132_batch_72 +table_132_batch_89 +table_132_batch_101 +table_133_batch_8 +table_133_batch_18 +table_133_batch_57 +table_133_batch_68 +table_133_batch_102 +table_133_batch_113 +table_134_batch_3 +table_134_batch_6 +table_134_batch_21 +table_134_batch_23 +table_134_batch_25 +table_134_batch_74 +table_134_batch_94 +table_134_batch_117 +table_135_batch_7 +table_135_batch_17 +table_135_batch_49 +table_135_batch_79 +table_135_batch_97 +table_135_batch_99 +table_136_batch_1 +table_136_batch_31 +table_136_batch_40 +table_136_batch_48 +table_136_batch_90 +table_136_batch_108 +table_137_batch_2 +table_137_batch_8 +table_137_batch_25 +table_137_batch_60 +table_137_batch_78 +table_137_batch_120 +table_137_batch_133 +table_138_batch_16 +table_138_batch_22 +table_138_batch_41 +table_138_batch_62 +table_138_batch_72 +table_138_batch_99 +table_139_batch_15 +table_139_batch_29 +table_139_batch_31 +table_139_batch_74 +table_139_batch_79 +table_139_batch_80 +table_139_batch_135 +table_139_batch_140 +table_140_batch_3 +table_140_batch_87 +table_140_batch_89 +table_140_batch_92 +table_140_batch_99 +table_140_batch_140 +table_140_batch_141 +table_141_batch_2 +table_141_batch_50 +table_141_batch_54 +table_141_batch_83 +table_141_batch_86 +table_141_batch_88 +table_141_batch_105 +table_142_batch_9 +table_142_batch_15 +table_142_batch_20 +table_142_batch_56 +table_142_batch_108 +table_142_batch_114 +table_142_batch_137 +table_143_batch_6 +table_143_batch_9 +table_143_batch_18 +table_143_batch_81 +table_143_batch_107 +table_143_batch_111 +table_144_batch_31 +table_144_batch_37 +table_144_batch_41 +table_144_batch_43 +table_144_batch_52 +table_144_batch_97 +table_144_batch_111 +table_144_batch_119 +table_145_batch_22 +table_145_batch_24 +table_145_batch_49 +table_145_batch_56 +table_145_batch_64 +table_146_batch_1 +table_146_batch_21 +table_146_batch_55 +table_146_batch_66 +table_146_batch_74 +table_146_batch_79 +table_146_batch_82 +table_146_batch_125 +table_147_batch_6 +table_147_batch_67 +table_147_batch_68 +table_147_batch_75 +table_147_batch_85 +table_147_batch_115 +table_148_batch_14 +table_148_batch_27 +table_148_batch_31 +table_148_batch_69 +table_148_batch_71 +table_148_batch_98 +table_148_batch_139 +table_148_batch_143 +table_149_batch_1 +table_149_batch_10 +table_149_batch_31 +table_149_batch_49 +table_149_batch_63 +table_149_batch_98 +table_150_batch_12 +table_150_batch_23 +table_150_batch_41 +table_150_batch_57 +table_150_batch_84 +table_150_batch_89 +table_150_batch_122 +table_151_batch_7 +table_151_batch_22 +table_151_batch_60 +table_151_batch_67 +table_151_batch_71 +table_151_batch_111 +table_152_batch_39 +table_152_batch_41 +table_152_batch_55 +table_152_batch_76 +table_152_batch_127 +table_152_batch_129 +table_152_batch_145 +table_153_batch_2 +table_153_batch_4 +table_153_batch_10 +table_153_batch_19 +table_153_batch_43 +table_153_batch_72 +table_153_batch_119 +table_154_batch_10 +table_154_batch_42 +table_154_batch_69 +table_154_batch_70 +table_154_batch_94 +table_154_batch_117 +table_154_batch_132 +table_155_batch_10 +table_155_batch_23 +table_155_batch_43 +table_155_batch_56 +table_155_batch_61 +table_155_batch_107 +table_155_batch_139 +table_156_batch_12 +table_156_batch_46 +table_156_batch_86 +table_156_batch_88 +table_156_batch_93 +table_156_batch_110 +table_156_batch_117 +table_156_batch_125 +table_157_batch_21 +table_157_batch_70 +table_157_batch_94 +table_157_batch_97 +table_157_batch_98 +table_157_batch_134 +table_157_batch_144 +table_158_batch_0 +table_158_batch_13 +table_158_batch_35 +table_158_batch_58 +table_158_batch_71 +table_158_batch_108 +table_159_batch_0 +table_159_batch_35 +table_159_batch_56 +table_159_batch_57 +table_159_batch_87 +table_159_batch_102 +table_159_batch_118 +table_159_batch_125 +table_160_batch_15 +table_160_batch_20 +table_160_batch_24 +table_160_batch_31 +table_160_batch_34 +table_160_batch_117 +table_161_batch_4 +table_161_batch_56 +table_161_batch_60 +table_161_batch_76 +table_161_batch_84 +table_161_batch_91 +table_161_batch_101 +table_161_batch_105 +table_162_batch_23 +table_162_batch_54 +table_162_batch_66 +table_162_batch_68 +table_162_batch_76 +table_162_batch_113 +table_163_batch_0 +table_163_batch_14 +table_163_batch_50 +table_163_batch_54 +table_163_batch_80 +table_163_batch_84 +table_163_batch_105 +table_163_batch_134 +table_164_batch_1 +table_164_batch_18 +table_164_batch_51 +table_164_batch_78 +table_164_batch_92 +table_164_batch_100 +table_165_batch_47 +table_165_batch_49 +table_165_batch_60 +table_165_batch_74 +table_165_batch_113 +table_165_batch_126 +table_165_batch_127 +table_166_batch_17 +table_166_batch_20 +table_166_batch_74 +table_166_batch_94 +table_166_batch_118 +table_166_batch_124 +table_167_batch_7 +table_167_batch_9 +table_167_batch_10 +table_167_batch_18 +table_167_batch_71 +table_167_batch_84 +table_168_batch_30 +table_168_batch_44 +table_168_batch_92 +table_168_batch_108 +table_168_batch_124 +table_168_batch_125 +table_169_batch_21 +table_169_batch_50 +table_169_batch_87 +table_169_batch_92 +table_169_batch_93 +table_169_batch_108 +table_170_batch_1 +table_170_batch_3 +table_170_batch_23 +table_170_batch_68 +table_170_batch_91 +table_170_batch_97 +table_170_batch_126 +table_171_batch_14 +table_171_batch_20 +table_171_batch_58 +table_171_batch_60 +table_171_batch_92 +table_171_batch_107 +table_171_batch_129 +table_171_batch_149 +table_172_batch_23 +table_172_batch_36 +table_172_batch_59 +table_172_batch_68 +table_172_batch_76 +table_172_batch_80 +table_173_batch_21 +table_173_batch_65 +table_173_batch_68 +table_173_batch_77 +table_173_batch_79 +table_173_batch_80 +table_173_batch_128 +table_174_batch_5 +table_174_batch_43 +table_174_batch_49 +table_174_batch_94 +table_174_batch_97 +table_174_batch_105 +table_174_batch_110 +table_175_batch_16 +table_175_batch_27 +table_175_batch_62 +table_175_batch_91 +table_175_batch_92 +table_175_batch_108 +table_175_batch_115 +table_176_batch_7 +table_176_batch_38 +table_176_batch_42 +table_176_batch_69 +table_176_batch_94 +table_176_batch_95 +table_176_batch_121 +table_177_batch_4 +table_177_batch_19 +table_177_batch_23 +table_177_batch_50 +table_177_batch_75 +table_177_batch_87 +table_178_batch_10 +table_178_batch_11 +table_178_batch_38 +table_178_batch_41 +table_178_batch_99 +table_178_batch_103 +table_179_batch_4 +table_179_batch_12 +table_179_batch_20 +table_179_batch_25 +table_179_batch_59 +table_179_batch_64 +table_179_batch_116 +table_180_batch_47 +table_180_batch_58 +table_180_batch_63 +table_180_batch_68 +table_180_batch_79 +table_180_batch_96 +table_180_batch_100 +table_181_batch_24 +table_181_batch_33 +table_181_batch_40 +table_181_batch_48 +table_181_batch_112 +table_181_batch_120 +table_181_batch_122 +table_181_batch_151 +table_181_batch_169 +table_182_batch_29 +table_182_batch_41 +table_182_batch_58 +table_182_batch_74 +table_182_batch_78 +table_182_batch_109 +table_183_batch_49 +table_183_batch_52 +table_183_batch_83 +table_183_batch_99 +table_183_batch_111 +table_183_batch_112 +table_183_batch_123 +table_184_batch_4 +table_184_batch_13 +table_184_batch_28 +table_184_batch_53 +table_184_batch_56 +table_184_batch_58 +table_185_batch_8 +table_185_batch_56 +table_185_batch_73 +table_185_batch_93 +table_185_batch_104 +table_186_batch_4 +table_186_batch_13 +table_186_batch_23 +table_186_batch_81 +table_186_batch_105 +table_186_batch_125 +table_186_batch_130 +table_187_batch_29 +table_187_batch_43 +table_187_batch_47 +table_187_batch_48 +table_187_batch_67 +table_187_batch_82 +table_187_batch_117 +table_188_batch_4 +table_188_batch_26 +table_188_batch_49 +table_188_batch_50 +table_188_batch_87 +table_188_batch_108 +table_188_batch_137 +table_189_batch_1 +table_189_batch_4 +table_189_batch_33 +table_189_batch_47 +table_189_batch_95 +table_189_batch_122 +table_189_batch_130 +table_189_batch_147 +table_190_batch_3 +table_190_batch_39 +table_190_batch_76 +table_190_batch_87 +table_190_batch_92 +table_190_batch_104 +table_190_batch_124 +table_191_batch_10 +table_191_batch_49 +table_191_batch_93 +table_191_batch_105 +table_191_batch_122 +table_191_batch_130 +table_191_batch_135 +table_192_batch_11 +table_192_batch_51 +table_192_batch_58 +table_192_batch_59 +table_192_batch_69 +table_192_batch_120 +table_193_batch_4 +table_193_batch_13 +table_193_batch_23 +table_193_batch_67 +table_193_batch_74 +table_193_batch_115 +table_194_batch_4 +table_194_batch_24 +table_194_batch_47 +table_194_batch_86 +table_194_batch_99 +table_194_batch_116 +table_194_batch_137 +table_195_batch_31 +table_195_batch_58 +table_195_batch_60 +table_195_batch_91 +table_195_batch_97 +table_195_batch_103 +table_195_batch_119 +table_196_batch_0 +table_196_batch_1 +table_196_batch_21 +table_196_batch_47 +table_196_batch_91 +table_196_batch_95 +table_197_batch_1 +table_197_batch_17 +table_197_batch_26 +table_197_batch_64 +table_197_batch_75 +table_197_batch_81 +table_198_batch_23 +table_198_batch_38 +table_198_batch_72 +table_198_batch_79 +table_198_batch_80 +table_198_batch_87 +table_199_batch_20 +table_199_batch_33 +table_199_batch_45 +table_199_batch_50 +table_199_batch_63 +table_199_batch_75 +table_199_batch_132 +table_200_batch_33 +table_200_batch_34 +table_200_batch_59 +table_200_batch_62 +table_200_batch_79 +table_200_batch_120 +table_200_batch_131 +table_200_batch_139 +table_201_batch_25 +table_201_batch_48 +table_201_batch_64 +table_201_batch_97 +table_201_batch_104 +table_201_batch_127 +table_202_batch_24 +table_202_batch_50 +table_202_batch_57 +table_202_batch_61 +table_202_batch_72 +table_202_batch_75 +table_202_batch_93 +table_203_batch_47 +table_203_batch_62 +table_203_batch_85 +table_203_batch_123 +table_203_batch_124 +table_203_batch_125 +table_204_batch_5 +table_204_batch_9 +table_204_batch_24 +table_204_batch_57 +table_204_batch_62 +table_204_batch_82 +table_204_batch_125 +table_205_batch_1 +table_205_batch_11 +table_205_batch_13 +table_205_batch_61 +table_205_batch_79 +table_205_batch_129 +table_206_batch_16 +table_206_batch_26 +table_206_batch_37 +table_206_batch_39 +table_206_batch_51 +table_206_batch_107 +table_206_batch_125 +table_207_batch_10 +table_207_batch_33 +table_207_batch_40 +table_207_batch_55 +table_207_batch_76 +table_207_batch_112 +table_208_batch_8 +table_208_batch_10 +table_208_batch_68 +table_208_batch_103 +table_208_batch_125 +table_208_batch_133 +table_208_batch_136 +table_209_batch_14 +table_209_batch_42 +table_209_batch_58 +table_209_batch_70 +table_209_batch_84 +table_209_batch_91 +table_210_batch_0 +table_210_batch_7 +table_210_batch_10 +table_210_batch_16 +table_210_batch_58 +table_210_batch_65 +table_211_batch_3 +table_211_batch_17 +table_211_batch_20 +table_211_batch_47 +table_211_batch_54 +table_211_batch_109 +table_212_batch_8 +table_212_batch_20 +table_212_batch_36 +table_212_batch_60 +table_212_batch_83 +table_212_batch_95 +table_212_batch_97 +table_213_batch_8 +table_213_batch_79 +table_213_batch_88 +table_213_batch_93 +table_213_batch_94 +table_213_batch_104 +table_214_batch_11 +table_214_batch_20 +table_214_batch_33 +table_214_batch_42 +table_214_batch_57 +table_214_batch_69 +table_214_batch_110 +table_215_batch_17 +table_215_batch_20 +table_215_batch_27 +table_215_batch_44 +table_215_batch_95 +table_215_batch_98 +table_215_batch_104 +table_215_batch_127 +table_216_batch_5 +table_216_batch_58 +table_216_batch_61 +table_216_batch_64 +table_216_batch_83 +table_216_batch_113 +table_217_batch_3 +table_217_batch_20 +table_217_batch_40 +table_217_batch_56 +table_217_batch_74 +table_217_batch_78 +table_218_batch_13 +table_218_batch_24 +table_218_batch_36 +table_218_batch_61 +table_218_batch_71 +table_218_batch_103 +table_219_batch_12 +table_219_batch_14 +table_219_batch_18 +table_219_batch_36 +table_219_batch_38 +table_219_batch_99 +table_220_batch_15 +table_220_batch_39 +table_220_batch_78 +table_220_batch_91 +table_220_batch_110 +table_220_batch_122 +table_220_batch_136 +table_221_batch_2 +table_221_batch_38 +table_221_batch_40 +table_221_batch_45 +table_221_batch_94 +table_221_batch_111 +table_221_batch_136 +table_222_batch_11 +table_222_batch_34 +table_222_batch_67 +table_222_batch_120 +table_222_batch_127 +table_222_batch_135 +table_222_batch_138 +table_223_batch_10 +table_223_batch_15 +table_223_batch_58 +table_223_batch_76 +table_223_batch_80 +table_223_batch_85 +table_223_batch_130 +table_224_batch_21 +table_224_batch_49 +table_224_batch_56 +table_224_batch_61 +table_224_batch_68 +table_224_batch_72 +table_225_batch_23 +table_225_batch_33 +table_225_batch_39 +table_225_batch_41 +table_225_batch_59 +table_225_batch_117 +table_226_batch_6 +table_226_batch_31 +table_226_batch_69 +table_226_batch_75 +table_226_batch_82 +table_226_batch_100 +table_226_batch_107 +table_227_batch_2 +table_227_batch_6 +table_227_batch_59 +table_227_batch_68 +table_227_batch_96 +table_228_batch_7 +table_228_batch_11 +table_228_batch_17 +table_228_batch_54 +table_228_batch_77 +table_228_batch_93 +table_228_batch_132 +table_229_batch_23 +table_229_batch_41 +table_229_batch_48 +table_229_batch_58 +table_229_batch_102 +table_229_batch_104 +table_229_batch_116 +table_230_batch_1 +table_230_batch_3 +table_230_batch_5 +table_230_batch_38 +table_230_batch_46 +table_230_batch_77 +table_230_batch_112 +table_231_batch_19 +table_231_batch_46 +table_231_batch_62 +table_231_batch_73 +table_231_batch_74 +table_231_batch_118 +table_232_batch_45 +table_232_batch_75 +table_232_batch_97 +table_232_batch_102 +table_232_batch_104 +table_232_batch_113 +table_232_batch_132 +table_233_batch_1 +table_233_batch_40 +table_233_batch_63 +table_233_batch_72 +table_233_batch_89 +table_233_batch_95 +table_234_batch_30 +table_234_batch_37 +table_234_batch_47 +table_234_batch_72 +table_234_batch_82 +table_234_batch_83 +table_234_batch_115 +table_235_batch_3 +table_235_batch_19 +table_235_batch_37 +table_235_batch_89 +table_235_batch_90 +table_235_batch_93 +table_236_batch_43 +table_236_batch_48 +table_236_batch_62 +table_236_batch_64 +table_236_batch_70 +table_236_batch_97 +table_236_batch_126 +table_237_batch_9 +table_237_batch_67 +table_237_batch_72 +table_237_batch_105 +table_237_batch_120 +table_237_batch_122 +table_238_batch_13 +table_238_batch_47 +table_238_batch_56 +table_238_batch_58 +table_238_batch_93 +table_238_batch_97 +table_238_batch_123 +table_239_batch_34 +table_239_batch_38 +table_239_batch_51 +table_239_batch_83 +table_239_batch_106 +table_239_batch_119 +table_240_batch_35 +table_240_batch_55 +table_240_batch_72 +table_240_batch_79 +table_240_batch_84 +table_240_batch_96 +table_240_batch_113 +table_241_batch_5 +table_241_batch_26 +table_241_batch_35 +table_241_batch_76 +table_241_batch_77 +table_241_batch_90 +table_242_batch_9 +table_242_batch_10 +table_242_batch_40 +table_242_batch_55 +table_242_batch_79 +table_242_batch_85 +table_242_batch_119 +table_242_batch_132 +table_243_batch_30 +table_243_batch_43 +table_243_batch_46 +table_243_batch_63 +table_243_batch_77 +table_243_batch_91 +table_244_batch_10 +table_244_batch_20 +table_244_batch_32 +table_244_batch_42 +table_244_batch_62 +table_244_batch_102 +table_244_batch_114 +table_245_batch_51 +table_245_batch_62 +table_245_batch_64 +table_245_batch_79 +table_245_batch_109 +table_245_batch_123 +table_245_batch_124 +table_245_batch_151 +table_246_batch_22 +table_246_batch_42 +table_246_batch_64 +table_246_batch_90 +table_246_batch_97 +table_246_batch_98 +table_246_batch_143 +table_246_batch_147 +table_247_batch_25 +table_247_batch_57 +table_247_batch_64 +table_247_batch_101 +table_247_batch_117 +table_247_batch_134 +table_248_batch_4 +table_248_batch_19 +table_248_batch_23 +table_248_batch_47 +table_248_batch_58 +table_248_batch_72 +table_248_batch_90 +table_249_batch_3 +table_249_batch_83 +table_249_batch_88 +table_249_batch_91 +table_249_batch_106 +table_249_batch_108 +table_249_batch_136 +table_250_batch_26 +table_250_batch_33 +table_250_batch_55 +table_250_batch_88 +table_250_batch_92 +table_251_batch_3 +table_251_batch_21 +table_251_batch_38 +table_251_batch_74 +table_251_batch_79 +table_251_batch_94 +table_251_batch_104 +table_252_batch_9 +table_252_batch_58 +table_252_batch_82 +table_252_batch_84 +table_252_batch_103 +table_252_batch_119 +table_252_batch_124 +table_253_batch_31 +table_253_batch_33 +table_253_batch_56 +table_253_batch_58 +table_253_batch_72 +table_253_batch_76 +table_254_batch_13 +table_254_batch_30 +table_254_batch_52 +table_254_batch_53 +table_254_batch_73 +table_254_batch_81 +table_254_batch_90 +table_255_batch_36 +table_255_batch_37 +table_255_batch_39 +table_255_batch_108 +table_255_batch_109 +table_255_batch_117 +table_256_batch_1 +table_256_batch_11 +table_256_batch_17 +table_256_batch_69 +table_256_batch_76 +table_256_batch_85 +table_256_batch_111 +table_256_batch_151 +table_257_batch_25 +table_257_batch_33 +table_257_batch_62 +table_257_batch_102 +table_257_batch_120 +table_257_batch_124 +table_257_batch_128 +table_258_batch_19 +table_258_batch_35 +table_258_batch_45 +table_258_batch_55 +table_258_batch_60 +table_258_batch_74 +table_259_batch_39 +table_259_batch_81 +table_259_batch_92 +table_259_batch_106 +table_259_batch_112 +table_259_batch_115 +table_259_batch_117 +table_260_batch_5 +table_260_batch_23 +table_260_batch_24 +table_260_batch_49 +table_260_batch_53 +table_260_batch_88 +table_260_batch_99 +table_261_batch_13 +table_261_batch_16 +table_261_batch_20 +table_261_batch_42 +table_261_batch_47 +table_261_batch_87 +table_262_batch_0 +table_262_batch_7 +table_262_batch_22 +table_262_batch_57 +table_262_batch_64 +table_262_batch_90 +table_262_batch_121 +table_262_batch_140 +table_263_batch_6 +table_263_batch_50 +table_263_batch_82 +table_263_batch_105 +table_263_batch_115 +table_263_batch_137 +table_263_batch_138 +table_264_batch_1 +table_264_batch_11 +table_264_batch_44 +table_264_batch_51 +table_264_batch_56 +table_264_batch_92 +table_264_batch_113 +table_264_batch_131 +table_265_batch_6 +table_265_batch_23 +table_265_batch_38 +table_265_batch_82 +table_265_batch_89 +table_265_batch_100 +table_266_batch_3 +table_266_batch_9 +table_266_batch_27 +table_266_batch_79 +table_266_batch_80 +table_266_batch_101 +table_267_batch_8 +table_267_batch_47 +table_267_batch_79 +table_267_batch_101 +table_267_batch_114 +table_267_batch_119 +table_267_batch_120 +table_268_batch_29 +table_268_batch_52 +table_268_batch_69 +table_268_batch_109 +table_268_batch_111 +table_268_batch_114 +table_268_batch_117 +table_269_batch_0 +table_269_batch_8 +table_269_batch_31 +table_269_batch_65 +table_269_batch_106 +table_269_batch_110 +table_270_batch_4 +table_270_batch_9 +table_270_batch_10 +table_270_batch_13 +table_270_batch_20 +table_270_batch_34 +table_270_batch_48 +table_270_batch_138 +table_271_batch_20 +table_271_batch_68 +table_271_batch_77 +table_271_batch_80 +table_271_batch_93 +table_272_batch_9 +table_272_batch_19 +table_272_batch_30 +table_272_batch_51 +table_272_batch_93 +table_272_batch_99 +table_272_batch_114 +table_273_batch_7 +table_273_batch_16 +table_273_batch_64 +table_273_batch_81 +table_273_batch_101 +table_273_batch_123 +table_273_batch_125 +table_274_batch_62 +table_274_batch_73 +table_274_batch_87 +table_274_batch_92 +table_274_batch_99 +table_274_batch_114 +table_275_batch_10 +table_275_batch_12 +table_275_batch_28 +table_275_batch_39 +table_275_batch_56 +table_275_batch_61 +table_276_batch_42 +table_276_batch_63 +table_276_batch_70 +table_276_batch_74 +table_276_batch_79 +table_276_batch_85 +table_276_batch_118 +table_277_batch_0 +table_277_batch_2 +table_277_batch_65 +table_277_batch_98 +table_277_batch_119 +table_277_batch_132 +table_277_batch_133 +table_277_batch_139 +table_278_batch_1 +table_278_batch_64 +table_278_batch_66 +table_278_batch_90 +table_278_batch_95 +table_278_batch_123 +table_279_batch_16 +table_279_batch_31 +table_279_batch_33 +table_279_batch_44 +table_279_batch_100 +table_279_batch_106 +table_280_batch_16 +table_280_batch_23 +table_280_batch_45 +table_280_batch_51 +table_280_batch_67 +table_280_batch_118 +table_280_batch_144 +table_281_batch_11 +table_281_batch_14 +table_281_batch_21 +table_281_batch_23 +table_281_batch_57 +table_281_batch_119 +table_281_batch_130 +table_281_batch_133 +table_282_batch_14 +table_282_batch_24 +table_282_batch_66 +table_282_batch_67 +table_282_batch_106 +table_282_batch_111 +table_283_batch_5 +table_283_batch_9 +table_283_batch_14 +table_283_batch_18 +table_283_batch_19 +table_283_batch_76 +table_284_batch_11 +table_284_batch_24 +table_284_batch_46 +table_284_batch_92 +table_284_batch_96 +table_284_batch_112 +table_285_batch_29 +table_285_batch_45 +table_285_batch_98 +table_285_batch_106 +table_285_batch_116 +table_285_batch_122 +table_285_batch_140 +table_286_batch_35 +table_286_batch_37 +table_286_batch_55 +table_286_batch_68 +table_286_batch_69 +table_286_batch_77 +table_286_batch_115 +table_287_batch_21 +table_287_batch_53 +table_287_batch_59 +table_287_batch_110 +table_287_batch_113 +table_287_batch_115 +table_288_batch_7 +table_288_batch_25 +table_288_batch_63 +table_288_batch_90 +table_288_batch_102 +table_288_batch_105 +table_289_batch_12 +table_289_batch_42 +table_289_batch_83 +table_289_batch_89 +table_289_batch_90 +table_289_batch_106 +table_289_batch_114 +table_290_batch_3 +table_290_batch_4 +table_290_batch_39 +table_290_batch_65 +table_290_batch_83 +table_290_batch_111 +table_290_batch_130 +table_291_batch_17 +table_291_batch_38 +table_291_batch_62 +table_291_batch_94 +table_291_batch_111 +table_291_batch_127 +table_292_batch_3 +table_292_batch_31 +table_292_batch_69 +table_292_batch_88 +table_292_batch_97 +table_292_batch_107 +table_292_batch_143 +table_293_batch_19 +table_293_batch_34 +table_293_batch_40 +table_293_batch_58 +table_293_batch_70 +table_293_batch_108 +table_293_batch_135 +table_293_batch_150 +table_294_batch_22 +table_294_batch_53 +table_294_batch_87 +table_294_batch_95 +table_294_batch_120 +table_294_batch_123 +table_295_batch_22 +table_295_batch_47 +table_295_batch_49 +table_295_batch_52 +table_295_batch_77 +table_295_batch_120 +table_295_batch_133 +table_296_batch_13 +table_296_batch_81 +table_296_batch_87 +table_296_batch_94 +table_296_batch_96 +table_296_batch_122 +table_297_batch_33 +table_297_batch_39 +table_297_batch_48 +table_297_batch_77 +table_297_batch_108 +table_297_batch_109 +table_297_batch_118 +table_298_batch_11 +table_298_batch_22 +table_298_batch_24 +table_298_batch_67 +table_298_batch_84 +table_298_batch_93 +table_299_batch_11 +table_299_batch_44 +table_299_batch_86 +table_299_batch_93 +table_299_batch_99 +table_299_batch_111 +table_299_batch_119 \ No newline at end of file diff --git a/ptlflow/data/datasets.py b/ptlflow/data/datasets.py index d61e670..6d36fbc 100644 --- a/ptlflow/data/datasets.py +++ b/ptlflow/data/datasets.py @@ -212,6 +212,80 @@ def _log_status(self) -> None: logging.info('Loading %d samples from %s dataset.', self.__len__(), self.dataset_name) +class AutoFlowDataset(BaseFlowDataset): + """Handle the AutoFlow dataset.""" + + def __init__( + self, + root_dir: str, + split: str = 'train', + transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + max_flow: float = 10000.0, + get_valid_mask: bool = True, + get_meta: bool = True + ) -> None: + """Initialize AutoFlowDataset. + + Parameters + ---------- + root_dir : str + path to the root directory of the AutoFlow dataset. + split : str, default 'train' + Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval'}. + transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional + Transform to be applied on the inputs. + max_flow : float, default 10000.0 + Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked + as zero in the valid mask. + get_valid_mask : bool, default True + Whether to get or generate valid masks. + get_meta : bool, default True + Whether to get metadata. + """ + super().__init__( + dataset_name='AutoFlow', + split_name=split, + transform=transform, + max_flow=max_flow, + get_valid_mask=get_valid_mask, + get_occlusion_mask=False, + get_motion_boundary_mask=False, + get_backward=False, + get_meta=get_meta) + self.root_dir = root_dir + self.split_file = THIS_DIR / 'AutoFlow_val.txt' + + # Read data from disk + parts_dirs = [f'static_40k_png_{i+1}_of_4' for i in range(4)] + sample_paths = [] + for pdir in parts_dirs: + sample_paths.extend([p for p in (Path(root_dir) / pdir).glob('*') if p.is_dir()]) + + with open(self.split_file, 'r') as f: + val_names = f.read().strip().splitlines() + + if split == 'trainval': + remove_names = [] + elif split == 'train': + remove_names = val_names + elif split == 'val': + remove_names = [p.stem for p in sample_paths if p.stem not in val_names] + + # Keep only data from the correct split + self.img_paths = [ + [p / 'im0.png', p / 'im1.png'] + for p in sample_paths if p.stem not in remove_names] + self.flow_paths = [ + [p / 'forward.flo'] for p in sample_paths if p.stem not in remove_names] + assert len(self.img_paths) == len(self.flow_paths), f'{len(self.img_paths)} vs {len(self.flow_paths)}' + + self.metadata = [ + {'image_paths': [str(p) for p in paths], 'is_val': paths[0].stem in val_names, 'misc': ''} + for paths in self.img_paths] + + self._log_status() + + class FlyingChairsDataset(BaseFlowDataset): """Handle the FlyingChairs dataset.""" diff --git a/ptlflow/data/split_autoflow.py b/ptlflow/data/split_autoflow.py new file mode 100644 index 0000000..7b04516 --- /dev/null +++ b/ptlflow/data/split_autoflow.py @@ -0,0 +1,89 @@ +""" + +Create a file with a list of samples names from the AutoFlow [1] dataset to be used as validation samples. + +[1] Sun, Deqing et al. “AutoFlow: Learning a Better Training Set for Optical Flow.” CVPR. 2021. + +""" + +# ============================================================================= +# Copyright 2022 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +from argparse import ArgumentParser, Namespace +import os +from pathlib import Path +import random + +random.seed(42) + +THIS_DIR = Path(os.path.abspath(os.path.dirname(__file__))) + + +def _init_parser() -> ArgumentParser: + parser: ArgumentParser = ArgumentParser() + parser.add_argument('--autoflow_root', type=str, required=True) + parser.add_argument('--output_file', type=str, default=str(THIS_DIR / 'AutoFlow_val.txt')) + parser.add_argument('--val_percentage', type=float, default=0.05) + return parser + + +def main(args: Namespace) -> None: + """Run the split process. + + Parameters + ---------- + args : argparse.Namespace + Arguments for configuring the splitting. + """ + parts_dirs = [f'static_40k_png_{i+1}_of_4' for i in range(4)] + sample_dirs = [] + for pdir in parts_dirs: + sample_dirs.extend(sorted([f.stem for f in (Path(args.autoflow_root) / pdir).glob('*') if f.is_dir()])) + sample_dirs.sort() + assert len(sample_dirs) == 40000, f'ERROR: AutoFlow dataset should have 40k samples, but found {len(sample_dirs)}.' + samples_per_table = {} + for sdir in sample_dirs: + table_idx = sdir.split('_')[1] + if table_idx not in samples_per_table: + samples_per_table[table_idx] = [] + samples_per_table[table_idx].append(sdir) + assert len(samples_per_table) == 300, ( + f'ERROR: AutoFlow dataset should have 300 tables, but found {len(samples_per_table)}.') + + val_samples = [] + carryover_samples = 0.0 + for dir_list in samples_per_table.values(): + num_samples = len(dir_list) + num_val_samples_float = args.val_percentage * num_samples + carryover_samples + num_val_samples = int(num_val_samples_float) + + random.shuffle(dir_list) + val_samples.extend(dir_list[:num_val_samples]) + + carryover_samples = num_val_samples_float - num_val_samples + + val_samples.sort(key=lambda x: 1000*int(x.split('_')[1]) + int(x.split('_')[-1])) + with open(args.output_file, 'w') as f: + f.write('\n'.join(val_samples)) + + print(f'Saved {len(val_samples)} sample names to {args.output_file}') + + +if __name__ == '__main__': + parser: ArgumentParser = _init_parser() + args: Namespace = parser.parse_args() + main(args) diff --git a/ptlflow/models/base_model/base_model.py b/ptlflow/models/base_model/base_model.py index 450b680..e6192a0 100644 --- a/ptlflow/models/base_model/base_model.py +++ b/ptlflow/models/base_model/base_model.py @@ -26,6 +26,7 @@ import warnings from abc import abstractmethod from argparse import ArgumentParser, Namespace +from packaging import version from typing import Any, Callable, Dict, List, Optional, Tuple, Union with warnings.catch_warnings(): @@ -37,7 +38,7 @@ from ptlflow.data import flow_transforms as ft from ptlflow.data.datasets import ( - FlyingChairsDataset, FlyingChairs2Dataset, Hd1kDataset, KittiDataset, SintelDataset, FlyingThings3DDataset, + AutoFlowDataset, FlyingChairsDataset, FlyingChairs2Dataset, Hd1kDataset, KittiDataset, SintelDataset, FlyingThings3DDataset, FlyingThings3DSubsetDataset) from ptlflow.utils.external.raft import InputPadder from ptlflow.utils.utils import config_logging, make_divisible @@ -87,7 +88,10 @@ def __init__( self.last_inputs = None self.last_predictions = None - self.save_hyperparameters() + if version.parse(pl.__version__) >= version.parse('1.6.0'): + self.save_hyperparameters(ignore=['loss_fn']) + else: + self.save_hyperparameters() @staticmethod def add_model_specific_args( @@ -602,6 +606,41 @@ def _split_train_val_metrics( # _get_datasets ########################################################################### + def _get_autoflow_dataset( + self, + is_train: bool, + *args: str + ) -> Dataset: + device = 'cuda' if self.args.train_transform_cuda else 'cpu' + md = make_divisible + + if is_train: + if self.args.train_crop_size is None: + cy, cx = (md(368, self.output_stride), md(496, self.output_stride)) + self.args.train_crop_size = (cy, cx) + logging.warning('--train_crop_size is not set. It will be set as (%d, %d).', cy, cx) + else: + cy, cx = ( + md(self.args.train_crop_size[0], self.output_stride), md(self.args.train_crop_size[1], self.output_stride)) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose([ + ft.ToTensor(device=device, fp16=self.args.train_transform_fp16), + ft.RandomScaleAndCrop((cy, cx), (-0.1, 1.0), (-0.2, 0.2), min_pool_binary=True), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5/3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser(0.5, (int(1), int(3)), (int(50), int(100)), 'mean'), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ]) + else: + transform = ft.ToTensor() + + split = 'trainval' + if len(args) > 0 and args[0] in ['train', 'val', 'trainval']: + split = args[0] + dataset = AutoFlowDataset(self.args.autoflow_root_dir, split=split, transform=transform) + return dataset + def _get_chairs_dataset( self, is_train: bool, diff --git a/ptlflow/models/craft/LICENSE b/ptlflow/models/craft/LICENSE new file mode 100644 index 0000000..06f8bb5 --- /dev/null +++ b/ptlflow/models/craft/LICENSE @@ -0,0 +1,14 @@ + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. \ No newline at end of file diff --git a/ptlflow/models/craft/README.md b/ptlflow/models/craft/README.md new file mode 100644 index 0000000..1bb9808 --- /dev/null +++ b/ptlflow/models/craft/README.md @@ -0,0 +1,24 @@ +# RAFT + +## Original code + +[https://github.com/askerlee/craft](https://github.com/askerlee/craft) + +## Code license + +See [LICENSE](LICENSE). + +## Pretrained weights license + +Not specified. + +## Citation + +``` +@InProceedings{craft, +author="Sui, Xiuchao and Li, Shaohua and Geng, Xue and Wu, Yan and Xu, Xinxing and Liu, Yong and Goh, Rick Siow Mong and Zhu, Hongyuan", +title="CRAFT: Cross-Attentional Flow Transformers for Robust Optical Flow", +booktitle="CVPR", +year="2022" +} +``` \ No newline at end of file diff --git a/ptlflow/models/craft/__init__.py b/ptlflow/models/craft/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ptlflow/models/craft/corr.py b/ptlflow/models/craft/corr.py new file mode 100644 index 0000000..aee4e87 --- /dev/null +++ b/ptlflow/models/craft/corr.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .utils import bilinear_sampler +from .setrans import CrossAttFeatTrans, gen_all_indices, SETransInputFeatEncoder +import os + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4, do_corr_global_norm=False): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.do_corr_global_norm = do_corr_global_norm + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + if self.do_corr_global_norm: + corr_3d = corr.permute(0, 3, 1, 2, 4, 5).view(batch, dim, -1) + corr_normed = F.layer_norm( corr_3d, (corr_3d.shape[2],), eps=1e-12 ) + corr = corr_normed.view(batch, dim, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5) + + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + # Save corr for visualization + if 'SAVECORR' in os.environ: + corr_savepath = os.environ['SAVECORR'] + # corr2: batch, dim, h1, w1, h2, w2. + corr2 = corr.reshape(batch, h1, w1, dim, h2, w2).permute(0, 3, 1, 2, 4, 5).detach().cpu() + torch.save(corr2, corr_savepath) + print(f"Corr tensor saved to {corr_savepath}") + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + # Concatenate the four levels (4 resolutions of neighbors), + # and permute the neighbors to the channel dimension. + out = torch.cat(out_pyramid, dim=-1) + # [batch, number of neighbors, h1, w1] + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class CorrBlockSingleScale(nn.Module): + def __init__(self, fmap1, fmap2, num_levels=4, radius=4, do_corr_global_norm=False): + super().__init__() + self.radius = radius + self.do_corr_global_norm = do_corr_global_norm + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + if self.do_corr_global_norm: + corr_3d = corr.permute(0, 3, 1, 2, 4, 5).view(B, dim, -1) + corr_normed = F.layer_norm( corr_3d, (corr_3d.shape[2],), eps=1e-12 ) + corr = corr_normed.view(batch, dim, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5) + + self.corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + self.do_corr_global_norm = do_corr_global_norm + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + corr = self.corr + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + out = corr.view(batch, h1, w1, -1) + out = out.permute(0, 3, 1, 2).contiguous().float() + return out + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + +# TransCorrBlock instance is created and destroyed in each call of raft.forward(). +# It is only for a particular pair of image features fmap1, fmap2 +class TransCorrBlock(CorrBlock, nn.Module): + def __init__(self, config, num_levels=4, radius=4, do_corr_global_norm=False): + # Do not call CorrBlock.__init__(), as corr is computed differently. + nn.Module.__init__(self) + self.num_levels = num_levels + self.radius = radius + self.config = config + self.setrans = CrossAttFeatTrans(self.config, "Inter-frame correlation block") + self.vispos_encoder = SETransInputFeatEncoder(self.config) + self.coords2 = None + self.do_corr_global_norm = do_corr_global_norm + + # Compute the correlation volume between all pairs here. + # If fmap1o, fmap2o are None, then fmap1 is conv features, fmap2 is transformer features. + # If fmap1o, fmap2o are not None, then fmap1, fmap2 are transformer features, + # and fmap1o, fmap2o are conv features. + def update(self, fmap1, fmap2, fmap1o, fmap2o, coords1, coords2=None): + self.corr_pyramid = [] + # coords1 is generated by coords_grid(), with the format + # (width index, height index) + # flip => (height index, width index) + coords1 = coords1.permute(0, 2, 3, 1).flip(-1) + if coords2 is None: + coords2 = gen_all_indices(fmap2.shape[2:], device=fmap2.device) + coords2 = coords2.unsqueeze(0).repeat(fmap2.shape[0], 1, 1, 1) + + vispos1, pos_biases = self.vispos_encoder(fmap1, coords1, return_pos_biases=True) + vispos2 = self.vispos_encoder(fmap2, coords2, return_pos_biases=False) + batch, dim, ht, wd = fmap1.shape + + # If both f1_trans and f2_trans are used, then compute the two-way correlation. + # Otherwise, only fmap2o is not None, and compute the single-way correlation. + if fmap1o is not None and fmap2o is not None: + vispos1o = self.vispos_encoder(fmap1o, coords1, return_pos_biases=False) + vispos2o = self.vispos_encoder(fmap2o, coords2, return_pos_biases=False) + # Two-way correlation. + corr_1t2o = self.corr(ht, wd, vispos1, vispos2o, pos_biases) + corr_1o2t = self.corr(ht, wd, vispos1o, vispos2, pos_biases) + # Try concatenation instead of averaging, as averaging may lose info. + corr = torch.cat([corr_1t2o, corr_1o2t], dim=3) + else: + # single-way correlation. + corr = self.corr(ht, wd, vispos1, vispos2, pos_biases) + + batch, h1, w1, dim, h2, w2 = corr.shape + # Merge batch with h1 and w1 to improve efficiency. They will be separate later. + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + if 'SAVECORR' in os.environ: + corr_savepath = os.environ['SAVECORR'] + corr2 = corr.detach().cpu().reshape(batch, h1, w1, h2, w2) + torch.save(corr2, corr_savepath) + print(f"Corr tensor saved to {corr_savepath}") + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def corr(self, ht, wd, vispos1, vispos2, pos_biases): + batch, ht_wd, dim = vispos1.shape + assert ht_wd == ht * wd + # if out_attn_only, output attention matrix is in the shape of (query unit number, key unit number) + # otherwise, output features are in the same shape as the query features. + # key features are recombined to get new query features by matmul(attention_probs, V(key features)) + # frame1 frame2 + # corr: [1, 1, 7040, 7040] + corr = self.setrans(vispos1, vispos2, pos_biases) + if self.do_corr_global_norm: + B, C, H, W = corr.shape + corr_3d = corr.view(B, C, H*W) + corr_normed = F.layer_norm( corr_3d, (corr_3d.shape[2],), eps=1e-12 ) + corr = corr_normed.view(B, C, H, W) + + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr + + # __call__() inherits from CorrBlock. + \ No newline at end of file diff --git a/ptlflow/models/craft/craft.py b/ptlflow/models/craft/craft.py new file mode 100644 index 0000000..471164d --- /dev/null +++ b/ptlflow/models/craft/craft.py @@ -0,0 +1,328 @@ +from argparse import ArgumentParser, Namespace + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import GMAUpdateBlock +from .extractor import BasicEncoder +from .corr import CorrBlock, TransCorrBlock +from .utils import coords_grid, upflow8, print0 +from .gma import Attention +from .setrans import SETransConfig, SelfAttVisPosTrans +from ..base_model.base_model import BaseModel + + +class SequenceLoss(nn.Module): + def __init__(self, args): + super().__init__() + self.gamma = args.gamma + self.max_flow = args.max_flow + + def forward(self, outputs, inputs): + """ Loss function defined over sequence of flow predictions """ + + flow_preds = outputs['flow_preds'] + flow_gt = inputs['flows'][:, 0] + valid = inputs['valids'][:, 0] + + # n_predictions = args.iters = 12 + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exclude invalid pixels and extremely large displacements. + # MAX_FLOW = 400. + valid = (valid >= 0.5) & ((flow_gt**2).sum(dim=1).sqrt() < self.max_flow) + + for i in range(n_predictions): + # Exponentially increasing weights. (Eq.7 in RAFT paper) + # As i increases, flow_preds[i] is expected to be more and more accurate, + # so we are less and less tolerant to errors through gradually increased i_weight. + i_weight = self.gamma**(n_predictions - i - 1) + i_loss = (flow_preds[i] - flow_gt).abs() + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + + return flow_loss + + +class CRAFT(BaseModel): + pretrained_checkpoints = { + 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/craft-things-5a41930c.ckpt', + 'sintel': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/craft-sintel-ff8e6563.ckpt', + 'kitti': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/craft-kitti-4d99b0c1.ckpt' + } + + def __init__(self, + args: Namespace) -> None: + super().__init__( + args=args, + loss_fn=SequenceLoss(args), + output_stride=8) + + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + + if 'dropout' not in self.args: + self.args.dropout = 0 + + # default CRAFT corr_radius: 4 + if args.corr_radius == -1: + args.corr_radius = 4 + print0("Lookup radius: %d" %args.corr_radius) + + if args.craft: + self.inter_trans_config = SETransConfig() + self.inter_trans_config.update_config(args) + self.inter_trans_config.in_feat_dim = 256 + self.inter_trans_config.feat_dim = 256 + self.inter_trans_config.max_pos_size = 160 + # out_attn_scores_only implies no FFN nor V projection. + self.inter_trans_config.out_attn_scores_only = True # implies no FFN and no skip. + self.inter_trans_config.attn_diag_cycles = 1000 + self.inter_trans_config.num_modes = args.inter_num_modes # default: 4 + self.inter_trans_config.tie_qk_scheme = 'shared' # Symmetric Q/K + self.inter_trans_config.qk_have_bias = args.inter_qk_have_bias # default: True + self.inter_trans_config.pos_code_type = args.inter_pos_code_type # default: bias + self.inter_trans_config.pos_code_weight = args.inter_pos_code_weight # default: 0.5 + self.args.inter_trans_config = self.inter_trans_config + print0("Inter-frame trans config:\n{}".format(self.inter_trans_config.__dict__)) + + self.corr_fn = TransCorrBlock(self.inter_trans_config, radius=self.args.corr_radius, + do_corr_global_norm=True) + + # feature network, context network, and update block + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout) + + if args.f2trans != 'none': + # f2_trans has the same configuration as GMA att, + # except that the feature dimension is doubled, and not out_attn_probs_only. + self.f2_trans_config = SETransConfig() + self.f2_trans_config.update_config(args) + self.f2_trans_config.in_feat_dim = 256 + self.f2_trans_config.feat_dim = 256 + # f2trans(x) = attn_aggregate(v(x)) + x. Here attn_aggregate and v (first_linear) both have 4 modes. + self.f2_trans_config.has_input_skip = True + # No FFN. f2trans simply aggregates similar features. + # But there's still a V projection. + self.f2_trans_config.has_FFN = False + # When doing feature aggregation, set attn_mask_radius > 0 to exclude points that are too far apart, to reduce noises. + # E.g., 64 corresponds to 64*8=512 pixels in the image space. + self.f2_trans_config.attn_mask_radius = args.f2_attn_mask_radius + # Not tying QK performs slightly better. + self.f2_trans_config.tie_qk_scheme = None + self.f2_trans_config.qk_have_bias = False + self.f2_trans_config.out_attn_probs_only = False + self.f2_trans_config.attn_diag_cycles = 1000 + self.f2_trans_config.num_modes = args.f2_num_modes # default: 4 + self.f2_trans_config.pos_code_type = args.intra_pos_code_type # default: bias + self.f2_trans_config.pos_code_weight = args.f2_pos_code_weight # default: 0.5 + self.f2_trans = SelfAttVisPosTrans(self.f2_trans_config, "F2 transformer") + print0("F2-trans config:\n{}".format(self.f2_trans_config.__dict__)) + self.args.f2_trans_config = self.f2_trans_config + + if args.f1trans != 'none': + args.corr_multiplier = 2 + if args.f1trans == 'shared': + # f1_trans and f2_trans are shared. + self.f1_trans = self.f2_trans + elif args.f1trans == 'private': + # f1_trans is a private instance of SelfAttVisPosTrans. + self.f1_trans = SelfAttVisPosTrans(self.f2_trans_config, "F1 transformer") + else: + breakpoint() + else: + self.f1_trans = None + args.corr_multiplier = 1 + + if args.use_setrans: + self.intra_trans_config = SETransConfig() + self.intra_trans_config.update_config(args) + self.intra_trans_config.in_feat_dim = 128 + self.intra_trans_config.feat_dim = 128 + # has_FFN & has_input_skip are for GMAUpdateBlock.aggregator. + # Having FFN reduces performance. FYI, GMA also has no FFN. + self.intra_trans_config.has_FFN = False + self.intra_trans_config.has_input_skip = True + self.intra_trans_config.attn_mask_radius = -1 + # Not tying QK performs slightly better. + self.intra_trans_config.tie_qk_scheme = None + self.intra_trans_config.qk_have_bias = False + self.intra_trans_config.out_attn_probs_only = True + self.intra_trans_config.attn_diag_cycles = 1000 + self.intra_trans_config.num_modes = args.intra_num_modes # default: 4 + self.intra_trans_config.pos_code_type = args.intra_pos_code_type # default: bias + self.intra_trans_config.pos_code_weight = args.intra_pos_code_weight # default: 1 + self.att = SelfAttVisPosTrans(self.intra_trans_config, "Intra-frame attention") + self.args.intra_trans_config = self.intra_trans_config + print0("Intra-frame trans config:\n{}".format(self.intra_trans_config.__dict__)) + else: + self.att = Attention(args=self.args, dim=cdim, heads=self.args.num_heads, max_pos_size=160, dim_head=cdim) + + # if args.use_setrans, initialization of GMAUpdateBlock.aggregator needs to access self.args.intra_trans_config. + # So GMAUpdateBlock() construction has to be done after initializing intra_trans_config. + self.update_block = GMAUpdateBlock(self.args, hidden_dim=hdim) + self.call_counter = 0 + + @staticmethod + def add_model_specific_args(parent_parser=None): + parent_parser = BaseModel.add_model_specific_args(parent_parser) + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--corr_levels', type=int, default=4) + parser.add_argument('--corr_radius', type=int, default=4) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--gamma', type=float, default=0.8) + parser.add_argument('--max_flow', type=float, default=1000.0) + parser.add_argument('--iters', type=int, default=12) + parser.add_argument('--f1trans', type=str, choices=['none', 'shared', 'private'], default='none') + parser.add_argument('--f2trans', type=str, choices=('none', 'full'), default='full') + parser.add_argument('--f2_attn_mask_radius', type=int, default=-1) + parser.add_argument('--f2_num_modes', type=int, default=4) + parser.add_argument('--f2_pos_code_weight', type=float, default=0.5) + parser.add_argument('--inter_num_modes', type=int, default=4) + parser.add_argument('--inter_pos_code_type', type=str, choices=('bias', 'lsinu'), default='bias') + parser.add_argument('--inter_pos_code_weight', type=float, default=0.5) + parser.add_argument('--intra_pos_code_type', type=str, choices=('bias', 'lsinu'), default='bias') + parser.add_argument('--intra_pos_code_weight', type=float, default=1.0) + parser.add_argument('--intra_num_modes', type=int, default=4) + parser.add_argument('--no_craft', action='store_false', dest='craft') + parser.add_argument('--no_inter_qk_have_bias', action='store_false', dest='inter_qk_have_bias') + parser.add_argument('--num_heads', type=int, default=1) + parser.add_argument('--pos_bias_radius', type=int, default=7) + parser.add_argument('--no_use_setrans', action='store_false', dest='use_setrans') + return parser + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8).to(img.device) + coords1 = coords_grid(N, H // 8, W // 8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward(self, inputs, flow_init=None): + """ Estimate optical flow between pair of frames """ + image1 = inputs['images'][:, 0] + image2 = inputs['images'][:, 1] + + image1 = 2 * image1 - 1.0 + image2 = 2 * image2 - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + fmap1, fmap2 = self.fnet([image1, image2]) + fmap1o, fmap2o = None, None + if self.args.f1trans != 'none': + fmap1o = fmap1 + fmap1 = self.f1_trans(fmap1) + if self.args.f2trans != 'none': + fmap2o = fmap2 + fmap2 = self.f2_trans(fmap2) + + # fmap1, fmap2: [1, 256, 55, 128]. 1/8 size of the original image. + # correlation matrix: 7040*7040 (55*128=7040). + fmap1 = fmap1.float() + fmap2 = fmap2.float() + + # If not craft, the correlation volume is computed in the ctor. + # If craft, the correlation volume is computed in corr_fn.update(). + if not self.args.craft: + self.corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + # cnet: context network to extract features from image1 only. + # cnet arch is the same as fnet. + # fnet extracts features specifically for correlation computation. + # cnet_feat: extracted features focus on semantics of image1? + # (semantics of each pixel, used to guess its motion?) + cnet_feat = self.cnet(image1) + + # Both fnet and cnet are BasicEncoder. output is from conv (no activation function yet). + # net_feat, inp_feat: [1, 128, 55, 128] + net_feat, inp_feat = torch.split(cnet_feat, [hdim, cdim], dim=1) + net_feat = torch.tanh(net_feat) + inp_feat = torch.relu(inp_feat) + # attention, att_c, att_p = self.att(inp_feat) + attention = self.att(inp_feat) + + # coords0 is always fixed as original coords. + # coords1 is iteratively updated as coords0 + current estimated flow. + # At this moment coords0 == coords1. + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + # If craft, the correlation volume is computed in corr_fn.update(). + if self.args.craft: + # only update() once, instead of dynamically updating coords1. + self.corr_fn.update(fmap1, fmap2, fmap1o, fmap2o, coords1, coords2=None) + + flow_predictions = [] + for itr in range(self.args.iters): + coords1 = coords1.detach() + # corr: [6, 324, 50, 90]. 324: number of points in the neighborhood. + # radius = 4 -> neighbor points = (4*2+1)^2 = 81. 4 levels: x4 -> 324. + corr = self.corr_fn(coords1) # index correlation volume + flow = coords1 - coords0 + + # net_feat: hidden features of ConvGRU. + # inp_feat: input features to ConvGRU. + # up_mask is scaled to 0.25 of original values. + # update_block: GMAUpdateBlock + # In the first few iterations, delta_flow.abs().max() could be 1.3 or 0.8. Later it becomes 0.2~0.3. + net_feat, up_mask, delta_flow = self.update_block(net_feat, inp_feat, corr, flow, attention) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + # coords0 is fixed as original coords. + # upflow8: upsize to 8 * height, 8 * width. + # flow value also *8 (scale the offsets proportionally to the resolution). + flow_up = upflow8(coords1 - coords0) + else: + # The final high resolution flow field is found + # by using the mask to take a weighted combination over the neighborhood. + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if self.training: + outputs = { + 'flows': flow_up[:, None], + 'flow_preds': flow_predictions + } + else: + outputs = { + 'flows': flow_up[:, None], + 'flow_small': coords1 - coords0 + } + + return outputs diff --git a/ptlflow/models/craft/extractor.py b/ptlflow/models/craft/extractor.py new file mode 100644 index 0000000..c2314b1 --- /dev/null +++ b/ptlflow/models/craft/extractor.py @@ -0,0 +1,195 @@ +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + use_old_downsampling_scheme = True + if use_old_downsampling_scheme: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + # kernel size = 3 performs worse than 1. Disabled. + elif stride == 2: + # The trick proposed in "Detail Preserving Residual Feature Pyramid Modules for Optical Flow". + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=3, stride=2, padding=1), self.norm3) + else: + breakpoint() + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + # if input is list, x = [image1, image2]. + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/ptlflow/models/craft/gma.py b/ptlflow/models/craft/gma.py new file mode 100644 index 0000000..f7235fc --- /dev/null +++ b/ptlflow/models/craft/gma.py @@ -0,0 +1,150 @@ +import torch +from torch import nn, einsum +from einops import rearrange + +# max_pos_size = 160 +class RelPosEmb(nn.Module): + def __init__( + self, + max_pos_size, + dim_head + ): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) + # rel_ind[i, j] = j - i + 159. + rel_ind = deltas + max_pos_size - 1 + self.register_buffer('rel_ind', rel_ind) + + def forward(self, q): + # q: [8, 1, 46, 62, 128] + batch, heads, h, w, c = q.shape + # self.rel_ind[:h, :h]: [46, 46] + # self.rel_ind[:w, :w]: [62, 62] + # rel_ind[i,j] = j - i + 159, precomputed distance between i, j. + # This assumes the input x (from which q is derived) is precisely on the grid. + # This is fine when we do self-attention on x. + # However, it will be somewhat limiting if we use RelPosEmb on cross-attention between two frames, + # particularly when we use flow_init != 0 (on sintel), + # we better get the positional encodings of x according to flow_init, instead of the grid of x. + # However, an accurate computation of the relative distances between all input units is expensive. + # Since values in flow_init are usually small, this inaccuracy may be negligible. + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width( self.rel_ind[:w, :w].reshape(-1)) + + # height_emb: [46*46, 128] => [46, 46, 1, 128] + # width_emb: [62*62, 128] => [62, 1, 62, 128] + # height_emb[i, j]: the embedding of element at (i,j) as a function of the height difference (i-j). + # width_emb[i, j]: the embedding of element at (i,j) as a function of the width difference (i-j). + height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) + width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) + + # outer product? y, uv -> y u v b h x y d x u v d + # height_score: [8, 1, 46, 62, 46, 1] <= [8, 1, 46, 62, 128] * [46, 46, 1, 128] + # width_score: [8, 1, 46, 62, 1, 62] + height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) + width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) + # height_score + width_score: [8, 1, 46, 62, 46, 62], 65071232 elements. + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size = 100, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + self.pos_emb = RelPosEmb(max_pos_size, dim_head) + self.pos_embed_weight = 1.0 + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + # q, k: [8, 128, 46, 62] + q, k = self.to_qk(fmap).chunk(2, dim=1) + + # q, k: [8, 1, 46, 62, 128] + q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) + # Why not scale k? + q = self.scale * q + + if self.args.position_only: + sim = self.pos_emb(q) + + elif self.args.position_and_content: + # [..., 46, 62, ...] . [..., 46, 62, ...] => [..., 46, 62, 46, 62] + sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + sim_pos = self.pos_emb(q) + sim = sim_content + self.pos_embed_weight * sim_pos + + else: + # q, k: [B, 1, 46, 62, 128] + # sim: [B, 1, 46, 62, 46, 62] + sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + + sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') + attn = sim.softmax(dim=-1) + + return attn + +# Aggregate output is dim-dimensional, same as the input. No FFN is used. +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + + # project is None for GMA. + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) diff --git a/ptlflow/models/craft/setrans.py b/ptlflow/models/craft/setrans.py new file mode 100644 index 0000000..d3b7fd9 --- /dev/null +++ b/ptlflow/models/craft/setrans.py @@ -0,0 +1,796 @@ +import os +import math +import copy + +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +from .setrans_ablation import RandPosEmbedder, SinuPosEmbedder, ZeroEmbedder, MultiHeadFeatTrans +from .utils import print0 +torch.set_printoptions(sci_mode=False) + +bb2_stage_dims = { 'raft-small': [32, 32, 64, 96, 128], + 'raft-basic': [64, 64, 96, 128, 256], + 'resnet34': [64, 64, 128, 256, 512], + 'resnet50': [64, 256, 512, 1024, 2048], + 'resnet101': [64, 256, 512, 1024, 2048], + 'resibn101': [64, 256, 512, 1024, 2048], # resibn: resnet + IBN layers + 'eff-b0': [16, 24, 40, 112, 1280], # input: 224 + 'eff-b1': [16, 24, 40, 112, 1280], # input: 240 + 'eff-b2': [16, 24, 48, 120, 1408], # input: 260 + 'eff-b3': [24, 32, 48, 136, 1536], # input: 300 + 'eff-b4': [24, 32, 56, 160, 1792], # input: 380 + 'i3d': [64, 192, 480, 832, 1024] # input: 224 + } + +# Can also be implemented using torch.meshgrid(). +def gen_all_indices(shape, device): + indices = torch.arange(shape.numel(), device=device).view(shape) + + out = [] + for dim_size in reversed(shape): + out.append(indices % dim_size) + indices = torch.div(indices, dim_size, rounding_mode='trunc') + return torch.stack(tuple(reversed(out)), len(shape)) + +# drop_path and DropPath are copied from timm/models/layers/drop.py +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class SETransConfig(object): + def __init__(self): + self.feat_dim = -1 + self.in_feat_dim = -1 + # self.backbone_type = 'eff-b4' # resnet50, resnet101, resibn101, eff-b1~b4 + # self.bb_stage_idx = 4 # Last stage of the five stages. Index starts from 0. + # self.set_backbone_type(self.backbone_type) + # self.use_pretrained = True + + # Positional encoding settings. + self.pos_dim = 2 + self.pos_code_weight = 1 + + # Architecture settings + # Number of modes in the expansion attention block. + # When doing ablation study of multi-head, num_modes means num_heads, + # to avoid introducing extra config parameters. + self.num_modes = 4 + self.tie_qk_scheme = 'shared' # shared, loose, or none. + self.trans_output_type = 'private' # shared or private. + self.act_fun = F.gelu + + self.attn_clip = 100 + self.attn_diag_cycles = 1000 + self.base_initializer_range = 0.02 + + self.qk_have_bias = False + # Without the bias term, V projection often performs better. + self.v_has_bias = False + # Add an identity matrix (*0.02*query_idbias_scale) to query/key weights + # to make a bias towards identity mapping. + # Set to 0 to disable the identity bias. + self.query_idbias_scale = 10 + self.feattrans_lin1_idbias_scale = 10 + + # Pooling settings + self.pool_modes_feat = 'softmax' # softmax, max, mean, or none. + + # Randomness settings + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.2 + self.drop_path_prob = 0 # Drop path reduces performance greatly. + self.pos_code_type = 'bias' + self.ablate_multihead = False + self.out_attn_probs_only = False + # When out_attn_scores_only, dropout is not applied to attention scores. + self.out_attn_scores_only = False + self.attn_mask_radius = -1 + + def set_backbone_type(self, args): + if self.try_assign(args, 'backbone_type'): + self.bb_stage_dims = bb2_stage_dims[self.backbone_type] + self.in_feat_dim = self.bb_stage_dims[-1] + + # return True if any parameter is successfully set, and False if none is set. + def try_assign(self, args, *keys): + is_successful = False + + for key in keys: + if key in args: + if isinstance(args, dict): + self.__dict__[key] = args[key] + else: + self.__dict__[key] = args.__dict__[key] + is_successful = True + + return is_successful + + def update_config(self, args): + self.set_backbone_type(args) + self.try_assign(args, 'use_pretrained', 'apply_attn_stage', 'num_modes', + 'trans_output_type', 'base_initializer_range', + 'pos_code_type', 'ablate_multihead', 'attn_clip', 'attn_diag_cycles', + 'tie_qk_scheme', 'feattrans_lin1_idbias_scale', 'qk_have_bias', 'v_has_bias', + # out_attn_probs_only/out_attn_scores_only are only True for the optical flow correlation block. + 'out_attn_probs_only', 'out_attn_scores_only', + 'in_feat_dim', 'pos_bias_radius') + + if self.try_assign(args, 'out_feat_dim'): + self.feat_dim = self.out_feat_dim + else: + self.feat_dim = self.in_feat_dim + + if 'dropout_prob' in args and args.dropout_prob >= 0: + self.hidden_dropout_prob = args.dropout_prob + self.attention_probs_dropout_prob = args.dropout_prob + print0("Dropout prob: %.2f" %(args.dropout_prob)) + +CONFIG = SETransConfig() + + +# =================================== SETrans Initialization ====================================# +class SETransInitWeights(nn.Module): + """ An abstract class to handle weights initialization """ + def __init__(self, config, *inputs, **kwargs): + super(SETransInitWeights, self).__init__() + self.config = config + + def init_weights(self, module): + """ Initialize the weights. + type(module.weight) # + type(module.weight.data) # + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + base_initializer_range = self.config.base_initializer_range + module.weight.data.normal_(mean=0.0, std=base_initializer_range) + # Slightly different from the TF version which uses truncated_normal + # for initialization cf https://github.com/pytorch/pytorch/pull/5617 + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def tie_qk(module): + if isinstance(module, CrossAttFeatTrans) \ + and module.tie_qk_scheme != 'none' and module.tie_qk_scheme != None: + module.tie_qk() + +def add_identity_bias(module): + if isinstance(module, CrossAttFeatTrans) or isinstance(module, ExpandedFeatTrans): + module.add_identity_bias() + +#====================================== SETrans Shared Modules ========================================# + +class MMSharedMid(nn.Module): + def __init__(self, config): + super(MMSharedMid, self).__init__() + self.num_modes = config.num_modes + self.feat_dim = config.feat_dim + self.shared_linear = nn.Linear(self.feat_dim, self.feat_dim) + self.mid_act_fn = config.act_fun + # This dropout is not presented in huggingface transformers. + # Added to conform with lucidrains and rwightman's implementations. + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # x: [B0, 1792*4, U] + def forward(self, x): + # shape_4d: [B0, 4, 1792, U]. + shape_4d = ( x.shape[0], self.num_modes, self.feat_dim, x.shape[2] ) + # x_4d: [B0, 4, U, 1792]. + x_4d = x.view(shape_4d).permute([0, 1, 3, 2]) + + x_trans = self.shared_linear(x_4d) + x_act = self.mid_act_fn(x_trans) + x_drop = self.dropout(x_act) + + # restore the original shape + x_drop = x_drop.permute([0, 1, 3, 2]).reshape(x.shape) + + return x_drop + +# MMPrivateOutput/MMSharedOutput <- MMandedFeatTrans <- CrossAttFeatTrans +# MM***Output has a shortcut (residual) connection. +class MMPrivateOutput(nn.Module): + def __init__(self, config): + super(MMPrivateOutput, self).__init__() + self.num_modes = config.num_modes + self.feat_dim = config.feat_dim + feat_dim_allmode = self.feat_dim * self.num_modes + # Each group (mode) is applied a linear transformation, respectively. + self.group_linear = nn.Conv1d(feat_dim_allmode, feat_dim_allmode, 1, groups=self.num_modes) + self.resout_norm_layer = nn.LayerNorm(self.feat_dim, eps=1e-12, elementwise_affine=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # x, shortcut: [B0, 1792*4, U] + def forward(self, x, shortcut): + x = self.group_linear(x) + # x_comb: [B0, 1792*4, U]. Residual connection. + x_comb = x + shortcut + shape_4d = ( x.shape[0], self.num_modes, self.feat_dim, x.shape[2] ) + # x_comb_4d, x_drop_4d: [B0, 4, U, 1792]. + x_comb_4d = x.view(shape_4d).permute([0, 1, 3, 2]) + x_drop_4d = self.dropout(x_comb_4d) + x_normed = self.resout_norm_layer(x_drop_4d) + return x_normed + +# MMPrivateOutput/MMSharedOutput <- MMandedFeatTrans <- CrossAttFeatTrans +# MM***Output has a shortcut (residual) connection. +class MMSharedOutput(nn.Module): + # feat_dim_allmode is not used. Just to keep the ctor arguments the same as MMPrivateOutput. + def __init__(self, config): + super(MMSharedOutput, self).__init__() + self.num_modes = config.num_modes + self.feat_dim = config.feat_dim + self.shared_linear = nn.Linear(self.feat_dim, self.feat_dim) + self.resout_norm_layer = nn.LayerNorm(self.feat_dim, eps=1e-12, elementwise_affine=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # x, shortcut: [B0, 1792*4, U] or [B0, 4, U, 1792] + def forward(self, x, shortcut): + # shape_4d: [B0, 4, 1792, U]. + shape_4d = ( x.shape[0], self.num_modes, self.feat_dim, x.shape[2] ) + if len(x.shape) == 3: + x_4d = x.view(shape_4d).permute([0, 1, 3, 2]) + else: + x_4d = x + if len(shortcut.shape) == 3: + shortcut_4d = shortcut.view(shape_4d).permute([0, 1, 3, 2]) + else: + shortcut_4d = shortcut + + # x_4d, shortcut_4d: [B0, 4, U, 1792]. + x_trans = self.shared_linear(x_4d) + # x_4d, x_comb: [B0, 4, U, 1792]. Residual connection. + x_comb = x_trans + shortcut_4d + x_drop = self.dropout(x_comb) + x_normed = self.resout_norm_layer(x_drop) + return x_normed + +# group_dim: the tensor dimension that corresponds to the multiple groups. +class LearnedSoftAggregate(nn.Module): + def __init__(self, num_feat, group_dim, keepdim=False): + super(LearnedSoftAggregate, self).__init__() + self.group_dim = group_dim + # num_feat = 1: element-wise score function & softmax. + # num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor. + self.num_feat = num_feat + self.feat2score = nn.Linear(num_feat, 1) + self.keepdim = keepdim + + def forward(self, x, score_basis=None): + # Assume the last dim of x is the feature dim. + if score_basis is None: + score_basis = x + + if self.num_feat == 1: + mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1) + else: + mode_scores = self.feat2score(score_basis) + attn_probs = mode_scores.softmax(dim=self.group_dim) + x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim) + return x_aggr + +# ExpandedFeatTrans <- CrossAttFeatTrans. +# ExpandedFeatTrans has a residual connection. +class ExpandedFeatTrans(nn.Module): + def __init__(self, config, name): + super(ExpandedFeatTrans, self).__init__() + self.config = config + self.name = name + self.in_feat_dim = config.in_feat_dim + self.feat_dim = config.feat_dim + self.num_modes = config.num_modes + self.feat_dim_allmode = self.feat_dim * self.num_modes + # first_linear is the value projection in other transformer implementations. + # The output of first_linear will be divided into num_modes groups. + # first_linear is always 'private' for each group, i.e., + # parameters are not shared (parameter sharing makes no sense). + self.first_linear = nn.Linear(self.in_feat_dim, self.feat_dim_allmode, bias=config.v_has_bias) + + self.base_initializer_range = config.base_initializer_range + self.has_FFN = getattr(config, 'has_FFN', True) + self.has_input_skip = getattr(config, 'has_input_skip', False) + self.drop_path = DropPath(config.drop_path_prob) if config.drop_path_prob > 0. else nn.Identity() + + print0("{}: v_has_bias: {}, has_FFN: {}, has_input_skip: {}".format( + self.name, config.v_has_bias, self.has_FFN, self.has_input_skip)) + + self.pool_modes_keepdim = False + self.pool_modes_feat = config.pool_modes_feat + + if self.pool_modes_feat == 'softmax': + agg_basis_feat_dim = self.feat_dim + + # group_dim = 1, i.e., features will be aggregated across the modes. + self.feat_softaggr = LearnedSoftAggregate(agg_basis_feat_dim, group_dim=1, + keepdim=self.pool_modes_keepdim) + + if self.has_FFN: + self.intermediate = MMSharedMid(self.config) + + if config.trans_output_type == 'shared': + self.output = MMSharedOutput(config) + elif config.trans_output_type == 'private': + self.output = MMPrivateOutput(config) + + # Have to ensure U1 == U2. + if self.has_input_skip: + self.input_skip_coeff = Parameter(torch.ones(1)) + self.skip_layer_norm = nn.LayerNorm(self.feat_dim, eps=1e-12, elementwise_affine=False) + + def add_identity_bias(self): + if self.config.feattrans_lin1_idbias_scale > 0: + # first_linear dimension is num_modes * feat_dim. + # If in_feat_dim == feat_dim, only add identity bias to the first mode. + # If in_feat_dim > feat_dim, expand to more modes until all in_feat_dim dimensions are covered. + identity_weight = torch.diag(torch.ones(self.feat_dim)) * self.base_initializer_range \ + * self.config.feattrans_lin1_idbias_scale + # Only bias the weight of the first mode. + # The total initial "weight mass" in each row is reduced by 1792*0.02*0.5. + self.first_linear.weight.data[:self.feat_dim, :self.feat_dim] = \ + self.first_linear.weight.data[:self.feat_dim, :self.feat_dim] * 0.5 + identity_weight + + # input_feat is usually key_feat. + # input_feat: [3, 4416, 128]. attention_probs: [3, 4, 4416, 4416]. + def forward(self, input_feat, attention_probs): + # input_feat: [B, U2, 1792], mm_first_feat: [B, Units, 1792*4] + # B: batch size, U2: number of the 2nd group of units, + # IF: in_feat_dim, could be different from feat_dim, due to layer compression + # (different from squeezed attention). + B, U2, IF = input_feat.shape + U1 = attention_probs.shape[2] + F = self.feat_dim + M = self.num_modes + mm_first_feat = self.first_linear(input_feat) + # mm_first_feat after transpose: [B, 1792*4, U2] + mm_first_feat = mm_first_feat.transpose(1, 2) + + # mm_first_feat_4d: [B, 4, U2, 1792] + mm_first_feat_4d = mm_first_feat.view(B, M, F, U2).transpose(2, 3) + + # attention_probs: [B, 4, U1, U2]. On sintel: [1, 4, 7040, 7040] + # mm_first_feat_fusion: [B, 4, U2, F]. On sintel: [1, 4, 7040, 256] + mm_first_feat_fusion = torch.matmul(attention_probs, mm_first_feat_4d) + mm_first_feat_fusion_3d = mm_first_feat_fusion.transpose(2, 3).reshape(B, M*F, U1) + mm_first_feat = mm_first_feat_fusion_3d + + if self.has_FFN: + # mm_mid_feat: [B, 1792*4, U1]. Group linear & gelu of mm_first_feat. + mm_mid_feat = self.intermediate(mm_first_feat) + # mm_last_feat: [B, 4, U1, 1792]. Group/shared linear & residual & Layernorm + mm_last_feat = self.output(mm_mid_feat, mm_first_feat) + mm_trans_feat = mm_last_feat + else: + mm_trans_feat = mm_first_feat_fusion + + if self.pool_modes_feat == 'softmax': + trans_feat = self.feat_softaggr(mm_trans_feat) + elif self.pool_modes_feat == 'max': + trans_feat = mm_trans_feat.max(dim=1)[0] + elif self.pool_modes_feat == 'mean': + trans_feat = mm_trans_feat.mean(dim=1) + elif self.pool_modes_feat == 'none': + trans_feat = mm_trans_feat + + # Have to ensure U1 == U2. + if self.has_input_skip: + trans_feat = self.input_skip_coeff * input_feat + self.drop_path(trans_feat) + trans_feat = self.skip_layer_norm(trans_feat) + + # trans_feat: [B, U1, 1792] + return trans_feat + +class CrossAttFeatTrans(SETransInitWeights): + def __init__(self, config, name): + super(CrossAttFeatTrans, self).__init__(config) + self.config = config + self.name = name + self.num_modes = config.num_modes + self.in_feat_dim = config.in_feat_dim + self.feat_dim = config.feat_dim + self.attention_mode_dim = self.in_feat_dim // self.num_modes # 448 + # att_size_allmode: 512 * modes + self.att_size_allmode = self.num_modes * self.attention_mode_dim + self.query = nn.Linear(self.in_feat_dim, self.att_size_allmode, bias=config.qk_have_bias) + self.key = nn.Linear(self.in_feat_dim, self.att_size_allmode, bias=config.qk_have_bias) + self.base_initializer_range = config.base_initializer_range + + self.out_attn_scores_only = config.out_attn_scores_only + self.out_attn_probs_only = config.out_attn_probs_only + self.ablate_multihead = config.ablate_multihead + + # out_attn_scores_only / out_attn_probs_only implies no FFN nor V projection. + if self.out_attn_scores_only or self.out_attn_probs_only: + self.out_trans = None + if self.num_modes > 1: + # Each attention value is a scalar. So num_feat = 1. + self.attn_softaggr = LearnedSoftAggregate(1, group_dim=1, keepdim=True) + + elif self.ablate_multihead: + self.out_trans = MultiHeadFeatTrans(config, name + "-out_trans") + else: + self.out_trans = ExpandedFeatTrans(config, name + "-out_trans") + + self.att_dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.tie_qk_scheme = config.tie_qk_scheme + print0("{}: in_feat_dim: {}, feat_dim: {}, modes: {}, qk_have_bias: {}".format( + self.name, self.in_feat_dim, self.feat_dim, self.num_modes, config.qk_have_bias)) + + # if using SlidingPosBiases2D, then add positional embeddings in CrossAttFeatTrans.forward(). + if config.pos_code_type == 'bias': + self.pos_code_weight = config.pos_code_weight + print0("Positional biases weight: {:.3}".format(self.pos_code_weight)) + else: + self.pos_code_weight = 1 + + self.attn_clip = config.attn_clip + if 'attn_diag_cycles' in config.__dict__: + self.attn_diag_cycles = config.attn_diag_cycles + else: + self.attn_diag_cycles = 1000 + self.max_attn = 0 + self.clamp_count = 0 + self.call_count = 0 + self.apply(self.init_weights) + self.apply(tie_qk) + # tie_qk() has to be executed after weight initialization. + self.apply(add_identity_bias) + + # if tie_qk_scheme is not None, it overrides the initialized self.tie_qk_scheme + def tie_qk(self, tie_qk_scheme=None): + # override config.tie_qk_scheme + if tie_qk_scheme is not None: + self.tie_qk_scheme = tie_qk_scheme + + if self.tie_qk_scheme == 'shared': + self.key.weight = self.query.weight + if self.key.bias is not None: + self.key.bias = self.query.bias + + elif self.tie_qk_scheme == 'loose': + self.key.weight.data.copy_(self.query.weight) + if self.key.bias is not None: + self.key.bias.data.copy_(self.query.bias) + + def add_identity_bias(self): + identity_weight = torch.diag(torch.ones(self.attention_mode_dim)) * self.base_initializer_range \ + * self.config.query_idbias_scale + repeat_count = self.in_feat_dim // self.attention_mode_dim + identity_weight = identity_weight.repeat([1, repeat_count]) + # only bias the weight of the first mode + # The total initial "weight mass" in each row is reduced by 1792*0.02*0.5. + self.key.weight.data[:self.attention_mode_dim] = \ + self.key.weight.data[:self.attention_mode_dim] * 0.5 + identity_weight + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_modes, -1) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + # pos_biases: [1, 1, U1, U2]. + def forward(self, query_feat, key_feat=None, pos_biases=None, attention_mask=None): + # query_feat: [B, U1, 1792] + # if key_feat == None: self attention. + if key_feat is None: + key_feat = query_feat + # mixed_query_layer, mixed_key_layer: [B, U1, 1792], [B, U2, 1792] + mixed_query_layer = self.query(query_feat) + mixed_key_layer = self.key(key_feat) + # query_layer, key_layer: [B, 4, U1, 448], [B, 4, U2, 448] + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [B0, 4, U1, 448] [B0, 4, 448, U2] + attention_scores = attention_scores / math.sqrt(self.attention_mode_dim) # [B0, 4, U1, U2] + + #if self.call_count == 0: + # print0(f"{self.name} query: {list(query_feat.shape)}, attn: {list(attention_scores.shape)}") + + with torch.no_grad(): + curr_max_attn = attention_scores.max().item() + curr_avg_attn = attention_scores.abs().mean().item() + + if curr_max_attn > self.max_attn: + self.max_attn = curr_max_attn + + if curr_max_attn > self.attn_clip: + attention_scores = torch.clamp(attention_scores, -self.attn_clip, self.attn_clip) + self.clamp_count += 1 + + self.call_count += 1 + if self.training: + if self.call_count % self.attn_diag_cycles == 0: + print0("max-attn: {:.2f}, avg-attn: {:.2f}, clamp-count: {}".format(self.max_attn, curr_avg_attn, self.clamp_count)) + self.max_attn = 0 + self.clamp_count = 0 + + if pos_biases is not None: + #[B0, 8, U1, U2] = [B0, 8, U1, U2] + [1, 1, U1, U2]. + attention_scores = attention_scores + self.pos_code_weight * pos_biases + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # When out_attn_scores_only, dropout is not applied to attention scores. + if self.out_attn_scores_only: + if self.num_modes > 1: + # [3, num_modes=4, 4500, 4500] => [3, 1, 4500, 4500] + attention_scores = self.attn_softaggr(attention_scores) + # attention_scores = self.att_dropout(attention_scores) + return attention_scores + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # self.attention_probs = attention_probs + + # lucidrains doesn't have this dropout but rwightman has. Will keep it. + attention_probs = self.att_dropout(attention_probs) #[B0, 4, U1, U2] + + if self.out_attn_probs_only: + # [6, 4, 4500, 4500] + return attention_probs + + else: + # out_feat: [B0, U1, 1792], in the same size as query_feat. + out_feat = self.out_trans(key_feat, attention_probs) + return out_feat + +class SelfAttVisPosTrans(nn.Module): + def __init__(self, config, name): + nn.Module.__init__(self) + self.config = copy.copy(config) + self.name = name + self.out_attn_only = config.out_attn_scores_only or config.out_attn_probs_only + self.attn_mask_radius = config.attn_mask_radius + self.setrans = CrossAttFeatTrans(self.config, name) + self.vispos_encoder = SETransInputFeatEncoder(self.config) + + def forward(self, x): + coords = gen_all_indices(x.shape[2:], device=x.device) + if self.attn_mask_radius > 0: + coords2 = coords.reshape(-1, 2) + coords_diff = coords2.unsqueeze(0) - coords2.unsqueeze(1) + attn_mask = (coords_diff.abs().max(dim=2)[0] > self.attn_mask_radius).float() + attn_mask = (attn_mask * -1e9).unsqueeze(0).unsqueeze(0) + else: + attn_mask = None + + coords = coords.unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + + x_vispos, pos_biases = self.vispos_encoder(x, coords, return_pos_biases=True) + + + # if out_attn_scores_only/out_attn_probs_only, + # then x_trans is an attention matrix in the shape of (query unit number, key unit number) + # otherwise, output features are in the same shape as the query features. + # key features are recombined to get new query features by matmul(attention_probs, V(key features)) + # frame1 frame2 + # x_vispos, x_trans: [4, 2852, 256] + # Here key_feat is omitted (None), i.e., key_feat = query_feat = x_vispos. + x_trans = self.setrans(x_vispos, pos_biases=pos_biases, attention_mask=attn_mask) + + # Save f2 attention for visualization + if self.name == 'F2 transformer' and 'SAVEF2' in os.environ: + # save the attention scores + f2_attention_probs = self.setrans.attention_probs.detach().cpu() + # [B0, 4, U1, U2] => [B0, U1, U2] + f2_attention_probs = f2_attention_probs.mean(dim=1, keepdim=False) + f2_savepath = os.environ['SAVEF2'] + batch, C, h1, w1 = x.shape + f2attn = f2_attention_probs.reshape(batch, h1, w1, h1, w1) + torch.save(f2attn, f2_savepath) + print0(f"F2 attention tensor saved to {f2_savepath}") + + # reshape x_trans to the input shape. + if not self.out_attn_only: + x_trans_shape = x_trans.shape + x_trans = x_trans.permute(0, 2, 1).reshape(x.shape) + + return x_trans + +# =================================== SETrans BackBone Components ==============================# + +class LearnedSinuPosEmbedder(nn.Module): + def __init__(self, pos_dim, pos_embed_dim, omega=1, affine=True): + super().__init__() + self.pos_dim = pos_dim + self.pos_embed_dim = pos_embed_dim + self.pos_fc = nn.Linear(self.pos_dim, self.pos_embed_dim, bias=True) + self.pos_mix_norm_layer = nn.LayerNorm(self.pos_embed_dim, eps=1e-12, elementwise_affine=affine) + self.omega = omega + print0("Learnable Sinusoidal positional encoding") + + def forward(self, pos_normed): + pos_embed_sum = 0 + pos_embed0 = self.pos_fc(pos_normed) + pos_embed_sin = torch.sin(self.omega * pos_embed0[:, :, 0::2]) + pos_embed_cos = torch.cos(self.omega * pos_embed0[:, :, 1::2]) + # Interlace pos_embed_sin and pos_embed_cos. + pos_embed_mix = torch.stack((pos_embed_sin, pos_embed_cos), dim=3).view(pos_embed0.shape) + pos_embed_out = self.pos_mix_norm_layer(pos_embed_mix) + + return pos_embed_out + +class SlidingPosBiases2D(nn.Module): + def __init__(self, pos_dim=2, pos_bias_radius=7, max_pos_size=(200, 200)): + super().__init__() + self.pos_dim = pos_dim + self.R = R = pos_bias_radius + # biases: [15, 15] + pos_bias_shape = [ pos_bias_radius * 2 + 1 for i in range(pos_dim) ] + self.biases = Parameter(torch.zeros(pos_bias_shape)) + # Currently only feature maps with a 2D spatial shape (i.e., 2D images) are supported. + if self.pos_dim == 2: + all_h1s, all_w1s, all_h2s, all_w2s = [], [], [], [] + for i in range(max_pos_size[0]): + i_h1s, i_w1s, i_h2s, i_w2s = [], [], [], [] + for j in range(max_pos_size[1]): + h1s, w1s, h2s, w2s = torch.meshgrid(torch.tensor(i), torch.tensor(j), + torch.arange(i, i+2*R+1), torch.arange(j, j+2*R+1)) + i_h1s.append(h1s) + i_w1s.append(w1s) + i_h2s.append(h2s) + i_w2s.append(w2s) + + i_h1s = torch.cat(i_h1s, dim=1) + i_w1s = torch.cat(i_w1s, dim=1) + i_h2s = torch.cat(i_h2s, dim=1) + i_w2s = torch.cat(i_w2s, dim=1) + all_h1s.append(i_h1s) + all_w1s.append(i_w1s) + all_h2s.append(i_h2s) + all_w2s.append(i_w2s) + + all_h1s = torch.cat(all_h1s, dim=0) + all_w1s = torch.cat(all_w1s, dim=0) + all_h2s = torch.cat(all_h2s, dim=0) + all_w2s = torch.cat(all_w2s, dim=0) + else: + breakpoint() + + # Put indices on GPU to speed up. + # But if without persistent=False, they will be saved to checkpoints, + # making the checkpoints unnecessarily huge. + self.register_buffer('all_h1s', all_h1s, persistent=False) + self.register_buffer('all_w1s', all_w1s, persistent=False) + self.register_buffer('all_h2s', all_h2s, persistent=False) + self.register_buffer('all_w2s', all_w2s, persistent=False) + print0(f"Sliding-window Positional Biases, r: {R}, max size: {max_pos_size}") + + def forward(self, feat_shape, device): + R = self.R + spatial_shape = feat_shape[-self.pos_dim:] + # [H, W, H, W] => [H+2R, W+2R, H+2R, W+2R]. + padded_pos_shape = list(spatial_shape) + [ 2*R + spatial_shape[i] for i in range(self.pos_dim) ] + padded_pos_biases = torch.zeros(padded_pos_shape, device=device) + + if self.pos_dim == 2: + H, W = spatial_shape + all_h1s = self.all_h1s[:H, :W] + all_w1s = self.all_w1s[:H, :W] + all_h2s = self.all_h2s[:H, :W] + all_w2s = self.all_w2s[:H, :W] + padded_pos_biases[(all_h1s, all_w1s, all_h2s, all_w2s)] = self.biases + + # Remove padding. [H+2R, W+2R, H+2R, W+2R] => [H, W, H, W]. + pos_biases = padded_pos_biases[:, :, R:-R, R:-R] + + return pos_biases + +class SETransInputFeatEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.feat_dim = config.in_feat_dim # 256 + self.pos_embed_dim = self.feat_dim + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.comb_norm_layer = nn.LayerNorm(self.feat_dim, eps=1e-12, elementwise_affine=False) + self.pos_code_type = config.pos_code_type + + # if using SlidingPosBiases2D, do not add positional embeddings here. + if config.pos_code_type != 'bias': + self.pos_code_weight = config.pos_code_weight + print0("Positional embedding weight: {:.3}".format(self.pos_code_weight)) + else: + self.pos_code_weight = 0 + + # Box position encoding. no affine, but could have bias. + # 2 channels => 1792 channels + if config.pos_code_type == 'lsinu': + self.pos_coder = LearnedSinuPosEmbedder(config.pos_dim, self.pos_embed_dim, omega=1, affine=False) + elif config.pos_code_type == 'rand': + self.pos_coder = RandPosEmbedder(config.pos_dim, self.pos_embed_dim, shape=(36, 36), affine=False) + elif config.pos_code_type == 'sinu': + self.pos_coder = SinuPosEmbedder(config.pos_dim, self.pos_embed_dim, shape=(36, 36), affine=False) + elif config.pos_code_type == 'zero': + self.pos_coder = ZeroEmbedder(self.pos_embed_dim) + elif config.pos_code_type == 'bias': + self.pos_coder = SlidingPosBiases2D(config.pos_dim, config.pos_bias_radius) + + self.cached_pos_code = None + self.cached_feat_shape = None + + # Cache the pos_code and feat_shape to avoid unnecessary generation time. + # This is only used during inference. During training, pos_code is always generated each time it's used. + # Otherwise the cached pos_code cannot receive proper gradients. + def pos_code_lookup_cache(self, vis_feat_shape, device, voxels_pos_normed): + if self.pos_code_type == 'bias': + # Cache miss for 'bias' type of positional codes. + if self.training or self.cached_pos_code is None or self.cached_feat_shape != vis_feat_shape: + self.cached_pos_code = self.pos_coder(vis_feat_shape, device) + self.cached_feat_shape = vis_feat_shape \ + # else: self.cached_pos_code exists, and self.cached_feat_shape == vis_feat_shape. + # Just return the cached pos_code. + else: + # Cache miss for all other type of positional codes. + if self.training or self.cached_pos_code is None or self.cached_feat_shape != voxels_pos_normed.shape: + self.cached_pos_code = self.pos_coder(voxels_pos_normed) + self.cached_feat_shape = voxels_pos_normed.shape + # else: self.cached_pos_code exists, and self.cached_feat_shape == voxels_pos_normed.shape. + # Just return the cached pos_code. + return self.cached_pos_code + + # return: [B0, num_voxels, 256] + def forward(self, vis_feat, voxels_pos, return_pos_biases=True): + # vis_feat: [8, 256, 46, 62] + batch, dim, ht, wd = vis_feat.shape + + if self.pos_code_type != 'bias': + # voxels_pos: [8, 46, 62, 2] + voxels_pos_normed = voxels_pos / voxels_pos.max() + # voxels_pos_normed: [B0, num_voxels, 2] + # pos_embed: [B0, num_voxels, 256] + voxels_pos_normed = voxels_pos_normed.view(batch, ht * wd, -1) + pos_embed = self.pos_code_lookup_cache(vis_feat.shape, vis_feat.device, voxels_pos_normed) + pos_biases = None + else: + pos_embed = 0 + # SlidingPosBiases2D() may be a bit slow. So only generate when necessary. + if return_pos_biases: + # pos_biases: [1, 1, H, W, H, W] + pos_biases = self.pos_code_lookup_cache(vis_feat.shape, vis_feat.device, None) + # pos_biases: [1, 1, H*W, H*W] + pos_biases = pos_biases.reshape(1, 1, ht*wd, ht*wd) + else: + # Simply discard pos_biases. Used when encoding the 2nd frame. + # As for cross-frame attention, only one group of pos_biases is required + # (added to the cross-frame attentio scores). + # When encoding the 1st frame, pos_biases is already returned, no need + # another group of pos_biases. + pass + + vis_feat = vis_feat.view(batch, dim, ht * wd).transpose(1, 2) + + feat_comb = vis_feat + self.pos_code_weight * pos_embed + feat_normed = self.comb_norm_layer(feat_comb) + feat_normed = self.dropout(feat_normed) + + if return_pos_biases: + return feat_normed, pos_biases + else: + return feat_normed diff --git a/ptlflow/models/craft/setrans_ablation.py b/ptlflow/models/craft/setrans_ablation.py new file mode 100644 index 0000000..2f415e6 --- /dev/null +++ b/ptlflow/models/craft/setrans_ablation.py @@ -0,0 +1,249 @@ +import math + +import torch +import torch.nn as nn + +def positionalencoding2d(pos_embed_dim, height, width): + """ + :param pos_embed_dim: dimension of the model embeddings + :param height: height of the positions + :param width: width of the positions + :return: height * width * pos_embed_dim matrix + """ + if pos_embed_dim % 4 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(pos_embed_dim)) + pe = torch.zeros(pos_embed_dim, height, width) + # Each dimension use half of pos_embed_dim + pos_embed_dim = int(pos_embed_dim / 2) + div_term = torch.exp(torch.arange(0., pos_embed_dim, 2) * + -(math.log(10000.0) / pos_embed_dim)) + pos_w = torch.arange(0., width).unsqueeze(1) + pos_h = torch.arange(0., height).unsqueeze(1) + pe[0:pos_embed_dim:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + pe[1:pos_embed_dim:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + pe[pos_embed_dim::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + pe[pos_embed_dim + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + pe = pe.permute(1, 2, 0) + return pe + +class RandPosEmbedder(nn.Module): + def __init__(self, pos_dim, pos_embed_dim, shape, affine): + super().__init__() + self.pos_dim = pos_dim + self.pos_embed_dim = pos_embed_dim + height, width = shape + self.pos_embed = nn.Embedding(height * width, pos_embed_dim) + self.pos_embed_norm_layer = nn.LayerNorm(self.pos_embed_dim, eps=1e-12, elementwise_affine=affine) + print("Random discrete embedder for positional encoding ablation created") + + def forward(self, pos_normed): + B, N, D = pos_normed.shape + pos_embed_1 = self.pos_embed.weight + pos_embed_out_1 = self.pos_embed_norm_layer(pos_embed_1) + pos_embed_out = pos_embed_out_1.unsqueeze(0).repeat((B, 1, 1)) + return pos_embed_out + +class SinuPosEmbedder(nn.Module): + def __init__(self, pos_dim, pos_embed_dim, shape, affine): + super().__init__() + self.pos_dim = pos_dim + self.pos_embed_dim = pos_embed_dim + self.pos_embed = positionalencoding2d(pos_embed_dim, shape[0], shape[1]) + self.pos_embed = self.pos_embed.cuda().reshape((shape[0] * shape[1], pos_embed_dim)) + print("Sinu embedder for positional encoding ablation created") + + def forward(self, pos_normed): + B, N, D = pos_normed.shape + pos_embed_out = self.pos_embed.unsqueeze(0).repeat((B, 1, 1)) + return pos_embed_out + +class ZeroEmbedder(nn.Module): + def __init__(self, pos_embed_dim): + super().__init__() + self.pos_embed_dim = pos_embed_dim + print("Zero embedder for positional encoding ablation created") + + def forward(self, pos_normed): + B, N, D = pos_normed.shape + zero_pos_embed = torch.zeros(B, N, self.pos_embed_dim, requires_grad=False).cuda() + return zero_pos_embed + +# MM*Mid, MM*Output are the same as in segtran. Just for use by ExpandedFeatTrans. +class MMPrivateMid(nn.Module): + def __init__(self, config): + super(MMPrivateMid, self).__init__() + # Use 1x1 convolution as a group linear layer. + # Equivalent to each group going through a respective nn.Linear(). + self.num_modes = config.num_modes + self.feat_dim = config.feat_dim + feat_dim_allmode = self.feat_dim * self.num_modes + self.group_linear = nn.Conv1d(feat_dim_allmode, feat_dim_allmode, 1, groups=self.num_modes) + self.mid_act_fn = config.act_fun + + def forward(self, x): + x_trans = self.group_linear(x) # [B0, 1024*8, 50] -> [B0, 1024*8, 50] + x_act = self.mid_act_fn(x_trans) # [B0, 1024*8, 50] + return x + +class MMSharedMid(nn.Module): + def __init__(self, config): + super(MMSharedMid, self).__init__() + self.num_modes = config.num_modes + self.feat_dim = config.feat_dim + feat_dim_allmode = self.feat_dim * self.num_modes + self.shared_linear = nn.Linear(self.feat_dim, self.feat_dim) + self.mid_act_fn = config.act_fun + + # x: [B0, 1024*8, 50] or [B0, 8, 50, 1024] + def forward(self, x): + if len(x.shape) == 3: + # shape_4d: [B0, 8, 1024, 50]. + shape_4d = ( x.shape[0], self.num_modes, self.feat_dim, x.shape[2] ) + # x_4d: [B0, 8, 50, 1024]. + x_4d = x.view(shape_4d).permute([0, 1, 3, 2]) + reshaped = True + else: + x_4d = x + reshaped = False + + x_trans = self.shared_linear(x_4d) + x_act = self.mid_act_fn(x_trans) + + if reshaped: + # restore the original shape + x_act = x_act.permute([0, 1, 3, 2]).reshape(x.shape) + + return x_act + +# MMPrivateOutput/MMSharedOutput <- ExpandedFeatTrans <- SelfAttFeatTrans <- SegtranFusionEncoder. +# MM***Output has a shortcut (residual) connection. +class MMPrivateOutput(nn.Module): + def __init__(self, config): + super(MMPrivateOutput, self).__init__() + self.num_modes = config.num_modes + self.feat_dim = config.feat_dim + feat_dim_allmode = self.feat_dim * self.num_modes + self.group_linear = nn.Conv1d(feat_dim_allmode, feat_dim_allmode, 1, groups=self.num_modes) + self.resout_norm_layer = nn.LayerNorm(self.feat_dim, eps=1e-12, elementwise_affine=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # x, shortcut: [B0, 1024*8, 50] + def forward(self, x, shortcut): + x = self.group_linear(x) + # x_comb: [B0, 1024*8, 50]. Residual connection. + x_comb = x + shortcut + shape_4d = ( x.shape[0], self.num_modes, self.feat_dim, x.shape[2] ) + # x_comb_4d, x_drop_4d: [B0, 8, 50, 1024]. + x_comb_4d = x.view(shape_4d).permute([0, 1, 3, 2]) + x_drop_4d = self.dropout(x_comb_4d) + x_normed = self.resout_norm_layer(x_drop_4d) + return x_normed + +# MMPrivateOutput/MMSharedOutput <- ExpandedFeatTrans <- SelfAttFeatTrans <- SegtranFusionEncoder. +# MM***Output has a shortcut (residual) connection. +class MMSharedOutput(nn.Module): + # feat_dim_allmode is not used. Just to keep the ctor arguments the same as MMPrivateOutput. + def __init__(self, config): + super(MMSharedOutput, self).__init__() + self.num_modes = config.num_modes + self.feat_dim = config.feat_dim + self.shared_linear = nn.Linear(self.feat_dim, self.feat_dim) + self.resout_norm_layer = nn.LayerNorm(self.feat_dim, eps=1e-12, elementwise_affine=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # x, shortcut: [B0, 1024*8, 50] or [B0, 8, 50, 1024] + def forward(self, x, shortcut): + # shape_4d: [B0, 8, 1024, 50]. + shape_4d = ( x.shape[0], self.num_modes, self.feat_dim, x.shape[2] ) + if len(x.shape) == 3: + x_4d = x.view(shape_4d).permute([0, 1, 3, 2]) + else: + x_4d = x + if len(shortcut.shape) == 3: + shortcut_4d = shortcut.view(shape_4d).permute([0, 1, 3, 2]) + else: + shortcut_4d = shortcut + + # x_4d, shortcut_4d: [B0, 8, 50, 1024]. + x_trans = self.shared_linear(x_4d) + # x_4d, x_comb: [B0, 8, 50, 1024]. Residual connection. + x_comb = x_trans + shortcut_4d + x_drop = self.dropout(x_comb) + x_normed = self.resout_norm_layer(x_drop) + return x_normed + +# MultiHeadFeatTrans <- SelfAttFeatTrans. +# MultiHeadFeatTrans has a residual connection. +# We "misuse" num_modes for num_heads, to avoid introducing extra config parameters. +class MultiHeadFeatTrans(nn.Module): + def __init__(self, config, name): + super(MultiHeadFeatTrans, self).__init__() + self.config = config + self.name = name + self.in_feat_dim = config.in_feat_dim + self.feat_dim = config.feat_dim + self.num_modes = config.num_modes + self.feat_dim_onehead = self.feat_dim // self.num_modes + self.feat_dim_allhead = self.feat_dim_onehead * self.num_modes + # first_linear is the value projection in other transformer implementations. + # The output of first_linear will be divided into num_modes groups. + # first_linear is always 'private' for each group, i.e., + # parameters are not shared (parameter sharing makes no sense). + self.first_linear = nn.Linear(self.in_feat_dim, self.feat_dim_allhead) + + print("%s: pool_modes_feat=concat, trans_output_type=%s" % \ + (self.name, config.trans_output_type)) + + # Disable multiple modes for intermediate and output layers. + config.num_modes = 1 + self.intermediate = MMSharedMid(self.config) + + if config.trans_output_type == 'shared': + self.output = MMSharedOutput(config) + elif config.trans_output_type == 'private': + self.output = MMPrivateOutput(config) + + self.apply_attn_early = config.apply_attn_early + + def add_identity_bias(self): + if self.config.feattrans_lin1_idbias_scale > 0: + identity_weight = torch.diag(torch.ones(self.feat_dim)) * self.config.initializer_range \ + * self.config.feattrans_lin1_idbias_scale + # Only bias the weight of the first mode. + # The total initial "weight mass" in each row is reduced by 1024*0.02*0.5. + self.first_linear.weight.data[:self.feat_dim] = \ + self.first_linear.weight.data[:self.feat_dim] * 0.5 + identity_weight + + def forward(self, input_feat, attention_probs, attention_scores): + # input_feat: [B0, 50, 1024], mm_first_feat: [B0, 50, 1024*8] + mm_first_feat = self.first_linear(input_feat) + # mm_first_feat_act after permute: [B0, 1024*8, 50] + mm_first_feat = mm_first_feat.permute(0, 2, 1) + + if self.apply_attn_early: + # shape_4d: [B0, 8, 1024, 50] + shape_4d = ( mm_first_feat.shape[0], self.num_modes, self.feat_dim_onehead, mm_first_feat.shape[2] ) + # mm_first_feat_4d: [B0, 8, 50, 1024] + mm_first_feat_4d = mm_first_feat.view(shape_4d).permute([0, 1, 3, 2]) + mm_first_feat_fusion = torch.matmul(attention_probs, mm_first_feat_4d) + mm_first_feat_fusion_3d = mm_first_feat_fusion.permute([0, 1, 3, 2]).reshape(mm_first_feat.shape) + mm_first_feat = mm_first_feat_fusion_3d + + # mm_mid_feat: [B0, 1024*8, 50]. Group linear & gelu of mm_first_feat. + mm_mid_feat = self.intermediate(mm_first_feat) + # mm_last_feat: [B0, 8, 50, 1024]. Group/shared linear & residual & Layernorm + mm_last_feat = self.output(mm_mid_feat, mm_first_feat) + + if (attention_probs is not None) and (not self.apply_attn_early): + # matmul(t1, t2): (h1, w1), (w1, w2) => (h1, w2) + # [B0, 8, 50, 50][B0, 8, 50, 1024] -> mm_trans_feat: [B0, 8, 50, 1024] + mm_trans_feat = torch.matmul(attention_probs, mm_last_feat) + else: + mm_trans_feat = mm_last_feat + + trans_feat = mm_trans_feat.squeeze(1) + + # trans_feat: [B0, 50, 1024], if pool_modes_feat != 'none', + # or [B0, 8, 50, 1024] if pool_modes_feat == 'none'. + return trans_feat diff --git a/ptlflow/models/craft/update.py b/ptlflow/models/craft/update.py new file mode 100644 index 0000000..0b8cdcb --- /dev/null +++ b/ptlflow/models/craft/update.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .gma import Aggregate +from .setrans import ExpandedFeatTrans + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=128+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + # When both f1 and f2 are applied SS-Trans, corr_multiplier = 2. + # Otherwise corr_multiplier = 1. + cor_planes = args.corr_levels * args.corr_multiplier * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(input_dim=hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net_feat, inp_feat, corr, flow, upsample=True): + # motion_features: (256+2)-dimensional. + motion_features = self.encoder(flow, corr) + inp_feat = torch.cat([inp_feat, motion_features], dim=1) + + net_feat = self.gru(net_feat, inp_feat) + delta_flow = self.flow_head(net_feat) + + # scale mask to balance gradients + mask = .25 * self.mask(net_feat) + return net_feat, mask, delta_flow + + +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + self.use_setrans = args.use_setrans + if self.use_setrans: + self.intra_trans_config = args.intra_trans_config + self.aggregator = ExpandedFeatTrans(self.intra_trans_config, 'Motion Aggregator') + else: + # Aggregate is attention with a (learnable-weighted) skip connection, without FFN. + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads) + + def forward(self, net, inp, corr, flow, attention): + # encoder: BasicMotionEncoder + # corr: [3, 676, 50, 90] + motion_features = self.encoder(flow, corr) + # motion_features: 128-dim + if self.use_setrans: + # attention is multi-mode. ExpandedFeatTrans takes multi-mode attention. + B, C, H, W = motion_features.shape + motion_features_3d = motion_features.view(B, C, H*W).permute(0, 2, 1) + # motion_features_3d: [1, 7040, 128], attention: [1, 4, 7040, 7040] + motion_features_global_3d = self.aggregator(motion_features_3d, attention) + motion_features_global = motion_features_global_3d.view(B, H, W, C).permute(0, 3, 1, 2) + else: + # attention: [8, 1, 2852, 2852]. motion_features: [8, 128, 46, 62]. + motion_features_global = self.aggregator(attention, motion_features) + + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/ptlflow/models/craft/utils.py b/ptlflow/models/craft/utils.py new file mode 100644 index 0000000..5440654 --- /dev/null +++ b/ptlflow/models/craft/utils.py @@ -0,0 +1,196 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate +import os + +# Only print on GPU0. Avoid duplicate messages. +def print0(*print_args, **kwargs): + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + if local_rank == 0: + print(*print_args, **kwargs) + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel', mod=8): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // mod) + 1) * mod - self.ht) % mod + pad_wd = (((self.wd // mod) + 1) * mod - self.wd) % mod + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +# Map the flow of each pixel to new locations indicated by the flow. +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].expand(batch, -1, -1, -1) + + +def coords_grid_y_first(batch, ht, wd): + """Place y grid before x grid""" + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords, dim=0).int() + return coords[None].expand(batch, -1, -1, -1) + +# soft_argmax is not used anywhere. +def soft_argmax(corr_me, B, H1, W1): + # Implement soft argmin + coords, feats = corr_me.decomposed_coordinates_and_features + + # Computing soft argmin + flow_pred = torch.zeros(B, 2, H1, W1).to(corr_me.device) + for batch, (coord, feat) in enumerate(zip(coords, feats)): + coord_img_1 = coord[:, :2].to(corr_me.device) + coord_img_2 = coord[:, 2:].to(corr_me.device) + # relative positions (flow hypotheses) + rel_pos = (coord_img_2 - coord_img_1) + # augmented indices + aug_coord_img_1 = (coord_img_1[:, 0:1] * W1 + coord_img_1[:, 1:2]).long() + # run softmax on the score + weight = scatter_softmax(feat, aug_coord_img_1, dim=0) + rel_pos_weighted = weight * rel_pos + out = scatter_add(rel_pos_weighted, aug_coord_img_1, dim=0) + # Need to permute (y, x) to (x, y) for flow + flow_pred[batch] = out[:, [1,0]].view(H1, W1, 2).permute(2, 0, 1) + return flow_pred + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def upflow4(flow, mode='bilinear'): + new_size = (4 * flow.shape[2], 4 * flow.shape[3]) + return 4 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def upflow2(flow, mode='bilinear'): + new_size = (2 * flow.shape[2], 2 * flow.shape[3]) + return 2 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def downflow8(flow, mode='bilinear'): + new_size = (flow.shape[2] // 8, flow.shape[3] // 8) + return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 8 + + +def downflow4(flow, mode='bilinear'): + new_size = (flow.shape[2] // 4, flow.shape[3] // 4) + return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 4 + + +def compute_interpolation_weights(yx_warped): + # yx_warped: [N, 2] + y_warped = yx_warped[:, 0] + x_warped = yx_warped[:, 1] + + # elementwise operations below + y_f = torch.floor(y_warped) + y_c = y_f + 1 + x_f = torch.floor(x_warped) + x_c = x_f + 1 + + w0 = (y_c - y_warped) * (x_c - x_warped) + w1 = (y_warped - y_f) * (x_c - x_warped) + w2 = (y_c - y_warped) * (x_warped - x_f) + w3 = (y_warped - y_f) * (x_warped - x_f) + + weights = [w0, w1, w2, w3] + indices = [torch.stack([y_f, x_f], dim=1), torch.stack([y_c, x_f], dim=1), + torch.stack([y_f, x_c], dim=1), torch.stack([y_c, x_c], dim=1)] + weights = torch.cat(weights, dim=1) + indices = torch.cat(indices, dim=2) + # indices = torch.cat(indices, dim=0) # [4*N, 2] + + return weights, indices + +# weights, indices = compute_interpolation_weights(xy_warped, b, h_i, w_i) + + +def compute_inverse_interpolation_img(weights, indices, img, b, h_i, w_i): + """ + weights: [b, h*w] + indices: [b, h*w] + img: [b, h*w, a, b, c, ...] + """ + w0, w1, w2, w3 = weights + ff_idx, cf_idx, fc_idx, cc_idx = indices + + k = len(img.size()) - len(w0.size()) + img_0 = w0[(...,) + (None,) * k] * img + img_1 = w1[(...,) + (None,) * k] * img + img_2 = w2[(...,) + (None,) * k] * img + img_3 = w3[(...,) + (None,) * k] * img + + img_out = torch.zeros(b, h_i * w_i, *img.shape[2:]).type_as(img) + + ff_idx = torch.clamp(ff_idx, min=0, max=h_i * w_i - 1) + cf_idx = torch.clamp(cf_idx, min=0, max=h_i * w_i - 1) + fc_idx = torch.clamp(fc_idx, min=0, max=h_i * w_i - 1) + cc_idx = torch.clamp(cc_idx, min=0, max=h_i * w_i - 1) + + img_out.scatter_add_(1, ff_idx[(...,) + (None,) * k].expand_as(img_0), img_0) + img_out.scatter_add_(1, cf_idx[(...,) + (None,) * k].expand_as(img_1), img_1) + img_out.scatter_add_(1, fc_idx[(...,) + (None,) * k].expand_as(img_2), img_2) + img_out.scatter_add_(1, cc_idx[(...,) + (None,) * k].expand_as(img_3), img_3) + + return img_out # [b, h_i*w_i, ...] diff --git a/ptlflow/models/csflow/LICENSE b/ptlflow/models/csflow/LICENSE new file mode 100644 index 0000000..5c84e85 --- /dev/null +++ b/ptlflow/models/csflow/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Hao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/ptlflow/models/csflow/README.md b/ptlflow/models/csflow/README.md new file mode 100644 index 0000000..e2b6e1a --- /dev/null +++ b/ptlflow/models/csflow/README.md @@ -0,0 +1,24 @@ +# RAFT + +## Original code + +[https://github.com/MasterHow/CSFlow](https://github.com/MasterHow/CSFlow) + +## Code license + +See [LICENSE](LICENSE). + +## Pretrained weights license + +Not specified. + +## Citation + +``` +@article{shi2022csflow, + title={CSFlow: Learning optical flow via cross strip correlation for autonomous driving}, + author={Shi, Hao and Zhou, Yifan and Yang, Kailun and Yin, Xiaoting and Wang, Kaiwei}, + journal={arXiv preprint arXiv:2202.00909}, + year={2022} +} +``` \ No newline at end of file diff --git a/ptlflow/models/csflow/__init__.py b/ptlflow/models/csflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ptlflow/models/csflow/csflow.py b/ptlflow/models/csflow/csflow.py new file mode 100644 index 0000000..60ccc55 --- /dev/null +++ b/ptlflow/models/csflow/csflow.py @@ -0,0 +1,865 @@ +from argparse import ArgumentParser, Namespace + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import DeformConv2d + +from ..base_model.base_model import BaseModel + + +class SequenceLoss(nn.Module): + def __init__(self, args): + super().__init__() + self.gamma = args.gamma + self.max_flow = args.max_flow + + def forward(self, outputs, inputs): + """ Loss function defined over sequence of flow predictions """ + + flow_preds = outputs['flow_preds'] + flow_gt = inputs['flows'][:, 0] + valid = inputs['valids'][:, 0] + + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < self.max_flow) + + for i in range(n_predictions): + i_weight = self.gamma**(n_predictions - i - 1) + i_loss = (flow_preds[i] - flow_gt).abs() + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + + return flow_loss + + +class CSFlow(BaseModel): + pretrained_checkpoints = { + 'chairs': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/csflow-chairs-458a9436.ckpt', + 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/csflow-things-ebdd403b.ckpt', + 'kitti': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/csflow-kitti-dc66357a.ckpt' + } + + def __init__(self, + args: Namespace) -> None: + super().__init__( + args=args, + loss_fn=SequenceLoss(args), + output_stride=8) + + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + + if 'dropout' not in self.args: + self.args.dropout = 0 + + # feature network, context network, and update block + self.fnet = BasicEncoder( + output_dim=256, norm_fn='instance', dropout=args.dropout) + + self.cnet = BasicEncoder( + output_dim=hdim + cdim, + norm_fn='batch', + dropout=args.dropout) + + self.strip_corr_block_v2 = StripCrossCorrMap_v2( + in_chan=256, out_chan=256) + self.update_block = BasicUpdateBlock( + self.args, hidden_dim=hdim) + + @staticmethod + def add_model_specific_args(parent_parser=None): + parent_parser = BaseModel.add_model_specific_args(parent_parser) + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--corr_levels', type=int, default=4) + parser.add_argument('--corr_radius', type=int, default=4) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--gamma', type=float, default=0.8) + parser.add_argument('--max_flow', type=float, default=400.0) + parser.add_argument('--iters', type=int, default=12) + parser.add_argument('--gen_fmap', action='store_true') + parser.add_argument('--skip_encode', action='store_true') + return parser + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids. + flow = coords1 - coords0, Modified by Hao + """ + N, C, H, W = img.shape + + coords0 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex + combination.""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward(self, inputs, flow_init=None): + """ Estimate optical flow between pair of frames """ + image1 = inputs['images'][:, 0] + image2 = inputs['images'][:, 1] + """Estimate optical flow between pair of frames.""" + + if not self.args.skip_encode: + # Modified, take image pairs as input + image1 = 2 * image1 - 1.0 + image2 = 2 * image2 - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + # run the feature network + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + else: + fmap1 = image1 * 255.0 + fmap2 = image2 * 255.0 + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the context network + if not self.args.skip_encode: + cnet = self.cnet(image1) + + if not self.training: + if self.args.gen_fmap: + return fmap1, fmap2, cnet + else: + cnet = inputs['images'][:, 2] * 255.0 + + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + strip_coor_map, strip_corr_map_w, strip_corr_map_h = self.strip_corr_block_v2( + [fmap1, fmap2]) + corr_fn = CorrBlock_v2( + fmap1, fmap2, strip_coor_map, radius=self.args.corr_radius) + + if not self.args.skip_encode: + coords0, coords1 = self.initialize_flow(image1) + else: + b, c, h, w = fmap1.shape + image1 = torch.zeros(b, c, 8 * h, 8 * w).cuda() + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + + # init flow with regression before GRU iters + corr_w_act = torch.nn.functional.softmax( + strip_corr_map_w, dim=3) # B H1 W1 1 W2 + corr_h_act = torch.nn.functional.softmax( + strip_corr_map_h, dim=4) # B H1 W1 H2 1 + + flo_v = corr_w_act.mul(strip_corr_map_w) # B H1 W1 1 W2 + flo_u = corr_h_act.mul(strip_corr_map_h) # B H1 W1 H2 1 + + flow_v = torch.sum(flo_v, dim=4).squeeze(dim=3) # B H1 W1 + flow_u = torch.sum(flo_u, dim=3).squeeze(dim=3) # B H1 W1 + + corr_init = torch.stack((flow_u, flow_v), dim=1) # B 2 H1 W1 + + coords1 = coords1.detach() + coords1 = coords1 + corr_init + + # add loss + flow_up = upflow8(coords1 - coords0) + flow_predictions.append(flow_up) + + for itr in range(self.args.iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + net, up_mask, delta_flow = self.update_block( + net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if self.training: + outputs = { + 'flows': flow_up[:, None], + 'flow_preds': flow_predictions + } + else: + outputs = { + 'flows': flow_up[:, None], + 'flow_small': coords1 - coords0 + } + + return outputs + + +class StripCrossCorrMap_v2(nn.Module): + """Strip Cross Corr Augmentation Module by Hao, version2.0""" + + def __init__(self, in_chan=256, out_chan=256, *args, **kwargs): + super(StripCrossCorrMap_v2, self).__init__() + self.conv1_1 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1_2 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv2_1 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv2_2 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + fmap1, fmap2 = x + + # vertical query map + fmap1_w = self.conv1_1(fmap1) # B, 64, H, W + batchsize, c_middle, h, w = fmap1_w.size() + fmap1_w = fmap1_w.view(batchsize, c_middle, -1) + + # horizontal query map + fmap1_h = self.conv1_2(fmap1) # B, 64, H, W + batchsize, c_middle, h, w = fmap1_h.size() + fmap1_h = fmap1_h.view(batchsize, c_middle, -1) + + # vertical striping map + fmap2_w = self.conv2_1(fmap2) # B, 64, H, W + fmap2_w = F.avg_pool2d(fmap2_w, [h, 1]) + fmap2_w = fmap2_w.view(batchsize, c_middle, -1).permute(0, 2, 1) + + # horizontal striping map + fmap2_h = self.conv2_2(fmap2) # B, 64, H, W + fmap2_h = F.avg_pool2d(fmap2_h, [1, w]) + fmap2_h = fmap2_h.view(batchsize, c_middle, -1).permute(0, 2, 1) + + # cross strip corr map + strip_corr_map_w = torch.bmm(fmap2_w, fmap1_w).\ + view(batchsize, w, h, w, 1).permute(0, 2, 3, 4, 1) # B H1 W1 1 W2 + strip_corr_map_h = torch.bmm(fmap2_h, fmap1_h).\ + view(batchsize, h, h, w, 1).permute(0, 2, 3, 1, 4) # B H1 W1 H2 1 + + return (strip_corr_map_w + strip_corr_map_h).view( + batchsize, h, w, 1, h, w), strip_corr_map_w, strip_corr_map_h + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, torch.nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class ConvBNReLU(nn.Module): + """Conv with BN and ReLU, used for Strip Corr Module""" + + def __init__(self, + in_chan, + out_chan, + ks=3, + stride=1, + padding=1, + *args, + **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d( + in_chan, + out_chan, + kernel_size=ks, + stride=stride, + padding=padding, + bias=False) + self.bn = torch.nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class SmallUpdateBlock(nn.Module): + + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + """Modified by Hao, support for CSFlow""" + + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder_v2(args) + self.gru = SepConvGRU( + hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + +def pool2x(x): + return F.avg_pool2d(x, 3, stride=2, padding=1) + + +def interp(x, dest): + interp_args = {'mode': 'bilinear', 'align_corners': True} + return F.interpolate(x, dest.shape[2:], **interp_args) + + +class BasicEncoder(nn.Module): + + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + from torch.nn.modules.utils import _pair + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, + (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock( + self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, + (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock( + self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class CorrBlock_v2: + """Corr Block, modified by Hao, concat SC with 4D corr""" + + def __init__(self, + fmap1, + fmap2, + strip_coor_map=None, + num_levels=4, + radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock_v2.corr(fmap1, fmap2) + + if strip_coor_map is not None: + # strip correlation augmentation with concat + corr = torch.cat((corr, strip_coor_map), dim=3) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """Wrapper for grid_sample, uses pixel coordinates.""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid( + torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate( + flow, size=new_size, mode=mode, align_corners=True) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm4) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicMotionEncoder_v2(nn.Module): + """Get Motion Feature from CSFlow, by Hao""" + + def __init__(self, args): + super(BasicMotionEncoder_v2, self).__init__() + # double cor_plances due to concat aug + cor_planes = 2 * (args.corr_levels * (2 * args.corr_radius + 1)**2) + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallMotionEncoder(nn.Module): + + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SepConvGRU(nn.Module): + + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class FlowHead(nn.Module): + + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h diff --git a/ptlflow/models/flowformer/LICENSE b/ptlflow/models/flowformer/LICENSE new file mode 100644 index 0000000..f49a4e1 --- /dev/null +++ b/ptlflow/models/flowformer/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/ptlflow/models/flowformer/README.md b/ptlflow/models/flowformer/README.md new file mode 100644 index 0000000..3e5a808 --- /dev/null +++ b/ptlflow/models/flowformer/README.md @@ -0,0 +1,24 @@ +# RAFT + +## Original code + +[https://github.com/drinkingcoder/FlowFormer-Official](https://github.com/drinkingcoder/FlowFormer-Official) + +## Code license + +See [LICENSE](LICENSE). + +## Pretrained weights license + +Not specified. + +## Citation + +``` +@article{huang2022flowformer, + title={{FlowFormer}: A Transformer Architecture for Optical Flow}, + author={Huang, Zhaoyang and Shi, Xiaoyu and Zhang, Chao and Wang, Qiang and Cheung, Ka Chun and Qin, Hongwei and Dai, Jifeng and Li, Hongsheng}, + journal={{ECCV}}, + year={2022} +} +``` \ No newline at end of file diff --git a/ptlflow/models/flowformer/__init__.py b/ptlflow/models/flowformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ptlflow/models/flowformer/attention.py b/ptlflow/models/flowformer/attention.py new file mode 100644 index 0000000..7da17e9 --- /dev/null +++ b/ptlflow/models/flowformer/attention.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +from torch import einsum + +from einops import rearrange + +class BroadMultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(BroadMultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q.squeeze(), 'i (heads d) -> heads i d', heads=self.heads) + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) + + dots = einsum('hid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, _, _ = K.shape + _, N, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads n d -> b n (heads d)', b=B, n=N) + + return out + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) + + dots = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, HW, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + + return out + +class MultiHeadAttentionRelative(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttentionRelative, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K, Q_r, K_r): + """ + Q: [BH1W1, 1, dim] + K: [BH1W1, H3W3, dim] + Q_r: [BH1W1, H3W3, dim] + K_r: [BH1W1, H3W3, dim] + """ + + Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, 1, dim] + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + + # context-context similarity + c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads 1 H3W3] + # context-position similarity + c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3] + # position-context similarity + p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) * self.scale + p_c = torch.squeeze(p_c, dim=4) + p_c = p_c.permute(0, 1, 3, 2) + dots = c_c + c_p + p_c + return self.attend(dots) + + def forward(self, Q, K, V, Q_r, K_r): + attn = self.attend_with_rpe(Q, K, Q_r, K_r) + B, HW, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + + return out + +def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device) + return torch.cat([torch.sin(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.sin(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR)], dim=-1) + +def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device) + return torch.cat([torch.sin(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.sin(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands))], dim=-1) diff --git a/ptlflow/models/flowformer/cnn.py b/ptlflow/models/flowformer/cnn.py new file mode 100644 index 0000000..2acf002 --- /dev/null +++ b/ptlflow/models/flowformer/cnn.py @@ -0,0 +1,577 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_ +import math +import numpy as np + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + mul = input_dim // 3 + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64 * mul) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64 * mul) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 * mul + self.layer1 = self._make_layer(64 * mul, stride=1) + self.layer2 = self._make_layer(96 * mul, stride=2) + self.layer3 = self._make_layer(128 * mul, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class ConvNets(nn.Module): + def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1): + super(ConvNets, self).__init__() + + self.conv_first = nn.Conv2d(in_dim, inter_dim, kernel_size=3, padding=1, stride=stride) + self.conv_last = nn.Conv2d(inter_dim, out_dim, kernel_size=3, padding=1, stride=stride) + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList( + [ResidualBlock(inter_dim, inter_dim, norm_fn='none', stride=1) for i in range(depth)]) + + def forward(self, x): + x = self.relu(self.conv_first(x)) + for inter_conv in self.inter_convs: + x = inter_conv(x) + x = self.conv_last(x) + return x + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.motion_feature_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicFuseMotion(nn.Module): + def __init__(self, args): + super(BasicFuseMotion, self).__init__() + cor_planes = args.motion_feature_dim + out_planes = args.query_latent_dim + + self.normf1 = nn.InstanceNorm2d(128) + self.normf2 = nn.InstanceNorm2d(128) + + self.convf1 = nn.Conv2d(2, 128, 3, padding=1) + self.convf2 = nn.Conv2d(128, 128, 3, padding=1) + self.convf3 = nn.Conv2d(128, 64, 3, padding=1) + + s = 1 + self.normc1 = nn.InstanceNorm2d(256*s) + self.normc2 = nn.InstanceNorm2d(256*s) + self.normc3 = nn.InstanceNorm2d(256*s) + + self.convc1 = nn.Conv2d(cor_planes+128, 256*s, 1, padding=0) + self.convc2 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.convc3 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.convc4 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.conv = nn.Conv2d(256*s + 64, out_planes, 1, padding=0) + + def forward(self, flow, feat, context1=None): + flo = F.relu(self.normf1(self.convf1(flow))) + flo = F.relu(self.normf2(self.convf2(flo))) + flo = self.convf3(flo) + + feat = torch.cat([feat, context1], dim=1) + feat = F.relu(self.normc1(self.convc1(feat))) + feat = F.relu(self.normc2(self.convc2(feat))) + feat = F.relu(self.normc3(self.convc3(feat))) + feat = self.convc4(feat) + + feat = torch.cat([flo, feat], dim=1) + feat = F.relu(self.conv(feat)) + + return feat + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + +class DirectMeanMaskPredictor(nn.Module): + def __init__(self, args): + super(DirectMeanMaskPredictor, self).__init__() + self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256) + self.mask = nn.Sequential( + nn.Conv2d(args.predictor_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, motion_features): + delta_flow = self.flow_head(motion_features) + mask = .25 * self.mask(motion_features) + + return mask, delta_flow + +class BaiscMeanPredictor(nn.Module): + def __init__(self, args, hidden_dim=128): + super(BaiscMeanPredictor, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, latent, flow): + motion_features = self.encoder(flow, latent) + delta_flow = self.flow_head(motion_features) + mask = .25 * self.mask(motion_features) + + return mask, delta_flow + +class BasicRPEEncoder(nn.Module): + def __init__(self, args): + super(BasicRPEEncoder, self).__init__() + self.args = args + dim = args.query_latent_dim + self.encoder = nn.Sequential( + nn.Linear(2, dim // 2), + nn.ReLU(inplace=True), + nn.Linear(dim // 2, dim), + nn.ReLU(inplace=True), + nn.Linear(dim, dim) + ) + + def forward(self, rpe_tokens): + return self.encoder(rpe_tokens) + +from .twins import Block, CrossBlock + +class TwinsSelfAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsSelfAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = 0. + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True) + self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + x = self.global_block(x, size) + + tgt = self.local_block(tgt, size) + tgt = self.global_block(tgt, size) + return x, tgt + +class TwinsCrossAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsCrossAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = 0. + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True) + self.global_block = CrossBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + tgt = self.local_block(tgt, size) + x, tgt = self.global_block(x, tgt, size) + + return x, tgt diff --git a/ptlflow/models/flowformer/convnext.py b/ptlflow/models/flowformer/convnext.py new file mode 100644 index 0000000..7740b2e --- /dev/null +++ b/ptlflow/models/flowformer/convnext.py @@ -0,0 +1,86 @@ +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +class ConvNextLayer(nn.Module): + def __init__(self, dim, depth=4): + super().__init__() + self.net = nn.Sequential( + *[ConvNextBlock(dim=dim) for j in range(depth)] + ) + + def forward(self, x): + return self.net(x) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + +class ConvNextBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + # print(f"conv next layer") + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + x + return x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x \ No newline at end of file diff --git a/ptlflow/models/flowformer/decoder.py b/ptlflow/models/flowformer/decoder.py new file mode 100644 index 0000000..77235bc --- /dev/null +++ b/ptlflow/models/flowformer/decoder.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from .utils import coords_grid, bilinear_sampler +from .attention import MultiHeadAttention, LinearPositionEmbeddingSine, ExpPositionEmbeddingSine + +from timm.models.layers import DropPath + +from .gru import BasicUpdateBlock, GMAUpdateBlock +from .gma import Attention + +def initialize_flow(img): + """ Flow is represented as difference between two means flow = mean1 - mean0""" + N, C, H, W = img.shape + mean = coords_grid(N, H, W).to(img.device) + mean_init = coords_grid(N, H, W).to(img.device) + + # optical flow computed as difference: flow = mean1 - mean0 + return mean, mean_init + +class CrossAttentionLayer(nn.Module): + # def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, add_flow_token=True, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0., pe='linear'): + super(CrossAttentionLayer, self).__init__() + + head_dim = qk_dim // num_heads + self.scale = head_dim ** -0.5 + self.query_token_dim = query_token_dim + self.pe = pe + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, v_dim, bias=True) + + self.proj = nn.Linear(v_dim*2, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout) + ) + self.add_flow_token = add_flow_token + self.dim = qk_dim + def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3): + """ + query_coord [B, 2, H1, W1] + """ + B, _, H1, W1 = query_coord.shape + + if key is None and value is None: + key = self.k(memory) + value = self.v(memory) + + # [B, 2, H1, W1] -> [BH1W1, 1, 2] + query_coord = query_coord.contiguous() + query_coord = query_coord.view(B, 2, -1).permute(0, 2, 1)[:,:,None,:].contiguous().view(B*H1*W1, 1, 2) + if self.pe == 'linear': + query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim) + elif self.pe == 'exp': + query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim) + + short_cut = query + query = self.norm1(query) + + if self.add_flow_token: + q = self.q(query+query_coord_enc) + else: + q = self.q(query_coord_enc) + k, v = key, value + + x = self.multi_head_attn(q, k, v) + + x = self.proj(torch.cat([x, short_cut],dim=2)) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x, k, v + +class MemoryDecoderLayer(nn.Module): + def __init__(self, dim, cfg): + super(MemoryDecoderLayer, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size # for converting coords into H2', W2' space + + query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim + qk_dim, v_dim = query_token_dim, query_token_dim + self.cross_attend = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, add_flow_token=cfg.add_flow_token, dropout=cfg.dropout) + + def forward(self, query, key, value, memory, coords1, size, size_h3w3): + """ + x: [B*H1*W1, 1, C] + memory: [B*H1*W1, H2'*W2', C] + coords1 [B, 2, H2, W2] + size: B, C, H1, W1 + 1. Note that here coords0 and coords1 are in H2, W2 space. + Should first convert it into H2', W2' space. + 2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0] + """ + x_global, k, v = self.cross_attend(query, key, value, memory, coords1, self.patch_size, size_h3w3) + B, C, H1, W1 = size + C = self.cfg.query_latent_dim + x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2) + return x_global, k, v + +class ReverseCostExtractor(nn.Module): + def __init__(self, cfg): + super(ReverseCostExtractor, self).__init__() + self.cfg = cfg + + def forward(self, cost_maps, coords0, coords1): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + BH1W1, heads, H2, W2 = cost_maps.shape + B, _, H1, W1 = coords1.shape + + assert (H1 == H2) and (W1 == W2) + assert BH1W1 == B*H1*W1 + + cost_maps = cost_maps.reshape(B, H1* W1*heads, H2, W2) + coords = coords1.permute(0, 2, 3, 1) + corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2] + corr = rearrange(corr, 'b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1', b=B, heads=heads, h1=H1, w1=W1, h2=H2, w2=W2) + + r = 4 + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device) + centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2) + delta = delta.view(1, 2*r+1, 2*r+1, 2) + coords = centroid + delta + corr = bilinear_sampler(corr, coords) + corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2) + return corr + +class MemoryDecoder(nn.Module): + def __init__(self, cfg): + super(MemoryDecoder, self).__init__() + dim = self.dim = cfg.query_latent_dim + self.cfg = cfg + + self.flow_token_encoder = nn.Sequential( + nn.Conv2d(81*cfg.cost_heads_num, dim, 1, 1), + nn.GELU(), + nn.Conv2d(dim, dim, 1, 1) + ) + self.proj = nn.Conv2d(256, 256, 1) + self.depth = cfg.decoder_depth + self.decoder_layer = MemoryDecoderLayer(dim, cfg) + + if self.cfg.gma: + self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128) + self.att = Attention(args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128) + else: + self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128) + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + def encode_flow_token(self, cost_maps, coords): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + r = 4 + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid = coords.reshape(batch*h1*w1, 1, 1, 2) + delta = delta.view(1, 2*r+1, 2*r+1, 2) + coords = centroid + delta + corr = bilinear_sampler(cost_maps, coords) + corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2) + return corr + + def forward(self, cost_memory, context, data={}, flow_init=None): + """ + memory: [B*H1*W1, H2'*W2', C] + context: [B, D, H1, W1] + """ + cost_maps = data['cost_maps'] + coords0, coords1 = initialize_flow(context) + + if flow_init is not None: + #print("[Using warm start]") + coords1 = coords1 + flow_init + + #flow = coords1 + + flow_predictions = [] + + context = self.proj(context) + net, inp = torch.split(context, [128, 128], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + if self.cfg.gma: + attention = self.att(inp) + + size = net.shape + key, value = None, None + + for idx in range(self.depth): + coords1 = coords1.detach() + + cost_forward = self.encode_flow_token(cost_maps, coords1) + #cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1) + + query = self.flow_token_encoder(cost_forward) + query = query.permute(0, 2, 3, 1).contiguous().view(size[0]*size[2]*size[3], 1, self.dim) + cost_global, key, value = self.decoder_layer(query, key, value, cost_memory, coords1, size, data['H3W3']) + if self.cfg.only_global: + corr = cost_global + else: + corr = torch.cat([cost_global, cost_forward], dim=1) + + flow = coords1 - coords0 + + if self.cfg.gma: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) + else: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # flow = delta_flow + coords1 = coords1 + delta_flow + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + flow_predictions.append(flow_up) + + if self.training: + return flow_predictions + else: + return flow_predictions[-1], coords1-coords0 diff --git a/ptlflow/models/flowformer/encoder.py b/ptlflow/models/flowformer/encoder.py new file mode 100644 index 0000000..13bc4e3 --- /dev/null +++ b/ptlflow/models/flowformer/encoder.py @@ -0,0 +1,356 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +import numpy as np + +from einops import rearrange + +from .utils import coords_grid +from .attention import BroadMultiHeadAttention, MultiHeadAttention, LinearPositionEmbeddingSine, ExpPositionEmbeddingSine +from .encoders import twins_svt_large +from typing import Tuple +from .twins import Size_ +from .cnn import BasicEncoder +from .mlpmixer import MLPMixerLayer +from .convnext import ConvNextLayer + +from timm.models.layers import DropPath + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe='linear'): + super().__init__() + self.patch_size = patch_size + self.dim = embed_dim + self.pe = pe + + # assert patch_size == 8 + if patch_size == 8: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim//4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim//2, embed_dim, kernel_size=6, stride=2, padding=2), + ) + elif patch_size == 4: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim//4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim//4, embed_dim, kernel_size=6, stride=2, padding=2), + ) + else: + print(f"patch size = {patch_size} is unacceptable.") + + self.ffn_with_coord = nn.Sequential( + nn.Conv2d(embed_dim*2, embed_dim*2, kernel_size=1), + nn.ReLU(), + nn.Conv2d(embed_dim*2, embed_dim*2, kernel_size=1) + ) + self.norm = nn.LayerNorm(embed_dim*2) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape # C == 1 + + pad_l = pad_t = 0 + pad_r = (self.patch_size - W % self.patch_size) % self.patch_size + pad_b = (self.patch_size - H % self.patch_size) % self.patch_size + x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) + + x = self.proj(x) + out_size = x.shape[2:] + + patch_coord = coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size + self.patch_size/2 # in feature coordinate space + patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1) + if self.pe == 'linear': + patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim) + elif self.pe == 'exp': + patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim) + patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(B, -1, out_size[0], out_size[1]) + + x_pe = torch.cat([x, patch_coord_enc], dim=1) + x = self.ffn_with_coord(x_pe) + x = self.norm(x.flatten(2).transpose(1, 2)) + + return x, out_size + +from .twins import Block, CrossBlock + +class GroupVerticalSelfAttentionLayer(nn.Module): + def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(GroupVerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = dropout + attn_drop_rate=0. + + self.block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True, vert_c_dim=cfg.vert_c_dim, groupattention=True, cfg=self.cfg) + + def forward(self, x, size, context=None): + x = self.block(x, size, context) + + return x + +class VerticalSelfAttentionLayer(nn.Module): + def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(VerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = dropout + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True, vert_c_dim=cfg.vert_c_dim) + self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True, vert_c_dim=cfg.vert_c_dim) + + def forward(self, x, size, context=None): + x = self.local_block(x, size, context) + x = self.global_block(x, size, context) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + +class SelfAttentionLayer(nn.Module): + def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(SelfAttentionLayer, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.multi_head_attn = MultiHeadAttention(dim, num_heads) + self.q, self.k, self.v = nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, bias=True) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = x + x = self.norm1(x) + + q, k, v = self.q(x), self.k(x), self.v(x) + + x = self.multi_head_attn(q, k, v) + + x = self.proj(x) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class CrossAttentionLayer(nn.Module): + def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(CrossAttentionLayer, self).__init__() + assert qk_dim % num_heads == 0, f"dim {qk_dim} should be divided by num_heads {num_heads}." + assert v_dim % num_heads == 0, f"dim {v_dim} should be divided by num_heads {num_heads}." + """ + Query Token: [N, C] -> [N, qk_dim] (Q) + Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V) + """ + self.num_heads = num_heads + head_dim = qk_dim // num_heads + self.scale = head_dim ** -0.5 + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, v_dim, bias=True) + + self.proj = nn.Linear(v_dim, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout) + ) + + def forward(self, query, tgt_token): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = query + query = self.norm1(query) + + q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token) + + x = self.multi_head_attn(q, k, v) + + x = short_cut + self.proj_drop(self.proj(x)) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + +class CostPerceiverEncoder(nn.Module): + def __init__(self, cfg): + super(CostPerceiverEncoder, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size + self.patch_embed = PatchEmbed(in_chans=self.cfg.cost_heads_num, patch_size=self.patch_size, embed_dim=cfg.cost_latent_input_dim, pe=cfg.pe) + + self.depth = cfg.encoder_depth + + self.latent_tokens = nn.Parameter(torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim)) + + query_token_dim, tgt_token_dim = cfg.cost_latent_dim, cfg.cost_latent_input_dim*2 + qk_dim, v_dim = query_token_dim, query_token_dim + self.input_layer = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout) + + if cfg.use_mlp: + self.encoder_layers = nn.ModuleList([MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) for idx in range(self.depth)]) + else: + self.encoder_layers = nn.ModuleList([SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) for idx in range(self.depth)]) + + if self.cfg.vertical_conv: + self.vertical_encoder_layers = nn.ModuleList([ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)]) + else: + self.vertical_encoder_layers = nn.ModuleList([VerticalSelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) for idx in range(self.depth)]) + self.cost_scale_aug = None + # if ('cost_scale_aug' in cfg.keys()): + # self.cost_scale_aug = cfg.cost_scale_aug + # print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug)) + + + + def forward(self, cost_volume, data, context=None): + B, heads, H1, W1, H2, W2 = cost_volume.shape + cost_maps = cost_volume.permute(0, 2, 3, 1, 4, 5).contiguous().view(B*H1*W1, self.cfg.cost_heads_num, H2, W2) + data['cost_maps'] = cost_maps + + if self.cost_scale_aug is not None: + scale_factor = torch.FloatTensor(B*H1*W1, self.cfg.cost_heads_num, H2, W2).uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1]).cuda() + cost_maps = cost_maps * scale_factor + + x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C + data['H3W3'] = size + H3, W3 = size + + x = self.input_layer(self.latent_tokens, x) + + short_cut = x + + for idx, layer in enumerate(self.encoder_layers): + x = layer(x) + if self.cfg.vertical_conv: + # B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1 + x = x.view(B, H1*W1, self.cfg.cost_latent_token_num, -1).permute(0, 3, 1, 2).reshape(B*self.cfg.cost_latent_token_num, -1, H1, W1) + x = self.vertical_encoder_layers[idx](x) + # B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D + x = x.view(B, self.cfg.cost_latent_token_num, -1, H1*W1).permute(0, 2, 3, 1).reshape(B*H1*W1, self.cfg.cost_latent_token_num, -1) + else: + x = x.view(B, H1*W1, self.cfg.cost_latent_token_num, -1).permute(0, 2, 1, 3).reshape(B*self.cfg.cost_latent_token_num, H1*W1, -1) + x = self.vertical_encoder_layers[idx](x, (H1, W1), context) + x = x.view(B, self.cfg.cost_latent_token_num, H1*W1, -1).permute(0, 2, 1, 3).reshape(B*H1*W1, self.cfg.cost_latent_token_num, -1) + + if self.cfg.cost_encoder_res is True: + x = x + short_cut + #print("~~~~") + return x + +class MemoryEncoder(nn.Module): + def __init__(self, cfg): + super(MemoryEncoder, self).__init__() + self.cfg = cfg + + if cfg.fnet == 'twins': + self.feat_encoder = twins_svt_large(pretrained=self.cfg.pretrain) + elif cfg.fnet == 'basicencoder': + self.feat_encoder = BasicEncoder(output_dim=256, norm_fn='instance') + else: + exit() + self.channel_convertor = nn.Conv2d(cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False) + self.cost_perceiver_encoder = CostPerceiverEncoder(cfg) + + def corr(self, fmap1, fmap2): + + batch, dim, ht, wd = fmap1.shape + fmap1 = rearrange(fmap1, 'b (heads d) h w -> b heads (h w) d', heads=self.cfg.cost_heads_num) + fmap2 = rearrange(fmap2, 'b (heads d) h w -> b heads (h w) d', heads=self.cfg.cost_heads_num) + corr = einsum('bhid, bhjd -> bhij', fmap1, fmap2) + corr = corr.permute(0, 2, 1, 3).view(batch*ht*wd, self.cfg.cost_heads_num, ht, wd) + #corr = self.norm(self.relu(corr)) + corr = corr.view(batch, ht*wd, self.cfg.cost_heads_num, ht*wd).permute(0, 2, 1, 3) + corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd) + + return corr + + def forward(self, img1, img2, data, context=None): + feat_s = self.feat_encoder(img1) + feat_t = self.feat_encoder(img2) + + feat_s = self.channel_convertor(feat_s) + feat_t = self.channel_convertor(feat_t) + + B, C, H, W = feat_s.shape + size = (H, W) + + if self.cfg.feat_cross_attn: + feat_s = feat_s.flatten(2).transpose(1, 2) + feat_t = feat_t.flatten(2).transpose(1, 2) + + for layer in self.layers: + feat_s, feat_t = layer(feat_s, feat_t, size) + + feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + cost_volume = self.corr(feat_s, feat_t) + x = self.cost_perceiver_encoder(cost_volume, data, context) + + return x diff --git a/ptlflow/models/flowformer/encoders.py b/ptlflow/models/flowformer/encoders.py new file mode 100644 index 0000000..6455e24 --- /dev/null +++ b/ptlflow/models/flowformer/encoders.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn +import timm +import numpy as np + +class twins_svt_large(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model('twins_svt_large', pretrained=pretrained) + + del self.svt.head + del self.svt.patch_embeds[2] + del self.svt.patch_embeds[2] + del self.svt.blocks[2] + del self.svt.blocks[2] + del self.svt.pos_block[2] + del self.svt.pos_block[2] + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j==0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer-1: + break + + return x + + def compute_params(self, layer=2): + num = 0 + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + for param in embed.parameters(): + num += np.prod(param.size()) + + for param in drop.parameters(): + num += np.prod(param.size()) + + for param in blocks.parameters(): + num += np.prod(param.size()) + + for param in pos_blk.parameters(): + num += np.prod(param.size()) + + if i == layer-1: + break + + for param in self.svt.head.parameters(): + num += np.prod(param.size()) + + return num + +class twins_svt_large_context(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model('twins_svt_large_context', pretrained=pretrained) + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j==0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer-1: + break + + return x + + +if __name__ == "__main__": + m = twins_svt_large() + input = torch.randn(2, 3, 400, 800) + out = m.extract_feature(input) + print(out.shape) \ No newline at end of file diff --git a/ptlflow/models/flowformer/flowformer.py b/ptlflow/models/flowformer/flowformer.py new file mode 100644 index 0000000..315aa2d --- /dev/null +++ b/ptlflow/models/flowformer/flowformer.py @@ -0,0 +1,129 @@ +from argparse import ArgumentParser, Namespace + +import torch +import torch.nn as nn + +from .cnn import BasicEncoder +from .encoder import MemoryEncoder +from .encoders import twins_svt_large +from .decoder import MemoryDecoder +from ..base_model.base_model import BaseModel + + +class SequenceLoss(nn.Module): + def __init__(self, args): + super().__init__() + self.gamma = args.gamma + self.max_flow = args.max_flow + + def forward(self, outputs, inputs): + """ Loss function defined over sequence of flow predictions """ + flow_preds = outputs['flow_preds'] + flow_gt = inputs['flows'][:, 0] + valid = inputs['valids'][:, 0] + + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < self.max_flow) + + for i in range(n_predictions): + i_weight = self.gamma**(n_predictions - i - 1) + i_loss = (flow_preds[i] - flow_gt).abs() + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + + return flow_loss + + +class FlowFormer(BaseModel): + pretrained_checkpoints = { + 'chairs': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/flowformer-chairs-2b34ea4b.ckpt', + 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/flowformer-things-ab5f3255.ckpt', + 'sintel': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/flowformer-sintel-27cc959a.ckpt', + 'kitti': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/flowformer-kitti-1e45a6c8.ckpt' + } + + def __init__(self, + args: Namespace) -> None: + super().__init__( + args=args, + loss_fn=SequenceLoss(args), + output_stride=8) + + if self.args.gma is None: + self.args.gma = True # Use GMA by default, unless + + self.memory_encoder = MemoryEncoder(args) + self.memory_decoder = MemoryDecoder(args) + if args.cnet == 'twins': + self.context_encoder = twins_svt_large(pretrained=self.args.pretrain) + elif args.cnet == 'basicencoder': + self.context_encoder = BasicEncoder(output_dim=256, norm_fn='instance') + + @staticmethod + def add_model_specific_args(parent_parser=None): + parent_parser = BaseModel.add_model_specific_args(parent_parser) + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--add_flow_token', action='store_true') + parser.add_argument('--cnet', type=str, choices=('basicencoder', 'twins'), default='twins') + parser.add_argument('--context_concat', action='store_true') + parser.add_argument('--cost_encoder_res', action='store_true') + parser.add_argument('--cost_heads_num', type=int, default=1) + parser.add_argument('--cost_latent_dim', type=int, default=128) + parser.add_argument('--cost_latent_input_dim', type=int, default=64) + parser.add_argument('--cost_latent_token_num', type=int, default=8) + parser.add_argument('--decoder_depth', type=int, default=12) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--encoder_depth', type=int, default=3) + parser.add_argument('--encoder_latent_dim', type=int, default=256) + parser.add_argument('--feat_cross_attn', action='store_true') + parser.add_argument('--fnet', type=str, choices=('basicencoder', 'twins'), default='twins') + parser.add_argument('--gamma', type=float, default=0.8) + parser.add_argument('--max_flow', type=float, default=400.0) + parser.add_argument('--no_gma', action='store_false', dest='gma') + parser.add_argument('--only_global', action='store_true') + parser.add_argument('--patch_size', type=int, default=8) + parser.add_argument('--pe', type=str, choices=('exp', 'linear'), default='linear') + parser.add_argument('--pretrain', action='store_true') + parser.add_argument('--query_latent_dim', type=int, default=64) + parser.add_argument('--use_mlp', action='store_true') + parser.add_argument('--vert_c_dim', type=int, default=64) + parser.add_argument('--vertical_conv', action='store_true') + return parser + + + def forward(self, inputs, flow_init=None): + """ Estimate optical flow between pair of frames """ + image1 = inputs['images'][:, 0] + image2 = inputs['images'][:, 1] + + image1 = 2 * image1 - 1.0 + image2 = 2 * image2 - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + data = {} + + if self.args.context_concat: + context = self.context_encoder(torch.cat([image1, image2], dim=1)) + else: + context = self.context_encoder(image1) + + cost_memory = self.memory_encoder(image1, image2, data, context) + + flow_predictions = self.memory_decoder(cost_memory, context, data, flow_init=flow_init) + + if self.training: + outputs = { + 'flows': flow_predictions[0][:, None], + 'flow_preds': flow_predictions + } + else: + outputs = { + 'flows': flow_predictions[0][:, None] + } + + return outputs diff --git a/ptlflow/models/flowformer/gma.py b/ptlflow/models/flowformer/gma.py new file mode 100644 index 0000000..c0bbba5 --- /dev/null +++ b/ptlflow/models/flowformer/gma.py @@ -0,0 +1,123 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RelPosEmb(nn.Module): + def __init__( + self, + max_pos_size, + dim_head + ): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) + rel_ind = deltas + max_pos_size - 1 + self.register_buffer('rel_ind', rel_ind) + + def forward(self, q): + batch, heads, h, w, c = q.shape + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) + + height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) + width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) + + height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) + width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) + + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size = 100, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + self.pos_emb = RelPosEmb(max_pos_size, dim_head) + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + q, k = self.to_qk(fmap).chunk(2, dim=1) + + q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) + q = self.scale * q + + # if self.args.position_only: + # sim = self.pos_emb(q) + + # elif self.args.position_and_content: + # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + # sim_pos = self.pos_emb(q) + # sim = sim_content + sim_pos + + # else: + sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + + sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') + attn = sim.softmax(dim=-1) + + return attn + + +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) diff --git a/ptlflow/models/flowformer/gru.py b/ptlflow/models/flowformer/gru.py new file mode 100644 index 0000000..1537561 --- /dev/null +++ b/ptlflow/models/flowformer/gru.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + if args.only_global: + print("[Decoding with only global cost]") + cor_planes = args.query_latent_dim + else: + cor_planes = 81+args.query_latent_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + +from .gma import Aggregate +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) + + def forward(self, net, inp, corr, flow, attention): + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow diff --git a/ptlflow/models/flowformer/mlpmixer.py b/ptlflow/models/flowformer/mlpmixer.py new file mode 100644 index 0000000..88efd07 --- /dev/null +++ b/ptlflow/models/flowformer/mlpmixer.py @@ -0,0 +1,49 @@ +from torch import nn +from functools import partial +import numpy as np + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return self.fn(self.norm(x)) + x + +def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear): + return nn.Sequential( + dense(dim, dim * expansion_factor), + nn.GELU(), + nn.Dropout(dropout), + dense(dim * expansion_factor, dim), + nn.Dropout(dropout) + ) + +class MLPMixerLayer(nn.Module): + def __init__(self, dim, cfg, drop_path=0., dropout=0.): + super(MLPMixerLayer, self).__init__() + + # print(f"use mlp mixer layer") + K = cfg.cost_latent_token_num + expansion_factor = cfg.mlp_expansion_factor + chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear + + self.mlpmixer = nn.Sequential( + PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)), + PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last)), + ) + + def compute_params(self): + num = 0 + for param in self.mlpmixer.parameters(): + num += np.prod(param.size()) + + return num + + def forward(self, x): + """ + x: [BH1W1, K, D] + """ + + return self.mlpmixer(x) diff --git a/ptlflow/models/flowformer/twins.py b/ptlflow/models/flowformer/twins.py new file mode 100644 index 0000000..978ef66 --- /dev/null +++ b/ptlflow/models/flowformer/twins.py @@ -0,0 +1,931 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.models.vision_transformer import Attention +from .attention import LinearPositionEmbeddingSine +from .utils import coords_grid + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'twins_pcpvt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth', + ), + 'twins_pcpvt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth', + ), + 'twins_pcpvt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth', + ), + 'twins_svt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth', + ), + 'twins_svt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth', + ), + 'twins_svt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth', + ), +} + +Size_ = Tuple[int, int] + +class GroupAttnRPEContext(nn.Module): + """ Latent cost tokens attend to different group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, cfg=None, vert_c_dim=0): + super(GroupAttnRPEContext, self).__init__() + assert ws != 1 + assert cfg is not None + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + assert cfg.cost_latent_token_num % 5 == 0, "cost_latent_token_num should be divided by 5." + assert vert_c_dim > 0, "vert_c_dim should not be 0" + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.vert_c_dim = vert_c_dim + + self.cfg = cfg + + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C+self.vert_c_dim + H, W = size + batch_num = B // 5 + + context = context.repeat(B//context.shape[0], 1, 1, 1) + context = context.view(B, -1, H*W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp*Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + coords_enc = coords_enc.reshape(B, Hp, Wp, C_qk) + + q = self.q(x_qk + coords_enc).reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + q = q.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x_qk + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat([kv[:batch_num, self.ws:Hp, :, :], kv[:batch_num, Hp-self.ws:Hp, :, :]], dim=1) + kv_down = torch.cat([kv[batch_num:batch_num*2, :self.ws, :, :], kv[batch_num:batch_num*2, :Hp-self.ws, :, :]], dim=1) + kv_left = torch.cat([kv[batch_num*2:batch_num*3, :, self.ws:Wp, :], kv[batch_num*2:batch_num*3, :, Wp-self.ws:Wp, :]], dim=2) + kv_right = torch.cat([kv[batch_num*3:batch_num*4, :, :self.ws, :], kv[batch_num*3:batch_num*4, :, :Wp-self.ws, :]], dim=2) + kv_center = kv[batch_num*4:batch_num*5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + k = k.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = v.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + v = v.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GroupAttnRPE(nn.Module): + """ Latent cost tokens attend to different group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, cfg=None): + super(GroupAttnRPE, self).__init__() + assert ws != 1 + assert cfg is not None + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + assert cfg.cost_latent_token_num % 5 == 0, "cost_latent_token_num should be divided by 5." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.cfg = cfg + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + batch_num = B // 5 + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp*Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + coords_enc = coords_enc.reshape(B, Hp, Wp, C) + + q = self.q(x + coords_enc).reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + q = q.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat([kv[:batch_num, self.ws:Hp, :, :], kv[:batch_num, Hp-self.ws:Hp, :, :]], dim=1) + kv_down = torch.cat([kv[batch_num:batch_num*2, :self.ws, :, :], kv[batch_num:batch_num*2, :Hp-self.ws, :, :]], dim=1) + kv_left = torch.cat([kv[batch_num*2:batch_num*3, :, self.ws:Wp, :], kv[batch_num*2:batch_num*3, :, Wp-self.ws:Wp, :]], dim=2) + kv_right = torch.cat([kv[batch_num*3:batch_num*4, :, :self.ws, :], kv[batch_num*3:batch_num*4, :, :Wp-self.ws, :]], dim=2) + kv_center = kv[batch_num*4:batch_num*5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + k = k.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = v.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + v = v.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class LocallyGroupedAttnRPEContext(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, vert_c_dim=0): + assert ws != 1 + super(LocallyGroupedAttnRPEContext, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.vert_c_dim = vert_c_dim + + self.context_proj = nn.Linear(256, vert_c_dim) + # context are not added to value + self.q = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + C_qk = C+self.vert_c_dim + + context = context.repeat(B//context.shape[0], 1, 1, 1) + context = context.view(B, -1, H*W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + x_qk = x_qk.reshape(B, _h, self.ws, _w, self.ws, C_qk).transpose(2, 3) + + v = self.v(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk).view(B, self.ws, self.ws, C_qk) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x_qk = x_qk + coords_enc[:, None, None, :, :, :] + + q = self.q(x_qk).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + k = self.k(x_qk).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GlobalSubSampleAttnRPEContext(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1, vert_c_dim=0): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.vert_c_dim = vert_c_dim + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr_key = nn.Conv2d(dim+vert_c_dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.sr_value = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + context = context.repeat(B//context.shape[0], 1, 1, 1) + context = context.view(B, -1, H*W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp*Wp + x = x.view(B, -1, C) + x_qk = x_qk.view(B, -1, C_qk) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = self.q(x_qk + coords_enc).reshape(B, padded_N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_key is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x_qk = x_qk.permute(0, 2, 1).reshape(B, C_qk, *padded_size) + x = self.sr_value(x).reshape(B, C, -1).permute(0, 2, 1) + x_qk = self.sr_key(x_qk).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + x_qk = self.norm(x_qk) + + coords = coords_grid(B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(x_qk + coords_enc).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(x).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class LocallyGroupedAttnRPE(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttnRPE, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + v = self.v(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C).view(B, self.ws, self.ws, C) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x = x + coords_enc[:, None, None, :, :, :] + + q = self.q(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + k = self.k(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GlobalSubSampleAttnRPE(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp*Wp + x = x.view(B, -1, C) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = self.q(x + coords_enc).reshape(B, padded_N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + + coords = coords_grid(B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(x + coords_enc).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(x).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class CrossGlobalSubSampleAttnRPE(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + coords = coords_grid(B, *size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, H*W, C + # x: B, H*W, C + q = self.q(x + coords_enc).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + coords = coords_grid(B, size[0] // self.sr_ratio, size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(tgt + coords_enc).reshape(B, (size[0] // self.sr_ratio)*(size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(tgt).reshape(B, (size[0] // self.sr_ratio)*(size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class LocallyGroupedAttn(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = self.qkv(x).reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class CrossGlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + kv = self.kv(tgt).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class CrossBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None, with_rpe=True): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossGlobalSubSampleAttnRPE(dim, num_heads, attn_drop, drop, sr_ratio) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, src, tgt, size: Size_): + src_shortcut, tgt_shortcut = src, tgt + + src, tgt = self.norm1(src), self.norm1(tgt) + src = src_shortcut + self.drop_path(self.attn(src, tgt, size)) + tgt = tgt_shortcut + self.drop_path(self.attn(tgt, src, size)) + + src = src + self.drop_path(self.mlp(self.norm2(src))) + tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt))) + return src, tgt + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None, with_rpe=False, vert_c_dim=0, groupattention=False, cfg=None): + super().__init__() + self.norm1 = norm_layer(dim) + if groupattention: + assert with_rpe, "Not implementing groupattention without rpe" + if vert_c_dim > 0: + self.attn = GroupAttnRPEContext(dim, num_heads, attn_drop, drop, ws, cfg, vert_c_dim) + else: + self.attn = GroupAttnRPE(dim, num_heads, attn_drop, drop, ws, cfg) + elif ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + elif ws == 1: + if with_rpe: + if vert_c_dim > 0: + self.attn = GlobalSubSampleAttnRPEContext(dim, num_heads, attn_drop, drop, sr_ratio, vert_c_dim) + else: + self.attn = GlobalSubSampleAttnRPE(dim, num_heads, attn_drop, drop, sr_ratio) + else: + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) + else: + if with_rpe: + if vert_c_dim > 0: + self.attn = LocallyGroupedAttnRPEContext(dim, num_heads, attn_drop, drop, ws, vert_c_dim) + else: + self.attn = LocallyGroupedAttnRPE(dim, num_heads, attn_drop, drop, ws) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, size: Size_, context=None): + x = x + self.drop_path(self.attn(self.norm1(x), size, context)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """ Twins Vision Transfomer (Revisiting Spatial Attention) + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, + block_cls=Block, init_weight=True): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])]) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # init weights + if init_weight: + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x diff --git a/ptlflow/models/flowformer/utils.py b/ptlflow/models/flowformer/utils.py new file mode 100644 index 0000000..5eabe7d --- /dev/null +++ b/ptlflow/models/flowformer/utils.py @@ -0,0 +1,101 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + elif mode == 'kitti400': + self._pad = [0, 0, 0, 400 - self.ht] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def indexing(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + """ + TODO: directly indexing features instead of sampling + """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True, mode='nearest') + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) \ No newline at end of file diff --git a/ptlflow/models/gmflow/LICENSE b/ptlflow/models/gmflow/LICENSE new file mode 100644 index 0000000..9185816 --- /dev/null +++ b/ptlflow/models/gmflow/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022, Haofei Xu + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/ptlflow/models/gmflow/README.md b/ptlflow/models/gmflow/README.md new file mode 100644 index 0000000..7ca0bc0 --- /dev/null +++ b/ptlflow/models/gmflow/README.md @@ -0,0 +1,25 @@ +# RAFT + +## Original code + +[https://github.com/haofeixu/gmflow](https://github.com/haofeixu/gmflow) + +## Code license + +See [LICENSE](LICENSE). + +## Pretrained weights license + +Not specified. + +## Citation + +``` +@inproceedings{xu2022gmflow, + title={GMFlow: Learning Optical Flow via Global Matching}, + author={Xu, Haofei and Zhang, Jing and Cai, Jianfei and Rezatofighi, Hamid and Tao, Dacheng}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={8121-8130}, + year={2022} +} +``` \ No newline at end of file diff --git a/ptlflow/models/gmflow/__init__.py b/ptlflow/models/gmflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ptlflow/models/gmflow/backbone.py b/ptlflow/models/gmflow/backbone.py new file mode 100644 index 0000000..a30942e --- /dev/null +++ b/ptlflow/models/gmflow/backbone.py @@ -0,0 +1,117 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/ptlflow/models/gmflow/geometry.py b/ptlflow/models/gmflow/geometry.py new file mode 100644 index 0000000..207e98f --- /dev/null +++ b/ptlflow/models/gmflow/geometry.py @@ -0,0 +1,96 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ diff --git a/ptlflow/models/gmflow/gmflow.py b/ptlflow/models/gmflow/gmflow.py new file mode 100644 index 0000000..22184a3 --- /dev/null +++ b/ptlflow/models/gmflow/gmflow.py @@ -0,0 +1,240 @@ +from argparse import ArgumentParser, Namespace + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .transformer import FeatureTransformer, FeatureFlowAttention +from .matching import global_correlation_softmax, local_correlation_softmax +from .geometry import flow_warp +from .utils import normalize_img, feature_add_position +from ..base_model.base_model import BaseModel + + +class SequenceLoss(nn.Module): + def __init__(self, args): + super().__init__() + self.gamma = args.gamma + self.max_flow = args.max_flow + + def forward(self, outputs, inputs): + """ Loss function defined over sequence of flow predictions """ + + flow_preds = outputs['flow_preds'] + flow_gt = inputs['flows'][:, 0] + valid = inputs['valids'][:, 0] + + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt ** 2, dim=1).sqrt() # [B, H, W] + valid = (valid >= 0.5) & (mag < self.max_flow) + + for i in range(n_predictions): + i_weight = self.gamma ** (n_predictions - i - 1) + + i_loss = (flow_preds[i] - flow_gt).abs() + + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + + return flow_loss + + +class GMFlow(BaseModel): + pretrained_checkpoints = { + 'chairs': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow-chairs-4922131e.ckpt', + 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow-things-5a18a9e8.ckpt', + 'sintel': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow-sintel-d6f83ccd.ckpt', + 'kitti': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow-kitti-af50eb2e.ckpt' + } + + def __init__(self, + args: Namespace) -> None: + super().__init__( + args=args, + loss_fn=SequenceLoss(args), + output_stride=16) + + # CNN backbone + self.backbone = CNNEncoder(output_dim=self.args.feature_channels, num_output_scales=self.args.num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=self.args.num_transformer_layers, + d_model=self.args.feature_channels, + nhead=self.args.num_head, + attention_type=self.args.attention_type, + ffn_dim_expansion=self.args.ffn_dim_expansion, + ) + + # flow propagation with self-attn + self.feature_flow_attn = FeatureFlowAttention(in_channels=self.args.feature_channels) + + # convex upsampling: concat feature0 and flow as input + self.upsampler = nn.Sequential(nn.Conv2d(2 + self.args.feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, self.args.upsample_factor ** 2 * 9, 1, 1, 0)) + + @staticmethod + def add_model_specific_args(parent_parser=None): + parent_parser = BaseModel.add_model_specific_args(parent_parser) + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--attention_type', type=str, choices=('full', 'swin'), default='swin') + parser.add_argument('--attn_splits_list', type=int, nargs='+', default=(2,)) + parser.add_argument('--corr_radius_list', type=int, nargs='+', default=(-1,)) + parser.add_argument('--feature_channels', type=int, default=128) + parser.add_argument('--ffn_dim_expansion', type=int, default=4) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--max_flow', type=float, default=400.0) + parser.add_argument('--num_head', type=int, default=1) + parser.add_argument('--num_scales', type=int, default=1) + parser.add_argument('--num_transformer_layers', type=int, default=6) + parser.add_argument('--pred_bidir_flow', action='store_true') + parser.add_argument('--prop_radius_list', type=int, nargs='+', default=(-1,)) + parser.add_argument('--upsample_factor', type=int, default=8) + return parser + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, + ): + if bilinear: + up_flow = F.interpolate(flow, scale_factor=upsample_factor, + mode='bilinear', align_corners=True) * upsample_factor + + else: + # convex upsampling + concat = torch.cat((flow, feature), dim=1) + + mask = self.upsampler(concat) + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, self.args.upsample_factor, self.args.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(self.args.upsample_factor * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, self.args.upsample_factor * h, + self.args.upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + def forward(self, inputs): + """ Estimate optical flow between pair of frames """ + img0 = inputs['images'][:, 0] + img1 = inputs['images'][:, 1] + + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + flow_preds = [] + + # resolution low to high + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + assert len(self.args.attn_splits_list) == len(self.args.corr_radius_list) == len(self.args.prop_radius_list) == self.args.num_scales + + for scale_idx in range(self.args.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if self.args.pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + upsample_factor = self.args.upsample_factor * (2 ** (self.args.num_scales - 1 - scale_idx)) + + if scale_idx > 0: + flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + + if flow is not None: + flow = flow.detach() + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + + attn_splits = self.args.attn_splits_list[scale_idx] + corr_radius = self.args.corr_radius_list[scale_idx] + prop_radius = self.args.prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.args.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits) + + # correlation and softmax + if corr_radius == -1: # global matching + flow_pred = global_correlation_softmax(feature0, feature1, self.args.pred_bidir_flow)[0] + else: # local matching + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + # upsample to the original resolution for supervison + if self.training: # only need to upsample intermediate flow predictions at training time + flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor) + flow_preds.append(flow_bilinear) + + # flow propagation with self-attn + if self.args.pred_bidir_flow and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + flow = self.feature_flow_attn(feature0, flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius) + + # bilinear upsampling at training time except the last one + if self.training and scale_idx < self.args.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor) + flow_preds.append(flow_up) + + if scale_idx == self.args.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0) + flow_preds.append(flow_up) + + + if self.training: + outputs = { + 'flows': flow_up[:, None], + 'flow_preds': flow_preds + } + else: + outputs = { + 'flows': flow_up[:, None] + } + + return outputs + + +class GMFlowWithRefinement(GMFlow): + pretrained_checkpoints = { + 'chairs': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow_refine-chairs-88cdc009.ckpt', + 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow_refine-things-e40899f5.ckpt', + 'sintel': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow_refine-sintel-ee46a2c4.ckpt', + 'kitti': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflow_refine-kitti-b7bf2fda.ckpt' + } + + def __init__(self, args: Namespace) -> None: + args.attn_splits_list = (2, 2) + args.corr_radius_list = (-1, 4) + args.num_scales = 2 + args.prop_radius_list = (-1, 1) + args.upsample_factor = 4 + super().__init__(args) diff --git a/ptlflow/models/gmflow/matching.py b/ptlflow/models/gmflow/matching.py new file mode 100644 index 0000000..1740200 --- /dev/null +++ b/ptlflow/models/gmflow/matching.py @@ -0,0 +1,83 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob diff --git a/ptlflow/models/gmflow/position.py b/ptlflow/models/gmflow/position.py new file mode 100644 index 0000000..14a6da4 --- /dev/null +++ b/ptlflow/models/gmflow/position.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/ptlflow/models/gmflow/transformer.py b/ptlflow/models/gmflow/transformer.py new file mode 100644 index 0000000..9a8f2ce --- /dev/null +++ b/ptlflow/models/gmflow/transformer.py @@ -0,0 +1,409 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, merge_splits + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=torch.device('cuda')): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + **kwargs, + ): + super(FeatureTransformer, self).__init__() + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for layer in self.layers: + concat0 = layer(concat0, concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 + + +class FeatureFlowAttention(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(FeatureFlowAttention, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/ptlflow/models/gmflow/trident_conv.py b/ptlflow/models/gmflow/trident_conv.py new file mode 100644 index 0000000..29a2a73 --- /dev/null +++ b/ptlflow/models/gmflow/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/ptlflow/models/gmflow/utils.py b/ptlflow/models/gmflow/utils.py new file mode 100644 index 0000000..7f9c830 --- /dev/null +++ b/ptlflow/models/gmflow/utils.py @@ -0,0 +1,86 @@ +import torch +from .position import PositionEmbeddingSine + + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def normalize_img(img0, img1): + # loaded images are in [0, 1] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 - mean) / std + img1 = (img1 - mean) / std + + return img0, img1 + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 diff --git a/ptlflow/models/gmflownet/LICENSE b/ptlflow/models/gmflownet/LICENSE new file mode 100644 index 0000000..494d691 --- /dev/null +++ b/ptlflow/models/gmflownet/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Shiyu Zhao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/ptlflow/models/gmflownet/README.md b/ptlflow/models/gmflownet/README.md new file mode 100644 index 0000000..785f5fd --- /dev/null +++ b/ptlflow/models/gmflownet/README.md @@ -0,0 +1,24 @@ +# RAFT + +## Original code + +[https://github.com/xiaofeng94/GMFlowNet](https://github.com/xiaofeng94/GMFlowNet) + +## Code license + +See [LICENSE](LICENSE). + +## Pretrained weights license + +Not specified. + +## Citation + +``` +@inproceedings{xu2022gmflow, + title={Global Matching with Overlapping Attention for Optical Flow Estimation}, + author={Zhao, Shiyu and Zhao, Long and Zhang, Zhixing and Zhou, Enyu and Metaxas, Dimitris}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2022} +} +``` \ No newline at end of file diff --git a/ptlflow/models/gmflownet/__init__.py b/ptlflow/models/gmflownet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ptlflow/models/gmflownet/corr.py b/ptlflow/models/gmflownet/corr.py new file mode 100644 index 0000000..cc18885 --- /dev/null +++ b/ptlflow/models/gmflownet/corr.py @@ -0,0 +1,87 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + self.corrMap = corr.view(batch, h1*w1, h2*w2) + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/ptlflow/models/gmflownet/extractor.py b/ptlflow/models/gmflownet/extractor.py new file mode 100644 index 0000000..3cab097 --- /dev/null +++ b/ptlflow/models/gmflownet/extractor.py @@ -0,0 +1,1170 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class IPTHeadEncoder(nn.Module): + """docstring for IPTHead""" + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(IPTHeadEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + half_out_dim = max(output_dim // 2, 64) + self.layer1 = ResidualBlock(64, half_out_dim, self.norm_fn, stride=2) + self.layer2 = ResidualBlock(half_out_dim, output_dim, self.norm_fn, stride=2) + + # # output convolution; this can solve mixed memory warning, not know why + # self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.relu1(self.norm1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class BasicConvEncoder(nn.Module): + """docstring for BasicConvEncoder""" + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicConvEncoder, self).__init__() + self.norm_fn = norm_fn + + half_out_dim = max(output_dim // 2, 64) + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=64) + self.norm3 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + self.norm2 = nn.BatchNorm2d(half_out_dim) + self.norm3 = nn.BatchNorm2d(output_dim) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + self.norm2 = nn.InstanceNorm2d(half_out_dim) + self.norm3 = nn.InstanceNorm2d(output_dim) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.conv2 = nn.Conv2d(64, half_out_dim, kernel_size=3, stride=2, padding=1) + self.conv3 = nn.Conv2d(half_out_dim, output_dim, kernel_size=3, stride=2, padding=1) + + # # output convolution; this can solve mixed memory warning, not know why + # self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = F.relu(self.norm1(self.conv1(x)), inplace=True) + x = F.relu(self.norm2(self.conv2(x)), inplace=True) + x = F.relu(self.norm3(self.conv3(x)), inplace=True) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__() + + assert d_model % n_head == 0 + + self.n_head = n_head + self.d_model = d_model + self.d_head = self.d_model // self.n_head + + self.w_qs = nn.Linear(self.d_model, self.d_model, bias=False) # TODO: enable bias + self.w_ks = nn.Linear(self.d_model, self.d_model, bias=False) + self.w_vs = nn.Linear(self.d_model, self.d_model, bias=False) + self.fc = nn.Linear(self.d_model, self.d_model) + + # self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) # TODO + + self.dropout = nn.Dropout(dropout) + # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + + def forward(self, q, k, v): + ''' + q: shape of N*len*C + ''' + d_head, n_head = self.d_head, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + # residual = q + + # Pass through the pre-attention projection: b x lq x (n*dv) + # Separate different heads: b x lq x n x dv + q = self.w_qs(q).view(sz_b, len_q, n_head, d_head) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_head) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_head) + + # Transpose for attention dot product: b x n x lq x dv + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.d_head) + attn = self.dropout(F.softmax(attn, dim=-1)) + q_updated = torch.matmul(attn, v) + + # Transpose to move the head dimension back: b x lq x n x dv + # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) + q_updated = q_updated.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + q_updated = self.dropout(self.fc(q_updated)) + # q_updated += residual + + # q_updated = self.layer_norm(q_updated) + + return q_updated, attn + + + +class AnchorEncoderBlock(nn.Module): + + def __init__(self, anchor_dist, d_model, num_heads, d_ff, dropout=0.): + super().__init__() + + self.anchor_dist = anchor_dist + self.half_anchor_dist = anchor_dist // 2 + + self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) + self.dropout = nn.Dropout(dropout) + self.layer_norm_1 = nn.LayerNorm(d_model) + + self.FFN = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model), + ) + + self.layer_norm_2 = nn.LayerNorm(d_model) + + + def forward(self, inputs): + ''' + inputs: batches with N*C*H*W + ''' + N, C, H, W = inputs.shape + + x = inputs + anchors = inputs[:,:, self.half_anchor_dist::self.anchor_dist, + self.half_anchor_dist::self.anchor_dist].clone() + + # flatten feature maps + x = x.reshape(N, C, H*W).transpose(-1,-2) + anchors = anchors.reshape(N, C, anchors.shape[2]* anchors.shape[3]).transpose(-1,-2) + + # two-stage multi-head self-attention + anchors_new = self.dropout(self.selfAttn(anchors, x, x)[0]) + residual = self.dropout(self.selfAttn(x, anchors_new, anchors_new)[0]) + + norm_1 = self.layer_norm_1(x + residual) + x_linear = self.dropout(self.FFN(norm_1)) + x_new = self.layer_norm_2(norm_1 + x_linear) + + outputs = x_new.transpose(-1,-2).reshape(N, C, H, W) + return outputs + + + +class EncoderBlock(nn.Module): + + def __init__(self, d_model, num_heads, d_ff, dropout=0.): + super().__init__() + + self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) + self.dropout = nn.Dropout(dropout) + self.layer_norm_1 = nn.LayerNorm(d_model) + + self.FFN = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model), + ) + + self.layer_norm_2 = nn.LayerNorm(d_model) + + + def forward(self, x): + ''' + x: input batches with N*C*H*W + ''' + N, C, H, W = x.shape + + # update x + x = x.reshape(N, C, H*W).transpose(-1,-2) + + residual = self.dropout(self.selfAttn(x, x, x)[0]) + norm_1 = self.layer_norm_1(x + residual) + x_linear = self.dropout(self.FFN(norm_1)) + x_new = self.layer_norm_2(norm_1 + x_linear) + + outputs = x_new.transpose(-1,-2).reshape(N, C, H, W) + return outputs + + + +class ReduceEncoderBlock(nn.Module): + + def __init__(self, d_model, num_heads, d_ff, dropout=0.): + super().__init__() + + self.reduce = nn.Sequential( + nn.Conv2d(d_model, d_model, 2, 2), + nn.Conv2d(d_model, d_model, 2, 2) + ) + # self.reduce = nn.Sequential( + # nn.AvgPool2d(16, 16) + # ) + + self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) + self.dropout = nn.Dropout(dropout) + self.layer_norm_1 = nn.LayerNorm(d_model) + + self.FFN = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model) + ) + + self.layer_norm_2 = nn.LayerNorm(d_model) + + + def forward(self, x): + ''' + x: input batches with N*C*H*W + ''' + N, C, H, W = x.shape + x_reduced = self.reduce(x) + + # update x + x = x.reshape(N, C, H*W).transpose(-1,-2) + x_reduced = x_reduced.reshape(N, C, -1).transpose(-1,-2) + + # print('x ', x.shape) + # print('x_reduced ', x_reduced.shape) + # exit() + + residual = self.dropout(self.selfAttn(x, x_reduced, x_reduced)[0]) + + norm_1 = self.layer_norm_1(x + residual) + x_linear = self.dropout(self.FFN(norm_1)) + x_new = self.layer_norm_2(norm_1 + x_linear) + + outputs = x_new.transpose(-1,-2).reshape(N, C, H, W) + return outputs + + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class LayerEncoderBlock(nn.Module): + + def __init__(self, win_size, d_model, num_heads, d_ff, dropout=0.): + super().__init__() + + self.win_size = win_size + self.down_factor = 4 + self.unfold_stride = int(self.win_size//self.down_factor) + + self.stride_list = [math.floor(win_size/self.down_factor**idx) for idx in range(8) if win_size/self.down_factor**idx >= 1] + # [16, 4, 1] + + self.reduce = nn.Sequential( + nn.AvgPool2d(self.down_factor, self.down_factor) + ) + + self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) + self.crossAttn = MultiHeadAttention(d_model, num_heads, dropout) + + self.dropout = nn.Dropout(dropout) + self.layerNormSelf = nn.LayerNorm(d_model) + self.layerNormCross = nn.LayerNorm(d_model) + + self.FFN = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model) + ) + + self.layer_norm_out = nn.LayerNorm(d_model) + + + def Circular_pad2D(self, x, pad_right, pad_bottom): + ''' + x: (N, H, W, C) + x_pad: (N, H_pad, W_pad, C) + ''' + N, H, W, C = x.shape + + H_pad = H + pad_bottom + W_pad = W + pad_right + + H_repeat = math.ceil(H_pad/H) + W_repeat = math.ceil(W_pad/W) + x_repeat = x.repeat(1, H_repeat, W_repeat, 1) + + x_pad = x_repeat[:, :H_pad, :W_pad, :] + return x_pad + + + def pad_fit_win(self, x, win_size): + N, H, W, C = x.shape + + W_ = math.ceil(W/win_size)*win_size + H_ = math.ceil(H/win_size)*win_size + padRight = W_ - W + padBottom = H_ - H + + x_pad = self.Circular_pad2D(x, padRight, padBottom) # N*H_*W_*C + return x_pad + + + def self_attention(self, x): + ''' + x: (N, H, W, C) + out: (N, H, W, C) + ''' + N, H, W, C = x.shape + x_pad = self.pad_fit_win(x, self.win_size) # N*H_*W_*C + _, H_, W_, _ = x_pad.shape + + # x_pad = F.pad(x.permute(xxx), (0, padRight, 0, padBottom), mode='reflect') # N*C*H_*W_ + + x_window = window_partition(x_pad, self.win_size) # (num_win*B, win_size, win_size, C) + x_window = x_window.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C) + + # self-attention + residual = self.dropout(self.selfAttn(x_window, x_window, x_window)[0]) + residual = residual.view(-1, self.win_size, self.win_size, C) + residual = window_reverse(residual, self.win_size, H_, W_) # (N, H_, W_, C) + + out = x_pad + residual + out = out[:, :H, :W, :] + return out + + + def cross_attention(self, query, keyVal): + ''' + query: (N, qH, qW, C) + keyVal: (N, kH, kW, C) + out: (N, qH, qW, C) + ''' + _, qH, qW, C = query.shape + _, kH, kW, C = keyVal.shape + + # print('in query ', query.shape) + # print('in keyVal ', keyVal.shape) + # print('-') + + query = self.pad_fit_win(query, self.win_size) # N*H_*W_*C + _, qH_, qW_, C = query.shape + + query_win = window_partition(query, self.win_size) + query_win = query_win.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C) + + # pad and unfold keyVal + kW_ = (math.ceil(kW/self.unfold_stride) - 1)*self.unfold_stride + self.win_size + kH_ = (math.ceil(kH/self.unfold_stride) - 1)*self.unfold_stride + self.win_size + padRight = kW_ - kW + padBottom = kH_ - kH + + keyVal_pad = self.Circular_pad2D(keyVal, padRight, padBottom) + keyVal = F.unfold(keyVal_pad.permute(0, 3, 1, 2), self.win_size, stride=self.unfold_stride) # (N, C*win_size*win_size, num_win) + keyVal = keyVal.permute(0,2,1).reshape(-1, C, self.win_size*self.win_size).permute(0,2,1) # (num_win*B, win_size*win_size, C) + + # print('win query ', query_win.shape) + # print('win keyVal ', keyVal.shape) + # print('-') + + residual = self.dropout(self.crossAttn(query_win, keyVal, keyVal)[0]) + residual = residual.view(-1, self.win_size, self.win_size, C) + residual = window_reverse(residual, self.win_size, qH_, qW_) # (N, H, W, C) + + out = query + residual + out = out[:, :qH, :qW, :] + return out + + + def forward(self, x): + ''' + x: input batches with N*C*H*W + ''' + N, C, H, W = x.shape + x = x.permute(0, 2, 3, 1) # N*H*W*C + x = self.pad_fit_win(x, self.win_size) # pad + + # layered self-attention + layerAttnList = [] + strideListLen = len(self.stride_list) + for idx in range(strideListLen): + x_attn = self.self_attention(x) # built-in shortcut + x_attn = self.layerNormSelf(x_attn) + layerAttnList.append(x_attn) + + if idx < strideListLen - 1: + x = self.reduce(x_attn.permute(0, 3, 1 ,2)) # N*C*H*W + x = x.permute(0, 2, 3, 1) # N*H*W*C + + # layered cross-attention + KeyVal = layerAttnList[-1] + for idx in range(strideListLen-1, 0, -1): + Query = layerAttnList[idx-1] + Query = self.cross_attention(Query, KeyVal) # built-in shortcut + Query = self.layerNormCross(Query) + + KeyVal = Query + + Query = Query[:, :H, :W, :] # unpad + + q_residual = self.dropout(self.FFN(Query)) + x_new = self.layer_norm_out(Query + q_residual) + + outputs = x_new.permute(0, 3, 1, 2) + return outputs + + + +class BasicLayerEncoderBlock(nn.Module): + + def __init__(self, win_size, d_model, num_heads, d_ff, dropout=0.): + super().__init__() + + self.win_size = win_size + self.down_factor = 2 + self.unfold_stride = int(self.win_size//self.down_factor) + + self.stride_list = [math.floor(win_size/self.down_factor**idx) for idx in range(8) if win_size/self.down_factor**idx >= 1] + # [16, 8, 4, 2, 1] + + self.reduce = nn.Sequential( + nn.AvgPool2d(self.down_factor, self.down_factor) + ) + + self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) + self.crossAttn = MultiHeadAttention(d_model, num_heads, dropout) + + self.dropout = nn.Dropout(dropout) + self.layerNormSelf = nn.LayerNorm(d_model) + self.layerNormCross = nn.LayerNorm(d_model) + + self.FFN = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model) + ) + + self.layer_norm_out = nn.LayerNorm(d_model) + + + def Circular_pad2D(self, x, pad_right, pad_bottom): + ''' + x: (N, H, W, C) + x_pad: (N, H_pad, W_pad, C) + ''' + N, H, W, C = x.shape + + H_pad = H + pad_bottom + W_pad = W + pad_right + + H_repeat = math.ceil(H_pad/H) + W_repeat = math.ceil(W_pad/W) + x_repeat = x.repeat(1, H_repeat, W_repeat, 1) + + x_pad = x_repeat[:, :H_pad, :W_pad, :] + return x_pad + + + def pad_fit_win(self, x, win_size): + N, H, W, C = x.shape + + W_ = math.ceil(W/win_size)*win_size + H_ = math.ceil(H/win_size)*win_size + padRight = W_ - W + padBottom = H_ - H + + x_pad = self.Circular_pad2D(x, padRight, padBottom) # N*H_*W_*C + return x_pad + + + def self_attention(self, x): + ''' + x: (N, H, W, C) + out: (N, H, W, C) + ''' + N, H, W, C = x.shape + x_pad = self.pad_fit_win(x, self.win_size) # N*H_*W_*C + _, H_, W_, _ = x_pad.shape + + # x_pad = F.pad(x.permute(xxx), (0, padRight, 0, padBottom), mode='reflect') # N*C*H_*W_ + + x_window = window_partition(x_pad, self.win_size) # (num_win*B, win_size, win_size, C) + x_window = x_window.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C) + + # self-attention + residual = self.dropout(self.selfAttn(x_window, x_window, x_window)[0]) + residual = residual.view(-1, self.win_size, self.win_size, C) + residual = window_reverse(residual, self.win_size, H_, W_) # (N, H_, W_, C) + + out = x_pad + residual + out = out[:, :H, :W, :] + return out + + + def cross_attention(self, query, keyVal, query_win_size): + ''' + query: (N, qH, qW, C) + keyVal: (N, kH, kW, C) + out: (N, qH, qW, C) + ''' + _, qH, qW, C = query.shape + + query_win = window_partition(query, query_win_size) + query_win = query_win.view(-1, query_win_size*query_win_size, C) # (num_win*B, win_size*win_size, C) + + keyWinSize = query_win_size // 2 + keyVal_win = window_partition(keyVal, keyWinSize) + keyVal_win = keyVal_win.view(-1, keyWinSize*keyWinSize, C) # (num_win*B, win_size*win_size, C) + + residual = self.dropout(self.crossAttn(query_win, keyVal_win, keyVal_win)[0]) + residual = residual.view(-1, query_win_size, query_win_size, C) + residual = window_reverse(residual, query_win_size, qH, qW) # (N, H, W, C) + + out = query + residual + return out + + + def forward(self, x): + ''' + x: input batches with N*C*H*W + ''' + N, C, H, W = x.shape + x = x.permute(0, 2, 3, 1) # N*H*W*C + x = self.pad_fit_win(x, self.win_size) # pad + + # layered self-attention + layerAttnList = [] + strideListLen = len(self.stride_list) + for idx in range(strideListLen): + x_attn = self.self_attention(x) # built-in shortcut + x_attn = self.layerNormSelf(x_attn) + layerAttnList.append(x_attn) + + if idx < strideListLen - 1: + x = self.reduce(x_attn.permute(0, 3, 1 ,2)) # N*C*H*W + x = x.permute(0, 2, 3, 1) # N*H*W*C + + # layered cross-attention + KeyVal = layerAttnList[-1] + for idx in range(strideListLen-1, 0, -1): + Query = layerAttnList[idx-1] + QueryWinSize = self.stride_list[idx-1] + + Query = self.cross_attention(Query, KeyVal, QueryWinSize) # built-in shortcut + Query = self.layerNormCross(Query) + + KeyVal = Query + + Query = Query[:, :H, :W, :] # unpad + + q_residual = self.dropout(self.FFN(Query)) + x_new = self.layer_norm_out(Query + q_residual) + + outputs = x_new.permute(0, 3, 1, 2) + return outputs + + + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.): + super().__init__() + + self.max_len = 256 + self.d_model = d_model + + self._update_PE_table(self.max_len, self.d_model//2) + + + def _update_PE_table(self, max_len, d_model): + self.PE_table = torch.zeros(max_len, d_model) + + pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + denominator = torch.pow(10000, torch.arange(0, d_model, 2).float()/d_model) + + self.PE_table[:, 0::2] = torch.sin(pos/denominator) + self.PE_table[:, 1::2] = torch.cos(pos/denominator) + + + def forward(self, x): + ''' x: image batches with N*C*H*W ''' + + N, C, H, W = x.shape + max_hw = max(H, W) + + if max_hw > self.max_len or self.d_model != C: + self.max_len = max_hw + self.d_model = C + + self._update_PE_table(self.max_len, self.d_model//2) + + if self.PE_table.device != x.device: + self.PE_table = self.PE_table.to(x.device) + + h_pos_emb = self.PE_table[:H, :].unsqueeze(1).repeat(1, W, 1) # H*W*C/2 + w_pos_emb = self.PE_table[:W, :].unsqueeze(0).repeat(H, 1, 1) # H*W*C/2 + pos_emb = torch.cat([h_pos_emb, w_pos_emb], dim=-1 + ).permute([2,0,1]).unsqueeze(0).repeat(N,1,1,1) # N*C*H*W + + output = x + pos_emb + return output + + + +class TransformerEncoder(nn.Module): + + def __init__(self, anchor_dist, num_blocks, d_model, num_heads, d_ff, dropout=0.): + super().__init__() + + self.anchor_dist = anchor_dist + + blocks_list = [] + for idx in range(num_blocks): + # blocks_list.append( AnchorEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) + # blocks_list.append( EncoderBlock(d_model, num_heads, d_ff, dropout) ) + blocks_list.append( ReduceEncoderBlock(d_model, num_heads, d_ff, dropout) ) + # blocks_list.append( BasicLayerEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) + + self.blocks = nn.Sequential(*blocks_list) + + self.posEmbedding = PositionalEncoding(d_model, dropout) + + + def forward(self, x): + x_w_pos = self.posEmbedding(x) + x_updated = self.blocks(x_w_pos) + + return x_updated + + + +class RawInputTransEncoder(nn.Module): + + def __init__(self, anchor_dist, num_blocks, d_model, num_heads, d_ff, dropout=0.): + super().__init__() + + self.anchor_dist = anchor_dist + + self.linear = nn.Conv2d(3, d_model, 8, 8) + + blocks_list = [] + for idx in range(num_blocks): + # blocks_list.append( AnchorEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) + # blocks_list.append( EncoderBlock(d_model, num_heads, d_ff, dropout) ) + # blocks_list.append( ReduceEncoderBlock(d_model, num_heads, d_ff, dropout) ) + blocks_list.append( LayerEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) + + self.blocks = nn.Sequential(*blocks_list) + + self.posEmbedding = PositionalEncoding(d_model, dropout) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm, nn.LayerNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.linear(x) + x_w_pos = self.posEmbedding(x) + x_updated = self.blocks(x_w_pos) + + if is_list: + x_updated = torch.split(x_updated, [batch_dim, batch_dim], dim=0) + + return x_updated + + + +class GlobalLocalBlock(nn.Module): + + def __init__(self, anchor_dist, d_model, num_heads, out_dim, dropout=0., stride=1): + super().__init__() + + self.anchor_dist = anchor_dist + self.half_anchor_dist = anchor_dist // 2 + self.d_model = d_model + self.out_dim = out_dim + + self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) + self.dropout = nn.Dropout(dropout) + self.layer_norm_1 = nn.LayerNorm(d_model) + + self.resBlock_1 = ResidualBlock(d_model, d_model, norm_fn='instance', stride=stride) + self.change_channel = nn.Linear(d_model, out_dim) + self.resBlock_2 = ResidualBlock(out_dim, out_dim, norm_fn='instance', stride=1) + + self.posEmbedding = PositionalEncoding(d_model, dropout) + + + def forward(self, inputs): + ''' + inputs: batches with N*H*W*C + ''' + + # local update 1 + x = self.resBlock_1(inputs) + x = self.posEmbedding(x) + anchors = x[:,:, self.half_anchor_dist::self.anchor_dist, + self.half_anchor_dist::self.anchor_dist].clone() + + # flatten feature maps + N, C, H, W = x.shape + x = x.reshape(N, C, H*W).transpose(-1,-2) + anchors = anchors.reshape(N, C, anchors.shape[2]* anchors.shape[3]).transpose(-1,-2) + + # gloabl update with two-stage multi-head self-attention + anchors_new = self.dropout(self.selfAttn(anchors, x, x)[0]) + residual = self.dropout(self.selfAttn(x, anchors_new, anchors_new)[0]) + norm_1 = self.layer_norm_1(x + residual) + + # local update 2 + norm_1 = self.change_channel(norm_1) + norm_1 = norm_1.transpose(-1,-2).reshape(N, self.out_dim, H, W) + outputs = self.resBlock_2(norm_1) + + return outputs + + + +class GlobalLocalEncoder(nn.Module): + + def __init__(self, anchor_dist, output_dim, dropout=0.): + super().__init__() + + self.anchor_dist = anchor_dist + self.output_dim = output_dim + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.norm1 = nn.InstanceNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + + self.layer1 = GlobalLocalBlock(self.anchor_dist, 64, 2, 96, dropout, stride=2) + self.layer2 = GlobalLocalBlock(self.anchor_dist, 96, 3, 96, dropout, stride=1) + self.layer3 = GlobalLocalBlock(self.anchor_dist//2, 96, 4, 128, dropout, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm, nn.LayerNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + # def _make_layer(self, in_dim, out_dim, dropout=0., stride=1): + # layer1 = GlobalLocalBlock(self.anchor_dist, in_dim, in_dim//32, out_dim, dropout=0., stride=stride) + # layer2 = GlobalLocalBlock(self.anchor_dist, out_dim, out_dim//32, out_dim, dropout=0., stride=1) + # layers = (layer1, layer2) + + # return nn.Sequential(*layers) + + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.relu1(self.norm1(self.conv1(x))) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x \ No newline at end of file diff --git a/ptlflow/models/gmflownet/gma.py b/ptlflow/models/gmflownet/gma.py new file mode 100644 index 0000000..c1c8449 --- /dev/null +++ b/ptlflow/models/gmflownet/gma.py @@ -0,0 +1,123 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RelPosEmb(nn.Module): + def __init__( + self, + max_pos_size, + dim_head + ): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) + rel_ind = deltas + max_pos_size - 1 + self.register_buffer('rel_ind', rel_ind) + + def forward(self, q): + batch, heads, h, w, c = q.shape + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) + + height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) + width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) + + height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) + width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) + + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size = 100, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + self.pos_emb = RelPosEmb(max_pos_size, dim_head) + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + q, k = self.to_qk(fmap).chunk(2, dim=1) + + q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) + q = self.scale * q + + if self.args.position_only: + sim = self.pos_emb(q) + + elif self.args.position_and_content: + sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + sim_pos = self.pos_emb(q) + sim = sim_content + sim_pos + + else: + sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + + sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') + attn = sim.softmax(dim=-1) + + return attn + + +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) diff --git a/ptlflow/models/gmflownet/gmflownet.py b/ptlflow/models/gmflownet/gmflownet.py new file mode 100644 index 0000000..81f2a1f --- /dev/null +++ b/ptlflow/models/gmflownet/gmflownet.py @@ -0,0 +1,254 @@ +from argparse import ArgumentParser, Namespace +import configparser +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock +from .extractor import BasicEncoder, BasicConvEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 +from .swin_transformer import POLAUpdate, MixAxialPOLAUpdate +from .loss import compute_supervision_coarse, compute_coarse_loss, backwarp +from ..base_model.base_model import BaseModel + + +class SequenceLoss(nn.Module): + def __init__(self, args): + super().__init__() + self.gamma = args.gamma + self.max_flow = args.max_flow + self.use_matching_loss = args.use_matching_loss + + def forward(self, outputs, inputs): + """ Loss function defined over sequence of flow predictions """ + + flow_preds = outputs['flow_preds'] + soft_corr_map = outputs['soft_corr_map'] + image1 = inputs['images'][:, 0] + image2 = inputs['images'][:, 1] + flow_gt = inputs['flows'][:, 0] + valid = inputs['valids'][:, 0] + + # original RAFT loss + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exclude invalid pixels and extremely large displacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < self.max_flow) + + for i in range(n_predictions): + i_weight = self.gamma**(n_predictions - i - 1) + i_loss = (flow_preds[i] - flow_gt).abs() + flow_loss += i_weight * (valid[:, None].float() * i_loss).mean() + + if self.use_matching_loss: + # enable global matching loss. Try to use it in late stages of the trianing + img_2back1 = backwarp(image2, flow_gt) + occlusionMap = (image1 - img_2back1).mean(1, keepdims=True) #(N, H, W) + occlusionMap = torch.abs(occlusionMap) > 20 + occlusionMap = occlusionMap.float() + + conf_matrix_gt = compute_supervision_coarse(flow_gt, occlusionMap, 8) # 8 from RAFT downsample + + matchLossCfg = configparser.ConfigParser() + matchLossCfg.POS_WEIGHT = 1 + matchLossCfg.NEG_WEIGHT = 1 + matchLossCfg.FOCAL_ALPHA = 0.25 + matchLossCfg.FOCAL_GAMMA = 2.0 + matchLossCfg.COARSE_TYPE = 'cross_entropy' + match_loss = compute_coarse_loss(soft_corr_map, conf_matrix_gt, matchLossCfg) + + flow_loss = flow_loss + 0.01*match_loss + + return flow_loss + + +class GMFlowNet(BaseModel): + pretrained_checkpoints = { + 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflownet-things-9f061ac7.ckpt', + 'kitti': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflownet-kitti-712b4660.ckpt' + } + + def __init__(self, + args: Namespace) -> None: + super().__init__( + args=args, + loss_fn=SequenceLoss(args), + output_stride=8) + + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + + if not hasattr(self.args, 'dropout'): + self.args.dropout = 0 + + # feature network, context network, and update block + if self.args.use_mix_attn: + self.fnet = nn.Sequential( + BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout), + MixAxialPOLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7) + ) + else: + self.fnet = nn.Sequential( + BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout), + POLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7, neig_win_num=1) + ) + + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim) + + @staticmethod + def add_model_specific_args(parent_parser=None): + parent_parser = BaseModel.add_model_specific_args(parent_parser) + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--corr_levels', type=int, default=4) + parser.add_argument('--corr_radius', type=int, default=4) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--gamma', type=float, default=0.8) + parser.add_argument('--max_flow', type=float, default=400.0) + parser.add_argument('--iters', type=int, default=12) + parser.add_argument('--use_matching_loss', action='store_true') + parser.add_argument('--use_mix_attn', action='store_true') + return parser + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + def forward(self, inputs, flow_init=None): + """ Estimate optical flow between pair of frames """ + inputs['images'] = 2 * inputs['images'] - 1.0 + image1 = inputs['images'][:, 0] + image2 = inputs['images'][:, 1] + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + fmap1, fmap2 = self.fnet([image1, image2]) + fmap1 = fmap1.float() + fmap2 = fmap2.float() + + # # Self-attention update + # fmap1 = self.transEncoder(fmap1) + # fmap2 = self.transEncoder(fmap2) + + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + # Correlation as initialization + N, fC, fH, fW = fmap1.shape + corrMap = corr_fn.corrMap + + #_, coords_index = torch.max(corrMap, dim=-1) # no gradient here + softCorrMap = F.softmax(corrMap, dim=2) * F.softmax(corrMap, dim=1) # (N, fH*fW, fH*fW) + + if flow_init is not None: + coords1 = coords1 + flow_init + else: + # print('matching as init') + # mutual match selection + match12, match_idx12 = softCorrMap.max(dim=2) # (N, fH*fW) + match21, match_idx21 = softCorrMap.max(dim=1) + + for b_idx in range(N): + match21_b = match21[b_idx,:] + match_idx12_b = match_idx12[b_idx,:] + match21[b_idx,:] = match21_b[match_idx12_b] + + matched = (match12 - match21) == 0 # (N, fH*fW) + coords_index = torch.arange(fH*fW).unsqueeze(0).repeat(N,1).to(softCorrMap.device) + coords_index[matched] = match_idx12[matched] + + # matched coords + coords_index = coords_index.reshape(N, fH, fW) + coords_x = coords_index % fW + coords_y = coords_index // fW + + coords_xy = torch.stack([coords_x, coords_y], dim=1).float() + coords1 = coords_xy + + # Iterative update + flow_predictions = [] + for itr in range(self.args.iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if self.training: + outputs = { + 'flows': flow_up[:, None], + 'flow_preds': flow_predictions, + 'soft_corr_map': softCorrMap + } + else: + outputs = { + 'flows': flow_up[:, None], + 'flow_small': coords1 - coords0 + } + + return outputs + + +class GMFlowNetMix(GMFlowNet): + pretrained_checkpoints = { + 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflownet_mix-things-8396f0a1.ckpt', + 'sintel': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/gmflownet_mix-sintel-33492618.ckpt', + } + + def __init__(self, + args: Namespace) -> None: + args.use_mix_attn = True + super().__init__( + args=args) diff --git a/ptlflow/models/gmflownet/loss.py b/ptlflow/models/gmflownet/loss.py new file mode 100644 index 0000000..973bedf --- /dev/null +++ b/ptlflow/models/gmflownet/loss.py @@ -0,0 +1,149 @@ +import torch +import numpy as np + + +from typing import Optional +def create_meshgrid( + height: int, + width: int, + normalized_coordinates: bool = True, + device: Optional[torch.device] = torch.device('cpu'), + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Generate a coordinate grid for an image. + When the flag ``normalized_coordinates`` is set to True, the grid is + normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch + function :py:func:`torch.nn.functional.grid_sample`. + Args: + height: the image height (rows). + width: the image width (cols). + normalized_coordinates: whether to normalize + coordinates in the range :math:`[-1,1]` in order to be consistent with the + PyTorch function :py:func:`torch.nn.functional.grid_sample`. + device: the device on which the grid will be generated. + dtype: the data type of the generated grid. + Return: + grid tensor with shape :math:`(1, H, W, 2)`. + Example: + >>> create_meshgrid(2, 2) + tensor([[[[-1., -1.], + [ 1., -1.]], + + [[-1., 1.], + [ 1., 1.]]]]) + >>> create_meshgrid(2, 2, normalized_coordinates=False) + tensor([[[[0., 0.], + [1., 0.]], + + [[0., 1.], + [1., 1.]]]]) + """ + xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype) + ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype) + # Fix TracerWarning + # Note: normalize_pixel_coordinates still gots TracerWarning since new width and height + # tensors will be generated. + # Below is the code using normalize_pixel_coordinates: + # base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=2) + # if normalized_coordinates: + # base_grid = K.geometry.normalize_pixel_coordinates(base_grid, height, width) + # return torch.unsqueeze(base_grid.transpose(0, 1), dim=0) + if normalized_coordinates: + xs = (xs / (width - 1) - 0.5) * 2 + ys = (ys / (height - 1) - 0.5) * 2 + # generate grid by stacking coordinates + base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW + return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2 + + +def backwarp(img, flow): + _, _, H, W = img.size() + + u = flow[:, 0, :, :] + v = flow[:, 1, :, :] + + gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) + gridX = torch.tensor(gridX, requires_grad=False,).to(flow.device) + gridY = torch.tensor(gridY, requires_grad=False,).to(flow.device) + x = gridX.unsqueeze(0).expand_as(u).float() + u + y = gridY.unsqueeze(0).expand_as(v).float() + v + # range -1 to 1 + x = 2*(x/W - 0.5) + y = 2*(y/H - 0.5) + # stacking X and Y + grid = torch.stack((x,y), dim=3) + # Sample pixels using bilinear interpolation. + imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=False) + + return imgOut + + +@torch.no_grad() +def compute_supervision_coarse(flow, occlusions, scale: int): + N, _, H, W = flow.shape + Hc, Wc = int(np.ceil(H / scale)), int(np.ceil(W / scale)) + + occlusions_c = occlusions[:, :, ::scale, ::scale] + flow_c = flow[:, :, ::scale, ::scale] / scale + occlusions_c = occlusions_c.reshape(N, Hc * Wc) + + grid_c = create_meshgrid(Hc, Wc, False, device=flow.device).reshape(1, Hc * Wc, 2).repeat(N, 1, 1) + warp_c = grid_c + flow_c.permute(0, 2, 3, 1).reshape(N, Hc * Wc, 2) + warp_c = warp_c.round().long() + + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + + occlusions_c[out_bound_mask(warp_c, Wc, Hc)] = 1 + warp_c = warp_c[..., 0] + warp_c[..., 1] * Wc + + b_ids, i_ids = torch.split(torch.nonzero(occlusions_c == 0), 1, dim=1) + conf_matrix_gt = torch.zeros(N, Hc * Wc, Hc * Wc, device=flow.device) + j_ids = warp_c[b_ids, i_ids] + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + + return conf_matrix_gt + + +def compute_coarse_loss(conf, conf_gt, cfg): + c_pos_w, c_neg_w = cfg.POS_WEIGHT, cfg.NEG_WEIGHT + pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 + + if cfg.COARSE_TYPE == 'cross_entropy': + conf = torch.clamp(conf, 1e-6, 1 - 1e-6) + loss_pos = -torch.log(conf[pos_mask]) + loss_neg = -torch.log(1 - conf[neg_mask]) + + return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() + elif cfg.COARSE_TYPE == 'focal': + conf = torch.clamp(conf, 1e-6, 1 - 1e-6) + alpha = cfg.FOCAL_ALPHA + gamma = cfg.FOCAL_GAMMA + loss_pos = -alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() + loss_neg = -alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() + return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() + else: + raise ValueError('Unknown coarse loss: {type}'.format(type=cfg.COARSE_TYPE)) + + +def compute_fine_loss(kflow, kflow_gt, cfg): + fine_correct_thr = cfg.WINDOW_SIZE // 2 * 2 + error = (kflow - kflow_gt).abs() + correct = torch.max(error, dim=1)[0] < fine_correct_thr + rate = torch.sum(correct).float() / correct.shape[0] + num = correct.shape[0] + return error[correct].mean(), rate.item(), num + + +def compute_flow_loss(flow, flow_gt): + loss = (flow - flow_gt).abs().mean() + epe = torch.sum((flow - flow_gt)**2, dim=1).sqrt() + + metrics = { + 'epe': epe.mean().item(), + '1px': (epe < 1).float().mean().item(), + '3px': (epe < 3).float().mean().item(), + '5px': (epe < 5).float().mean().item(), + } + + return loss, metrics diff --git a/ptlflow/models/gmflownet/swin_transformer.py b/ptlflow/models/gmflownet/swin_transformer.py new file mode 100644 index 0000000..e9269cc --- /dev/null +++ b/ptlflow/models/gmflownet/swin_transformer.py @@ -0,0 +1,1467 @@ +# -------------------------------------------------------- +# This script is modified from the following source by Shiyu Zhao +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from .utils import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + use_shift_win = True, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0 if (i % 2 == 0) else window_size // 2) \ + if use_shift_win else 0, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransEncoder(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + use_shift_win=True, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + use_shift_win=use_shift_win, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.apply(self._init_weights) + self._freeze_stages() + + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + def forward(self, x): + """Forward function.""" + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + x_out = outs[-1] + if is_list: + x_out = torch.split(x_out, [batch_dim, batch_dim], dim=0) + + return x_out + + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransEncoder, self).train(mode) + self._freeze_stages() + + +# -------------------------------------------------------- +# Backbones for GMFlowNet +# -------------------------------------------------------- +class NeighborWindowAttention(nn.Module): + """ Patch-based OverLapping multi-head self-Attention (POLA) module with relative position bias. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window (or patch). + num_heads (int): Number of attention heads. + neig_win_num (int): Number of neighbor windows. Default: 1 + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, neig_win_num=1, + qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., use_proj=True): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.use_proj = use_proj + + # define a parameter table of relative position bias + self.n_win = 2*neig_win_num + 1 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(((self.n_win + 1) * window_size[0] - 1) * ((self.n_win + 1) * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + + coords_h_neig = torch.arange(self.n_win * self.window_size[0]) + coords_w_neig = torch.arange(self.n_win * self.window_size[1]) + coords_neig = torch.stack(torch.meshgrid([coords_h_neig, coords_w_neig])) # 2, Wh, Ww + + coords_flat = torch.flatten(coords, 1) # 2, Wh*Ww + coords_neig_flat = torch.flatten(coords_neig, 1) # 2, (n_win*Wh)*(n_win*Ww) + relative_coords = coords_flat[:, :, None] - coords_neig_flat[:, None, :] # 2, Wh*Ww, (n_win*Wh)*(n_win*Ww) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww,(n_win*Wh)*(n_win*Ww), 2 + relative_coords[:, :, 0] += self.n_win * self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.n_win * self.window_size[1] - 1 + relative_coords[:, :, 0] *= (self.n_win + 1) * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.Wq = nn.Linear(dim, dim, bias=qkv_bias) + self.Wk = nn.Linear(dim, dim, bias=qkv_bias) + self.Wv = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + if self.use_proj: + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, q, k, v, mask=None): + """ Forward function. + Args: + q: input queries with shape of (num_windows*B, N, C) + k: input keys with shape of (num_windows*B, N, C) + v: input values with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N_q, C = q.shape + N_kv = k.shape[1] + # qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + dim_per_head = C // self.num_heads + q = self.Wq(q).reshape(B_, N_q, self.num_heads, dim_per_head).permute(0, 2, 1, 3) + k = self.Wk(k).reshape(B_, N_kv, self.num_heads, dim_per_head).permute(0, 2, 1, 3) + v = self.Wv(v).reshape(B_, N_kv, self.num_heads, dim_per_head).permute(0, 2, 1, 3) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.n_win*self.window_size[0] * self.n_win*self.window_size[1], -1) # Wh*Ww,(n_win*Wh)*(n_win*Ww),nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, (n_win*Wh)*(n_win*Ww) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N_q, N_kv) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N_q, N_kv) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N_q, C) + if self.use_proj: + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MultiHeadAttention(nn.Module): + """ MultiHeadAttention modified from SwinTransformer + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, + attn_drop=0., proj_drop=0., use_proj=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.use_proj = use_proj + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.Wq = nn.Linear(dim, dim, bias=qkv_bias) + self.Wk = nn.Linear(dim, dim, bias=qkv_bias) + self.Wv = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + if self.use_proj: + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, q, k, v, mask=None): + """ Forward function. + Args: + q: input queries with shape of (B, Nq, C) + k: input keys with shape of (B, Nk, C) + v: input values with shape of (B, Nk, C) + mask: (0/-inf) mask with shape of (Nq, Nk) or None + """ + B, N_q, C = q.shape + N_kv = k.shape[1] + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + dim_per_head = C // self.num_heads + q = self.Wq(q).reshape(B, N_q, self.num_heads, dim_per_head).permute(0, 2, 1, 3) + k = self.Wk(k).reshape(B, N_kv, self.num_heads, dim_per_head).permute(0, 2, 1, 3) + v = self.Wv(v).reshape(B, N_kv, self.num_heads, dim_per_head).permute(0, 2, 1, 3) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # B, num_heads, Nq, Nk + + if mask is not None: + attn = attn + mask.unsqueeze(0).unsqueeze(0) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N_q, C) + if self.use_proj: + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class POLATransBlock(nn.Module): + """ Transformer block with POLA. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window/patch size. + neig_win_num (int): Number of overlapped windows + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, neig_win_num=1, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.neig_win_num = neig_win_num + self.mlp_ratio = mlp_ratio + + self.n_win = 2 * neig_win_num + 1 + + self.norm1 = norm_layer(dim) + + self.attn = NeighborWindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, neig_win_num=neig_win_num, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + + def forward(self, x, H, W, attn_mask=None): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + # print('LocalTransBlock x.shape: ', x.shape) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # partition windows + x_win = window_partition(x, self.window_size) # nW*B, window_size, window_size, C + x_win = x_win.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # pad and unfold + pad_size = self.neig_win_num * self.window_size + key_val = F.pad(x, (0, 0, pad_size, pad_size, pad_size, pad_size)) # B, H'+2*1*win, W'+2*1*win, C + key_val = F.unfold(key_val.permute(0, 3, 1, 2), self.n_win*self.window_size, stride=self.window_size) + key_val = key_val.permute(0,2,1).reshape(-1, C, (self.n_win*self.window_size)**2).permute(0,2,1) # (B*num_win, (3*3)*win_size*win_size, C) + + # Local attention feature + attn_windows = self.attn(x_win, key_val, key_val, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class MixAxialPOLABlock(nn.Module): + """ Transformer block with mixture of POLA, vertical and horizontal axis self-attentions + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads=8, window_size=7, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + + self.dim_per_head = dim // self.num_heads + self.axis_head = 2 # self.num_heads // 4 + self.local_head = self.num_heads - 2 * self.axis_head + + self.local_chl = self.local_head * self.dim_per_head + self.axis_chl = self.axis_head * self.dim_per_head + + # for POLA + self.neig_win_num = 1 + self.n_win = 2 * self.neig_win_num + 1 + self.norm1 = norm_layer(dim) + self.localAttn = NeighborWindowAttention(self.local_chl, window_size=to_2tuple(self.window_size), + num_heads=self.local_head, neig_win_num=self.neig_win_num, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + # for axial attention + self.vertiAttn = MultiHeadAttention(self.axis_chl, num_heads=self.axis_head, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, use_proj=False) + + self.horizAttn = MultiHeadAttention(self.axis_chl, num_heads=self.axis_head, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, use_proj=False) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + + def forward(self, x, H, W, attn_mask=None): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + # print('LocalTransBlock x.shape: ', x.shape) + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + x_local, x_horiz, x_verti = torch.split(x, [self.local_chl, self.axis_chl, self.axis_chl], dim=-1) + + # Local patch update + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x_local = F.pad(x_local, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x_local.shape + + # partition windows + x_windows = window_partition(x_local, self.window_size) # nW*B, window_size, window_size, C1 + x_windows = x_windows.view(-1, self.window_size * self.window_size, self.local_chl) # nW*B, window_size*window_size, C1 + + # pad and unfold + pad_size = self.neig_win_num * self.window_size + key_val = F.pad(x_local, (0, 0, pad_size, pad_size, pad_size, pad_size)) # B, H'+2*1*win, W'+2*1*win, C + key_val = F.unfold(key_val.permute(0, 3, 1, 2), self.n_win*self.window_size, stride=self.window_size) + key_val = key_val.permute(0,2,1).reshape(-1, self.local_chl, (self.n_win*self.window_size)**2).permute(0,2,1) # (B*num_win, (3*3)*win_size*win_size, C) + + # Local attention feature + attn_windows = self.localAttn(x_windows, key_val, key_val, mask=attn_mask) # nW*B, window_size*window_size, C1 + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.local_chl) + x_local = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C1 + + if pad_r > 0 or pad_b > 0: + x_local = x_local[:, :H, :W, :].contiguous() + + # Horizontal update + x_horiz = x_horiz.view(-1, W, self.axis_chl) # (B*H), W, C2 + x_horiz = self.horizAttn(x_horiz, x_horiz, x_horiz) + x_horiz = x_horiz.view(B, H, W, self.axis_chl) # B, H, W, C2 + + # Vertical update + x_verti = x_verti.transpose(1, 2).reshape(-1, H, self.axis_chl) # B, W, H, C3 -> (B*W), H, C3 + x_verti = self.vertiAttn(x_verti, x_verti, x_verti) + x_verti = x_verti.view(B, W, H, self.axis_chl).transpose(1, 2) # B, H, W, C3 + + x = torch.cat([x_local, x_horiz, x_verti], dim=-1) # B, H, W, C + x = x.view(B, H*W, C) + + x = self.proj(x) + x = self.proj_drop(x) # B, (H*W), C + + #FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + + +class BasicSwinUpdate(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + embed_dim (int): Number of linear projection output channels. Default: 96. + depth (int): number of Swin Transformer blocks. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + embed_dim=96, + depth=6, + num_head=3, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + use_shift_win=True, + use_checkpoint=False): + super().__init__() + + self.num_feature = embed_dim + self.window_size = window_size + self.shift_size = window_size // 2 + self.use_checkpoint = use_checkpoint + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim = self.num_feature, + num_heads = num_head, + window_size = self.window_size, + shift_size = (0 if (i % 2 == 0) else self.window_size // 2) \ + if use_shift_win else 0, + mlp_ratio = mlp_ratio, + qkv_bias = qkv_bias, + qk_scale = qk_scale, + drop = drop_rate, + attn_drop = attn_drop_rate, + drop_path = dpr[i], + norm_layer = norm_layer) + for i in range(depth) ]) + + self.norm = norm_layer(self.num_feature) + self.apply(self._init_weights) + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + def forward(self, x): + """Forward function.""" + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B L C + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + + x = self.norm(x) + out = x.view(-1, H, W, self.num_feature).permute(0, 3, 1, 2).contiguous() + + if is_list: + out = torch.split(out, [batch_dim, batch_dim], dim=0) + + return out + + +class POLAUpdate(nn.Module): + """ POLA update for GMFlowNet. + A PyTorch impl of : `Global Matching with Overlapping Attention for Optical Flow Estimation` - + https://arxiv.org/abs/2203.11335 + Args: + embed_dim (int): Number of linear projection output channels. Default: 256. + depths (int): Number of POLA blocks. + num_heads (int): Number of attention head in each POLA block. + window_size (int): Window/patch size. Default: 7. + neig_win_num: Number of overlapped Windows/patches + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + embed_dim=256, + depth=6, + num_head=8, + window_size=7, + neig_win_num=1, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + use_checkpoint=False): + super().__init__() + + self.num_feature = embed_dim + self.num_head = num_head + self.win_size = window_size + self.neig_win_num = neig_win_num + + self.use_checkpoint = use_checkpoint + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # build blocks + self.blocks = nn.ModuleList([ + POLATransBlock( + dim=self.num_feature, + num_heads=self.num_head, + window_size=self.win_size, + neig_win_num=self.neig_win_num, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(self.num_feature) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + """Forward function.""" + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B L C + + # calculate attention mask for ConvAlikeLocalTransBlock + img_mask = torch.zeros((1, H, W, 1), device=x.device) # 1 H W 1 + + pad_r = (self.win_size - W % self.win_size) % self.win_size + pad_b = (self.win_size - H % self.win_size) % self.win_size + pad_extra = self.neig_win_num * self.win_size + img_mask = F.pad(img_mask, (0, 0, pad_extra, pad_r + pad_extra, + pad_extra, pad_b + pad_extra), + mode='constant', value=float(-100.0)) + + # unfold + n_win = 2 * self.neig_win_num + 1 + mask_windows = F.unfold(img_mask.permute(0, 3, 1, 2), n_win * self.win_size, stride=self.win_size) + mask_windows = mask_windows.permute(0, 2, 1).reshape(-1, ( + n_win * self.win_size) ** 2) # (num_win, (3*3)*win_size*win_size) + attn_mask = mask_windows.unsqueeze(1).repeat(1, self.win_size * self.win_size, 1) + + # update features + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, H, W, attn_mask) + else: + x = blk(x, H, W, attn_mask) + + x = self.norm(x) + out = x.view(-1, H, W, self.num_feature).permute(0, 3, 1, 2).contiguous() + + if is_list: + out = torch.split(out, [batch_dim, batch_dim], dim=0) + + return out + + +class MixSelfAttnUpdate(nn.Module): + """ MixSelfAttnUpdate + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + embed_dim=256, + depth=6, + num_head=8, + window_size=7, + neig_win_num=1, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + use_checkpoint=False): + super().__init__() + + self.num_feature = embed_dim + self.num_head = num_head + self.win_size = window_size + + self.use_checkpoint = use_checkpoint + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # build blocks + self.blocks = nn.ModuleList([ + MixAxialPOLABlock( + dim=self.num_feature, + num_heads=self.num_head, + window_size=self.win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer) + for i in range(depth) ]) + + self.norm = norm_layer(self.num_feature) + self.apply(self._init_weights) + + self.x_list = list() # for retbuttal + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + def forward(self, x): + """Forward function.""" + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B L C + + # update features + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, H, W) + else: + x = blk(x, H, W) + + # for retbuttal, n,C,H,W, + self.x_list.append(x.view(-1, H, W, self.num_feature).permute(0, 3, 1, 2).contiguous()) + + x = self.norm(x) + out = x.view(-1, H, W, self.num_feature).permute(0, 3, 1, 2).contiguous() + + if is_list: + out = torch.split(out, [batch_dim, batch_dim], dim=0) + + return out + + +class MixAxialPOLAUpdate(nn.Module): + """ Mixture attention (POLA and axial attentions) update for GMFlowNet. + A PyTorch impl of : `Global Matching with Overlapping Attention for Optical Flow Estimation` - + https://arxiv.org/abs/2203.11335 + Args: + embed_dim (int): Number of linear projection output channels. Default: 96. + depth (tuple[int]): Number of mix attention blocks. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + neig_win_num (int): Number of overlapped windows for POLA + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + embed_dim=256, + depth=6, + num_head=8, + window_size=7, + neig_win_num=1, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + use_checkpoint=False): + super().__init__() + + self.num_feature = embed_dim + self.num_head = num_head + self.win_size = window_size + + self.use_checkpoint = use_checkpoint + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # build blocks + self.blocks = nn.ModuleList([ + MixAxialPOLABlock( + dim=self.num_feature, + num_heads=self.num_head, + window_size=self.win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer) + for i in range(depth) ]) + + self.norm = norm_layer(self.num_feature) + self.apply(self._init_weights) + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + def forward(self, x): + """Forward function.""" + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B L C + + # update features + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, H, W) + else: + x = blk(x, H, W) + + x = self.norm(x) + out = x.view(-1, H, W, self.num_feature).permute(0, 3, 1, 2).contiguous() + + if is_list: + out = torch.split(out, [batch_dim, batch_dim], dim=0) + + return out diff --git a/ptlflow/models/gmflownet/update.py b/ptlflow/models/gmflownet/update.py new file mode 100644 index 0000000..fffee8c --- /dev/null +++ b/ptlflow/models/gmflownet/update.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .gma import Aggregate + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + if hasattr(args, 'motion_feat_indim'): + cor_planes = args.motion_feat_indim + else: + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+input_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + hidden_dim2 = (hidden_dim // 256)*128 + 256 + + self.mask = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim2, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim2, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads) + + def forward(self, net, inp, corr, flow, attention): + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow diff --git a/ptlflow/models/gmflownet/utils/__init__.py b/ptlflow/models/gmflownet/utils/__init__.py new file mode 100644 index 0000000..5e4165f --- /dev/null +++ b/ptlflow/models/gmflownet/utils/__init__.py @@ -0,0 +1,4 @@ +# functions from timm +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ \ No newline at end of file diff --git a/ptlflow/models/gmflownet/utils/augmentor.py b/ptlflow/models/gmflownet/utils/augmentor.py new file mode 100644 index 0000000..e81c4f2 --- /dev/null +++ b/ptlflow/models/gmflownet/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/ptlflow/models/gmflownet/utils/drop.py b/ptlflow/models/gmflownet/utils/drop.py new file mode 100644 index 0000000..a43c865 --- /dev/null +++ b/ptlflow/models/gmflownet/utils/drop.py @@ -0,0 +1,166 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + block_mask = torch.empty_like(x).bernoulli_(gamma) + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.empty_like(x).normal_() + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + + def __init__( + self, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + fast: bool = True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) \ No newline at end of file diff --git a/ptlflow/models/gmflownet/utils/flow_viz.py b/ptlflow/models/gmflownet/utils/flow_viz.py new file mode 100644 index 0000000..dcee65e --- /dev/null +++ b/ptlflow/models/gmflownet/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/ptlflow/models/gmflownet/utils/frame_utils.py b/ptlflow/models/gmflownet/utils/frame_utils.py new file mode 100644 index 0000000..6c49113 --- /dev/null +++ b/ptlflow/models/gmflownet/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/ptlflow/models/gmflownet/utils/helpers.py b/ptlflow/models/gmflownet/utils/helpers.py new file mode 100644 index 0000000..64573ef --- /dev/null +++ b/ptlflow/models/gmflownet/utils/helpers.py @@ -0,0 +1,31 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v \ No newline at end of file diff --git a/ptlflow/models/gmflownet/utils/utils.py b/ptlflow/models/gmflownet/utils/utils.py new file mode 100644 index 0000000..5f32d28 --- /dev/null +++ b/ptlflow/models/gmflownet/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/ptlflow/models/gmflownet/utils/weight_init.py b/ptlflow/models/gmflownet/utils/weight_init.py new file mode 100644 index 0000000..4626747 --- /dev/null +++ b/ptlflow/models/gmflownet/utils/weight_init.py @@ -0,0 +1,89 @@ +import torch +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') \ No newline at end of file diff --git a/ptlflow/models/raft/raft.py b/ptlflow/models/raft/raft.py index c4c97d5..ffa9195 100644 --- a/ptlflow/models/raft/raft.py +++ b/ptlflow/models/raft/raft.py @@ -1,7 +1,5 @@ from argparse import ArgumentParser, Namespace -from pathlib import Path -import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F @@ -57,8 +55,6 @@ def __init__(self, self.hidden_dim = hdim = 128 self.context_dim = cdim = 128 - args.corr_levels = 4 - args.corr_radius = 4 if 'dropout' not in self.args: self.args.dropout = 0 @@ -116,7 +112,6 @@ def upsample_flow(self, flow, mask): up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, 2, 8*H, 8*W) - def forward(self, inputs, flow_init=None): """ Estimate optical flow between pair of frames """ image1 = inputs['images'][:, 0] diff --git a/ptlflow/models/scv/knn.py b/ptlflow/models/scv/knn.py index 516cc4e..d76a8ac 100644 --- a/ptlflow/models/scv/knn.py +++ b/ptlflow/models/scv/knn.py @@ -1,16 +1,11 @@ try: import faiss + res = faiss.StandardGpuResources() + res.setDefaultNullStreamAllDevices() except ImportError: - raise ImportError( - 'ERROR: faiss not found.' - ' CSV requires faiss library to run.' - ' Install with pip install faiss-gpu' - ) + faiss = None import torch -res = faiss.StandardGpuResources() -res.setDefaultNullStreamAllDevices() - def swig_ptr_from_Tensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ @@ -26,7 +21,7 @@ def swig_ptr_from_Tensor(x): def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, - metric=faiss.METRIC_L2): + metric=faiss.METRIC_L2 if faiss is not None else 0): """search xq in xb, without building an index""" assert xb.device == xq.device diff --git a/ptlflow/models/scv/scv.py b/ptlflow/models/scv/scv.py index a801907..c1d6759 100644 --- a/ptlflow/models/scv/scv.py +++ b/ptlflow/models/scv/scv.py @@ -95,7 +95,32 @@ def forward(self, x): return self.flowpredictor(x) -class SCVQuarter(BaseModel): +class SCVBase(BaseModel): + def __init__(self, + args: Namespace) -> None: + super().__init__( + args=args, + loss_fn=SequenceLoss(args), + output_stride=8) + + try: + import torch_scatter + except ImportError: + raise ImportError( + 'ERROR: torch_scatter not found.' + ' SCV requires torch_scatter library to run.' + ' Check instructions at: https://github.com/rusty1s/pytorch_scatter' + ) + try: + import faiss + except ImportError: + raise ImportError( + 'ERROR: faiss not found.' + ' CSV requires faiss library to run.' + ' Install with pip install faiss-gpu' + ) + +class SCVQuarter(SCVBase): pretrained_checkpoints = { 'chairs': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/scv-quarter-chairs-4726627e.ckpt', 'kitti': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/scv-quarter-kitti-e86c7953.ckpt', @@ -105,10 +130,7 @@ class SCVQuarter(BaseModel): def __init__(self, args: Namespace) -> None: - super().__init__( - args=args, - loss_fn=SequenceLoss(args), - output_stride=8) + super().__init__(args=args) # feature network, context network, and update block self.fnet = BasicEncoderQuarter(output_dim=256, norm_fn='instance', dropout=False) @@ -257,7 +279,7 @@ def forward(self, inputs, flow_init=None): return outputs -class SCVEighth(BaseModel): +class SCVEighth(SCVBase): pretrained_checkpoints = { 'chairs': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/scv-eighth-chairs-8ba57294.ckpt', 'things': 'https://github.com/hmorimitsu/ptlflow/releases/download/weights1/scv-eighth-things-9c893323.ckpt' @@ -265,10 +287,7 @@ class SCVEighth(BaseModel): def __init__(self, args: Namespace) -> None: - super().__init__( - args=args, - loss_fn=SequenceLoss(args), - output_stride=8) + super().__init__(args=args) # feature network, context network, and update block self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=False) diff --git a/ptlflow/models/scv/utils.py b/ptlflow/models/scv/utils.py index 90cf029..660cda5 100644 --- a/ptlflow/models/scv/utils.py +++ b/ptlflow/models/scv/utils.py @@ -5,11 +5,7 @@ try: from torch_scatter import scatter_softmax, scatter_add except ImportError: - raise ImportError( - 'ERROR: torch_scatter not found.' - ' CSV requires torch_scatter library to run.' - ' Check instructions at: https://github.com/rusty1s/pytorch_scatter' - ) + pass class InputPadder: diff --git a/ptlflow/utils/flow_metrics.py b/ptlflow/utils/flow_metrics.py index b1b6cb3..65d4575 100644 --- a/ptlflow/utils/flow_metrics.py +++ b/ptlflow/utils/flow_metrics.py @@ -39,6 +39,8 @@ class FlowMetrics(Metric): A prefix string that will be attached to the metric names. """ + full_state_update = False + def __init__( self, dist_sync_on_step: bool = False, diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 2e26fd8..0000000 --- a/requirements.txt +++ /dev/null @@ -1,16 +0,0 @@ -einops >= 0.3.0 -numpy >= 1.17.0 -opencv-python >= 4.0.0.21 -packaging >= 20.0 -pandas >= 1.1.0 -pillow >= 5.0 -plotly >= 5.0.0 -pypng >= 0.0.16 -pytorch-lightning >= 1.1.0 -requests >= 2.0.0 -scipy >= 1.0.0 -tabulate >= 0.8.3 -torch >= 1.7.0 -torchmetrics >= 0.2 -torchvision >= 0.8.0 -tqdm >= 4.41.0 \ No newline at end of file diff --git a/requirements_flake8.txt b/requirements_flake8.txt index 49ad9b1..f0ab1cb 100644 --- a/requirements_flake8.txt +++ b/requirements_flake8.txt @@ -4,6 +4,5 @@ flake8-blind-except flake8-bugbear flake8-comprehensions flake8-docstrings -flake8-eradicate flake8-rst-docstrings pep8-naming \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 97946a7..5da582a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,23 +32,24 @@ classifiers = packages = find: include_package_data = True install_requires = - einops >= 0.3.0 - numpy >= 1.17.0 - opencv-python >= 4.0.0.21 - packaging >= 20.0 - pandas >= 1.1.0 - pillow >= 5.0 - plotly >= 5.0.0 - pypng >= 0.0.16 - pytorch-lightning >= 1.1.0 - requests >= 2.0.0 - scipy >= 1.0.0 - tabulate >= 0.8.3 - torch >= 1.7.0 - torchmetrics >= 0.2 - torchvision >= 0.8.0 - tqdm >= 4.41.0 -python_requires = >=3.6 + einops>=0.3.0,<=0.4.* + numpy>=1.17.0,<=1.22.* + opencv-python>=4.0.0.21,<=4.6.* + packaging>=20.0,<=21.* + pandas>=1.1.0,<=1.4.* + pillow>=5.0,<=9.2.* + plotly>=5.0.0,<=5.9.* + pypng~=0.0.16 + pytorch-lightning>=1.1.0,<=1.6.*,!=1.3.*,!=1.4.* + requests>=2.0.0,<=2.28.* + scipy>=1.0.0,<=1.9.* + tabulate~=0.8.3 + timm~=0.6.3 + torch>=1.8.1,<=1.12.* + torchmetrics>=0.2,<=0.9.* + torchvision>=0.9.2,<=0.13.* + tqdm>=4.41.0,<=4.64.* +python_requires = >=3.8 [options.packages.find] exclude = diff --git a/speed_benchmark.py b/speed_benchmark.py index 8f7bc6a..90396eb 100644 --- a/speed_benchmark.py +++ b/speed_benchmark.py @@ -17,6 +17,7 @@ # ============================================================================= import argparse +import logging from pathlib import Path from typing import Union @@ -29,10 +30,12 @@ import ptlflow from ptlflow.models.base_model.base_model import BaseModel from ptlflow.utils.timer import Timer -from ptlflow.utils.utils import count_parameters, get_list_of_available_models_list, make_divisible +from ptlflow.utils.utils import config_logging, count_parameters, get_list_of_available_models_list, make_divisible TABLE_COLS = ['Model', 'Params', 'Time(ms)'] +config_logging() + def _init_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() @@ -83,17 +86,21 @@ def benchmark( else: model_names = [args.model] for mname in tqdm(model_names): - model = ptlflow.get_model(mname) - model = model.eval() - if torch.cuda.is_available(): - model = model.cuda() - model_params = count_parameters(model) - infer_timer = estimate_inference_time(args, model) - values = [mname, model_params, infer_timer*1000] - df = df.append({c: v for c, v in zip(df.columns, values)}, ignore_index=True) - df = df.round(3) - df.to_csv(output_path / f'speed_benchmark-{args.model}.csv', index=False) - save_plot(output_path, args.model, df) + try: + model = ptlflow.get_model(mname) + model = model.eval() + if torch.cuda.is_available(): + model = model.cuda() + model_params = count_parameters(model) + infer_timer = estimate_inference_time(args, model) + values = [mname, model_params, infer_timer*1000] + new_df = pd.DataFrame({c: [v] for c, v in zip(df.columns, values)}) + df = pd.concat([df, new_df], ignore_index=True) + df = df.round(3) + df.to_csv(output_path / f'speed_benchmark-{args.model}.csv', index=False) + save_plot(output_path, args.model, df) + except Exception as e: # noqa: B902 + logging.warning('Skipping model %s due to exception %s', mname, e) return df diff --git a/tests/ptlflow/models/test_checkpoints.py b/tests/ptlflow/models/test_checkpoints.py index 1d32b48..317f482 100644 --- a/tests/ptlflow/models/test_checkpoints.py +++ b/tests/ptlflow/models/test_checkpoints.py @@ -31,6 +31,30 @@ # Results at scale_factor=0.66 reference_accuracy = { + 'craft_things_flyingchairs': 0.720, + 'craft_things_flyingthings3d': 3.317, + 'craft_things_kitti': 7.534, + 'craft_things_sintel': 0.186, + 'craft_sintel_flyingchairs': 0.680, + 'craft_sintel_flyingthings3d': 4.002, + 'craft_sintel_kitti': 1.863, + 'craft_sintel_sintel': 0.143, + 'craft_kitti_flyingchairs': 1.713, + 'craft_kitti_flyingthings3d': 19.485, + 'craft_kitti_kitti': 0.921, + 'craft_kitti_sintel': 0.238, + 'csflow_chairs_flyingchairs': 0.587, + 'csflow_chairs_flyingthings3d': 8.454, + 'csflow_chairs_kitti': 12.062, + 'csflow_chairs_sintel': 0.223, + 'csflow_things_flyingchairs': 0.843, + 'csflow_things_flyingthings3d': 3.620, + 'csflow_things_kitti': 4.639, + 'csflow_things_sintel': 0.185, + 'csflow_kitti_flyingchairs': 0.920, + 'csflow_kitti_flyingthings3d': 6.655, + 'csflow_kitti_kitti': 1.205, + 'csflow_kitti_sintel': 0.168, 'dicl_chairs_flyingchairs': 0.675, 'dicl_chairs_flyingthings3d': 20.257, 'dicl_chairs_kitti': 24.210, @@ -67,6 +91,22 @@ 'fastflownet_things_flyingthings3d': 20.497, 'fastflownet_things_kitti': 12.205, 'fastflownet_things_sintel': 0.434, + 'flowformer_chairs_flyingchairs': 0.558, + 'flowformer_chairs_flyingthings3d': 11.260, + 'flowformer_chairs_kitti': 13.684, + 'flowformer_chairs_sintel': 0.261, + 'flowformer_things_flyingchairs': 0.697, + 'flowformer_things_flyingthings3d': 3.104, + 'flowformer_things_kitti': 6.169, + 'flowformer_things_sintel': 0.245, + 'flowformer_sintel_flyingchairs': 0.650, + 'flowformer_sintel_flyingthings3d': 2.960, + 'flowformer_sintel_kitti': 2.377, + 'flowformer_sintel_sintel': 0.175, + 'flowformer_kitti_flyingchairs': 1.605, + 'flowformer_kitti_flyingthings3d': 17.919, + 'flowformer_kitti_kitti': 1.888, + 'flowformer_kitti_sintel': 0.325, 'flownet2_things_flyingchairs': 1.986, 'flownet2_things_flyingthings3d': 10.010, 'flownet2_things_kitti': 16.391, @@ -139,6 +179,54 @@ 'gma_kitti_flyingthings3d': 18.008, 'gma_kitti_kitti': 0.987, 'gma_kitti_sintel': 0.286, + 'gmflow_chairs_flyingchairs': 0.946, + 'gmflow_chairs_flyingthings3d': 8.914, + 'gmflow_chairs_kitti': 12.958, + 'gmflow_chairs_sintel': 0.803, + 'gmflow_things_flyingchairs': 0.939, + 'gmflow_things_flyingthings3d': 3.517, + 'gmflow_things_kitti': 11.070, + 'gmflow_things_sintel': 0.226, + 'gmflow_sintel_flyingchairs': 1.063, + 'gmflow_sintel_flyingthings3d': 3.824, + 'gmflow_sintel_kitti': 3.283, + 'gmflow_sintel_sintel': 0.286, + 'gmflow_kitti_flyingchairs': 2.058, + 'gmflow_kitti_flyingthings3d': 14.674, + 'gmflow_kitti_kitti': 1.801, + 'gmflow_kitti_sintel': 0.652, + 'gmflow_refine_chairs_flyingchairs': 1.012, + 'gmflow_refine_chairs_flyingthings3d': 8.609, + 'gmflow_refine_chairs_kitti': 12.410, + 'gmflow_refine_chairs_sintel': 0.997, + 'gmflow_refine_things_flyingchairs': 0.922, + 'gmflow_refine_things_flyingthings3d': 6.392, + 'gmflow_refine_things_kitti': 9.507, + 'gmflow_refine_things_sintel': 0.383, + 'gmflow_refine_sintel_flyingchairs': 1.070, + 'gmflow_refine_sintel_flyingthings3d': 6.533, + 'gmflow_refine_sintel_kitti': 5.099, + 'gmflow_refine_sintel_sintel': 0.298, + 'gmflow_refine_kitti_flyingchairs': 1.774, + 'gmflow_refine_kitti_flyingthings3d': 11.691, + 'gmflow_refine_kitti_kitti': 2.900, + 'gmflow_refine_kitti_sintel': 0.415, + 'gmflownet_things_flyingchairs': 0.693, + 'gmflownet_things_flyingthings3d': 3.020, + 'gmflownet_things_kitti': 6.168, + 'gmflownet_things_sintel': 0.197, + 'gmflownet_kitti_flyingchairs': 2.343, + 'gmflownet_kitti_flyingthings3d': 18.116, + 'gmflownet_kitti_kitti': 1.040, + 'gmflownet_kitti_sintel': 0.348, + 'gmflownet_mix_things_flyingchairs': 0.976, + 'gmflownet_mix_things_flyingthings3d': 4.514, + 'gmflownet_mix_things_kitti': 8.019, + 'gmflownet_mix_things_sintel': 0.204, + 'gmflownet_mix_sintel_flyingchairs': 2.053, + 'gmflownet_mix_sintel_flyingthings3d': 7.421, + 'gmflownet_mix_sintel_kitti': 5.687, + 'gmflownet_mix_sintel_sintel': 0.178, 'irr_pwc_chairs_occ_flyingchairs': 0.909, 'irr_pwc_chairs_occ_flyingthings3d': 10.531, 'irr_pwc_chairs_occ_kitti': 9.929, diff --git a/tests/ptlflow/models/test_models.py b/tests/ptlflow/models/test_models.py index 20b33e8..a36d964 100644 --- a/tests/ptlflow/models/test_models.py +++ b/tests/ptlflow/models/test_models.py @@ -45,19 +45,23 @@ def test_forward() -> None: if mname in EXCLUDE_MODELS: continue - model = ptlflow.get_model(mname) - model = model.eval() + try: + model = ptlflow.get_model(mname) + model = model.eval() - s = make_divisible(400, model.output_stride) - inputs = {'images': torch.rand(1, 2, 3, s, s)} + s = make_divisible(128, model.output_stride) + inputs = {'images': torch.rand(1, 2, 3, s, s)} - if torch.cuda.is_available(): - model = model.cuda() - inputs['images'] = inputs['images'].cuda() + if torch.cuda.is_available(): + model = model.cuda() + inputs['images'] = inputs['images'].cuda() - model(inputs) + model(inputs) + except (ImportError, RuntimeError): + continue +@pytest.mark.skip(reason='Requires too many resources. Use only on machines with large GPUs.') def test_train(tmp_path: Path): write_flying_chairs2(tmp_path) diff --git a/tests/test_train.py b/tests/test_train.py index 0006f78..ba011c1 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -62,13 +62,19 @@ def test_train(tmp_path: Path) -> None: train.train(args) - log_dirs = Path('default/version_0') - - assert (tmp_path / f'{TEST_MODEL}-{TRAIN_LOG_SUFFIX}' / log_dirs / 'hparams.yaml').exists() - - ckpt_last = list((tmp_path / f'{TEST_MODEL}-{TRAIN_LOG_SUFFIX}' / log_dirs / 'checkpoints').glob('*_last_*.ckpt')) - assert len(ckpt_last) > 0 - ckpt_train = list((tmp_path / f'{TEST_MODEL}-{TRAIN_LOG_SUFFIX}' / log_dirs / 'checkpoints').glob('*_train_*.ckpt')) - assert len(ckpt_train) > 0 + dir_names = ['default', 'lightning_logs'] # Name changes depending on PL version + hparams_res = [] + last_res = [] + train_res = [] + for dname in dir_names: + log_dirs = Path(f'{dname}/version_0') + + hparams_res.append((tmp_path / f'{TEST_MODEL}-{TRAIN_LOG_SUFFIX}' / log_dirs / 'hparams.yaml').exists()) + last_res.append(len(list((tmp_path / f'{TEST_MODEL}-{TRAIN_LOG_SUFFIX}' / log_dirs / 'checkpoints').glob('*_last_*.ckpt')))) + train_res.append(len(list((tmp_path / f'{TEST_MODEL}-{TRAIN_LOG_SUFFIX}' / log_dirs / 'checkpoints').glob('*_train_*.ckpt')))) + + assert max(hparams_res) is True + assert max(last_res) > 0 + assert max(train_res) > 0 shutil.rmtree(tmp_path) diff --git a/validate.py b/validate.py index d3343db..bbad89e 100644 --- a/validate.py +++ b/validate.py @@ -73,6 +73,9 @@ def _init_parser() -> ArgumentParser: parser.add_argument( '--max_samples', type=int, default=None, help=('Maximum number of samples per dataset will be used for calculating the metrics.')) + parser.add_argument( + '--reversed', action='store_true', + help='To be combined with model all or select. Iterates over the list of models in reversed order') return parser @@ -159,6 +162,55 @@ def validate( return metrics_df +def validate_list_of_models( + args: Namespace +) -> None: + """Perform the validation. + + Parameters + ---------- + args : Namespace + Arguments to configure the list of models and the validation. + """ + metrics_df = pd.DataFrame() + + model_names = _get_model_names(args) + if args.reversed: + model_names = reversed(model_names) + + for mname in model_names: + logging.info(mname) + model_ref = ptlflow.get_model_reference(mname) + + if hasattr(model_ref, 'pretrained_checkpoints'): + ckpt_names = model_ref.pretrained_checkpoints.keys() + for cname in ckpt_names: + try: + logging.info(cname) + parser_tmp = model_ref.add_model_specific_args(parser) + args = parser_tmp.parse_args() + + args.model = mname + args.pretrained_ckpt = cname + + model_id = args.model + if args.pretrained_ckpt is not None: + model_id += f'_{args.pretrained_ckpt}' + args.output_path = Path(args.output_path) / model_id + + model = get_model(mname, cname, args) + instance_metrics_df = validate(args, model) + metrics_df = pd.concat([metrics_df, instance_metrics_df]) + args.output_path.parent.mkdir(parents=True, exist_ok=True) + if args.reversed: + metrics_df.to_csv(args.output_path.parent / 'metrics_all_rev.csv', index=False) + else: + metrics_df.to_csv(args.output_path.parent / 'metrics_all.csv', index=False) + except Exception as e: # noqa: B902 + logging.warning('Skipping model %s due to exception %s', mname, e) + break + + @torch.no_grad() def validate_one_dataloader( args: Namespace, @@ -311,34 +363,6 @@ def _write_to_file( model = get_model(sys.argv[1], args.pretrained_ckpt, args) args.output_path.mkdir(parents=True, exist_ok=True) - metrics_df = validate(args, model) + validate(args, model) else: - # Run validation on all models and checkpoints - metrics_df = pd.DataFrame() - - model_names = _get_model_names(args) - - for mname in model_names: - logging.info(mname) - model_ref = ptlflow.get_model_reference(mname) - - if hasattr(model_ref, 'pretrained_checkpoints'): - ckpt_names = model_ref.pretrained_checkpoints.keys() - for cname in ckpt_names: - logging.info(cname) - parser_tmp = model_ref.add_model_specific_args(parser) - args = parser_tmp.parse_args() - - args.model = mname - args.pretrained_ckpt = cname - - model_id = args.model - if args.pretrained_ckpt is not None: - model_id += f'_{args.pretrained_ckpt}' - args.output_path = Path(args.output_path) / model_id - - model = get_model(mname, cname, args) - instance_metrics_df = validate(args, model) - metrics_df = pd.concat([metrics_df, instance_metrics_df]) - args.output_path.parent.mkdir(parents=True, exist_ok=True) - metrics_df.to_csv(args.output_path.parent / 'metrics_all.csv', index=False) + validate_list_of_models(args)